Summary

The paper presents a retrieval-augmented generation (RAG) inspired model for tabular data classification. The authors propose a new mechanism for similarity calculation and retrieval and show superior performance compared to previous DL models and GBDT family on previously proposed tabular benchmarks.

  • A Retrieval Augmented Generation (RAG) inspired for tabular data classification.
  • Modified mechanism for similarity calculation (and thus retrieval).

Approach

  • Resnet + clustering-like for RAG type.
  • Attention-like mechanism to calculate similarity scores.
Mechanism

Retrieval-based mechanism for tabular data classification

Similarity module

Similarity mechanism

Retrieval Mechanism

  • Start with vanilla attention mechanism
S(x~,x~i)=WQ(x~)TWK(x~i)d1/2V(x~,x~i,yi)=WV(x~i)\mathcal{S}(\tilde x, \tilde x_i) = W_Q(\tilde x)^TW_K(\tilde x_i) \cdot d^{-1/2} \quad \mathcal{V}(\tilde x, \tilde x_i, y_i) = W_V(\tilde x_i)
  • Step-by-step description of modifying different components

    1. Context labels -- utilize the label of context object xix_i when calculating the value.

      S(x~,x~i)=WQ(x~)TWK(x~i)d1/2V(x~,x~i,yi)=WY(yi)+WV(x~i)\mathcal{S}(\tilde x, \tilde x_i) = W_Q(\tilde x)^TW_K(\tilde x_i) \cdot d^{-1/2} \quad \mathcal{V}(\tilde x, \tilde x_i, y_i) = \underline{W_Y(y_i)} + W_V(\tilde x_i)
    2. Change similarity module -- instead of using WQW_Q and WKW_K, use just WKW_K and instead use L2L_2 distance to calculate the similarity

      S(x~,x~i)=WK(x~)WK(x~i)2d1/2V(x~,x~i,yi)=WY(yi)+WV(x~i)\mathcal{S}(\tilde x, \tilde x_i) = \underline{-\Vert W_K(\tilde x) - W_K(\tilde x_i) \Vert^2} \cdot d^{-1/2} \quad \mathcal{V}(\tilde x, \tilde x_i, y_i) = W_Y(y_i) + W_V(\tilde x_i)
    • In appendix, the authors discuss that:
      1. It's easier to align just 1 learned representation (WKW_K), instead of two (WQW_Q and WKW_K).
      2. The encoder module EE is actually a linear transformation, meaning that similarity measures that work well in the original space may also work well after applying EE.
        • EE is a linear transformation because it's used very frequently (at each inference step to find the candidates). So it's better to have it be lightweight.

    1. Change the value module -- instead of using WVW_V, use WKW_K and WYW_Y to calculate the value. S(x~,x~i)=WK(x~)WK(x~i)2d1/2V(x~,x~i,yi)=WY(yi)+T(WK(x~)WK(x~i))\mathcal{S}(\tilde x, \tilde x_i) = -\Vert W_K(\tilde x) - W_K(\tilde x_i) \Vert^2 \cdot d^{-1/2} \quad \mathcal{V}(\tilde x, \tilde x_i, y_i) = W_Y(y_i) + \underline{T(W_K(\tilde x) - W_K(\tilde x_i))} T()=LinearWithoutBias(Dropout(ReLU(Linear())))T(\cdot) = \text{LinearWithoutBias}(\text{Dropout}(\text{ReLU}(\text{Linear}(\cdot))))

INTUITION

This step is adding more information about the target sample xx into the value vector. WY(yi)W_Y(y_i) is the raw contribution of sample ii (because it tells us about the label associated with that sample), where T(WK(x~)WK(x~i))T(W_K(\tilde x) - W_K(\tilde x_i)) is the correction term. T()T(\cdot) translates the differences of xx and xix_i in the key-space (WKW_K) to the label space (WYW_Y).

  1. Remove the scaling term d1/2d^{-1/2} (artifact of vanilla attention anyways) from the similarity calculation. k=WK(x~),ki=WK(x~i)S(x~,x~i)=kki2V(x~,x~i,yi)=WY(yi)+T(kk1)k = W_K(\tilde x), k_i = W_K(\tilde x_i) \quad \mathcal{S}(\tilde x, \tilde x_i) = - \Vert k-k_i \Vert^2 \quad \mathcal{V}(\tilde x, \tilde x_i, y_i) = W_Y(y_i) + T(k-k_1)

Putting it together

The output of the retrieval module is then the weighted sum of the value vectors of top mm samples, where the similarity score determines the weights (the similarity score is bound in [0,1][0, 1] since we take the L2 norm).

y^=Predictor(x~+itop mS(x~,x~i)V(x~,x~i,yi))\hat y = \text{Predictor}(\tilde x + \sum \limits_{i \in \text{top m}} \mathcal{S}(\tilde x, \tilde x_i) \cdot \mathcal{V}(\tilde x, \tilde x_i, y_i))

Ablations

  • Freeze context (encoded samples) after training starts to stablize.
  • Online setting (start with limited data and add unseen)

Findings

DL vs XGB Performance Comparison

TabR and previous DL models compared against XGB

TabR vs GBDT without HPO

TabR and previous DL models compared against GBDTs with and without HPO.

  • Adding all of the 4 modifications is what makes TabR perform well.
  • TabR beats GBDTs on some datasets.
  • Numerical embeddings and retrieval seem to be key techniques in good DL performance.

Noted limitations

  • While the new module is more efficient than standard attention, may still not scale too well.

Resources