Published on

Trompt: Towards a Better Deep Neural Network for Tabular Data

Summary

The authors propose an in-context-learning architecture for classification on tabular data.

Approach

Overall architecture of Trompt

Overall architecture of Trompt

The Trompt model uses multiple Trompt Cell, each of which handles a row of a table. The nnth cell receives the prompt generated by the n1n-1th cell.

Trompt Cell

Trompt Cell

Illustration of a single Trompt cell. Part 1. shows using the previous prompt and column embeddings to generate the next prompt. Part 2 shows how actual input data is transformed, and Part 3 shows how the feature embeddings are expanded for usage with multiple prompts.

SETTING

Given a table with NN rows CC columns, we want to learn a Trompt model that has a latent dimension of dd, LL cells and PP soft-prompts.

Part 1. Feature Importance

The feature importance scores are computed by using the column embeddings EcolumnRC×d\mE_{\text{column}} \in \R^{C\times d}, soft-prompts EpromptRP×d\mE_{\text{prompt}} \in \R^{P \times d} and the previous Trompt cell's output OprevRB×P×d\mO_{\text{prev}} \in \R^{B \times P \times d}. Note that Oprev\mO_{\text{prev}} has an extra batch dimension because the data is loaded in batches. Ecolumn\mE_{\text{column}} and Eprompt\mE_{\text{prompt}} are expanded into SEcolumnRB×C×d\mS\mE_{\text{column}} \in \R^{B \times C \times d} and SEpromptRB×P×d\mS\mE_{\text{prompt}} \in \R^{B \times P \times d}.

SE^prompt=dense(concat(SEprompt,Oprev))+SEprompt+OprevRB×P×dMimportance=softmax(SE^promptSEcolumnT)RB×P×C\begin{gather*} \hat{\mS\mE}_{\text{prompt}} = \text{dense}(\text{concat}(\mS\mE_{\text{prompt}}, \mO_{\text{prev}})) + \mS\mE_{\text{prompt}} + \mO_{\text{prev}} \in \R^{B \times P \times d} \\ \mM_{\text{importance}} = \text{softmax}(\hat{\mS\mE}_{\text{prompt}} \otimes \mS\mE_{\text{column}}^T) \in \R^{B \times P \times C} \end{gather*}

INTUITION

The authors say that this mechanism 'enables input-related representations to flow through and derive sample-specific feature importances'. The 'flow through' part comes from the fact that the nnth cell sees what the n1n-1th cell did. The sample-specific part comes from the fact that the prompt is generated based on the previous prompt and input sample, and not by a table-wide importance score. We can think of this as a way to map the interaction between individual prompt tokens and the table columns. In other words, (Mimportance)i,j,k(\mM_{\text{importance}})_{i,j,k} should tell us the importance of interaction between the jjth prompt and the kkth column for sample ii.

Parts 2 & 3. Generating/Reshaping Feature Embeddings

Now that the per-instance feature importance scores are computed, the actual representation of the instance itself needs to be computed. This is done through a rather standard procedure12, using an embedding lookup table for categorical features and a dense layer for numerical features. This process yields E^featureRB×C×d\hat{\mE}_{\text{feature}} \in \R^{B \times C \times d}. However, E^feature\hat{\mE}_{\text{feature}} cannot be directly used with the importance scores MimportanceM_{\text{importance}} because the shapes do not match. To fix this, Efeature\mE_{\text{feature}} is transformed once more by a dense layer to turn into RB×P×C×d\R^{B \times P \times C \times d}. Now we can expand Mimportance\mM_{\text{importance}} into RB×P×C×1\R \in {B \times P \times C \times 1} and apply element-wise multiplication to compute the output of the Trompt layer.

O=i=1C(MimportanceE^feature):,:,i,:RB×P×d\mO = \sum_{i=1}^{C} (\mM_{\text{importance}} \otimes \hat{\mE}_{\text{feature}})_{:,:,i,:} \in \R^{B \times P \times d}

For classification, a weighted sum is applied to O\mO to get the input to the downstream classifier by training a dense layer to compute the weights for each prompt.

Wprompt=softmax(dense(O))RB×PO^=i=1P(WpromptO):,i,:RB×dP=softmax(dense(O^))RB×T\begin{align*} \mW_{\text{prompt}} = \text{softmax}(\text{dense}(\mO)) \in \R^{B \times P} \\ \hat{\mO} = \sum_{i=1}^{P} (\mW_{\text{prompt}} \odot \mO)_{:,i,:} \in \R^{B \times d} \\ \mP = \text{softmax}(\text{dense}(\hat{\mO})) \in \R^{B \times T} \end{align*}

Where P\mP is the prediction and TT is the number of target classes.

Experiments

The authors focus on the WhyTrees benchmark3, which contains 45 datasets that were found to be difficult for both tree-based and deep learning models. Using this dataset, Trompt is compared against FT-Transformer1, SAINT2, ResNet3, XGB, LightGBM, Catboost, and Random Forest.

Findings

Main Results

Main Results

HPO comparison of performance grouped by task type (classification/regression) and dataset type (numerical only/heterogeneous).

Classification on medium datasets

  • Trompt outperforms DL baselines.
  • Also gets closer to tree-based models' performance.

Classification on large datasets

  • Less clear performance difference between DL and tree-based models.
  • Trompt shows very strong performance on numerical features.
    • also comparable to FTT on heterogeneous features.

Regression on medium datasets

  • Outperforms other DL methods.
  • Gets closer to tree-based, but not quite.

Regression on large datasets

  • Slightly below FTT and SAINT on numerical features.
  • But very large gap on heterogeneous features.

HPO Ablation

  • Having a larger prompt doesn't always help (default at 128 seems good enough).

  • Transforming the feature embeddings with a dense layer is crucial.

    • If we don't do this, it means each prompt is looking at the same feature embeddings.
  • Trompt does learn meaningful feature importances.

    • Evaluated by testing against GBDTs on synthetic datasets.

Resources

Footnotes

  1. Gorishniy, Yury, Ivan Rubachev, Valentin Khrulkov, and Artem Babenko. "Revisiting deep learning models for tabular data." Advances in Neural Information Processing Systems 34 (2021): 18932-18943. 2

  2. Somepalli, Gowthami, Micah Goldblum, Avi Schwarzschild, C. Bayan Bruss, and Tom Goldstein. "Saint: Improved neural networks for tabular data via row attention and contrastive pre-training." arXiv preprint arXiv:2106.01342 (2021). 2

  3. Grinsztajn, Léo, Edouard Oyallon, and Gaël Varoquaux. "Why do tree-based models still outperform deep learning on typical tabular data?." Advances in neural information processing systems 35 (2022): 507-520. 2