- Published on
TabR: Tabular Deep Learning Meets Nearest Neighbors
Summary
The paper presents a retrieval-augmented generation (RAG) inspired model for tabular data classification. The authors propose a new mechanism for similarity calculation and retrieval and show superior performance compared to previous DL models and GBDT family on previously proposed tabular benchmarks.
- A Retrieval Augmented Generation (RAG) inspired for tabular data classification.
- Modified mechanism for similarity calculation (and thus retrieval).
Approach
- Resnet + clustering-like for RAG type.
- Attention-like mechanism to calculate similarity scores.
Retrieval-based mechanism for tabular data classification
Similarity mechanism
Retrieval Mechanism
- Start with vanilla attention mechanism
Step-by-step description of modifying different components
Context labels -- utilize the label of context object when calculating the value.
Change similarity module -- instead of using and , use just and instead use distance to calculate the similarity
- In appendix, the authors discuss that:
- It's easier to align just 1 learned representation (), instead of two ( and ).
- The encoder module is actually a linear transformation, meaning that similarity measures that work well in the original space may also work well after applying .
is a linear transformation because it's used very frequently (at each inference step to find the candidates). So it's better to have it be lightweight.
- Change the value module -- instead of using , use and to calculate the value.
INTUITION
This step is adding more information about the target sample into the value vector. is the raw contribution of sample (because it tells us about the label associated with that sample), where is the correction term. translates the differences of and in the key-space () to the label space ().
- Remove the scaling term (artifact of vanilla attention anyways) from the similarity calculation.
Putting it together
The output of the retrieval module is then the weighted sum of the value vectors of top samples, where the similarity score determines the weights (the similarity score is bound in since we take the L2 norm).
Ablations
- Freeze context (encoded samples) after training starts to stablize.
- Online setting (start with limited data and add unseen)
Findings
TabR and previous DL models compared against XGB
TabR and previous DL models compared against GBDTs with and without HPO.
- Adding all of the 4 modifications is what makes TabR perform well.
- TabR beats GBDTs on some datasets.
- Numerical embeddings and retrieval seem to be key techniques in good DL performance.
Noted limitations
- While the new module is more efficient than standard attention, may still not scale too well.