- Published on
Large Scale Transfer Learning for Tabular Data via Language Modeling
Summary
This paper introduces TabuLa-8b, a LLaMA3-8b based model fine-tuned for tabular data classification. The authors examine its performance in few-shot settings and find that the model fine-tuned for few-shot inference (TabuLa) vastly outperforms out-of-the box LLMs and is comparable to GBDT based baselines.
Approach
Model
Serialization
The authors opt for a natural-language style serialization of table rows. For example, the following table:
date | prediction | temp_max | weather |
---|---|---|---|
2015-03-22 | 1.0 | 11.699 | rain |
2015-09-19 | 0.0 | 14.722 | sun |
Will turn into the following text:
Predict the value of weather: ||sun||rain||snow|| The date is 2015-03-22. The precipitation is 1.0. The temp_max is 11.699. What is the value of weather? ||sun||rain||snow|| <|endinput|> rain<|endcompletion|>Predict the value of weather: ||sun||rain|| snow|| The date is 2015-09-19. The precipitation is 1.0. The temp_max is 14.722. What is the value of weather? ||sun||rain|| snow||<|endinput|>sun<|endcompletion|>
Note that the serialization includes the column names as well as the possible values for each column. The model is then trained to predict the value of a column given the rest of the row. The authors also note that they they only consider cases where the model correctly "finishes" the sequence with <|endcompletion|>
as correct.
Attention Mechanism
Row Causal Mask
Similar to causal masks used in language models, the row causal mask is used to ensure that the attention mechanism is only applied to cells from the same row.
The authors opt to use what they call "row causal mask" to prevent tokens from different rows to attend to each other. This is similar to the causal mask used in language models, but applied to the rows of the table. I think something similar existed in a while for batching multiple sentences the 4d attention stuff. This also allows them to pack multiple sequences into one input, maximizing the use of the context at each input.
Data Collection
Filtered and processed TabLib1 by:
- Filtering by table quality (missing values, non-tabular data, etc.)
- Automatically determine tasks (target columns to predict)
And end with 4M tables of 2.1B rows (~100B Llama3 tokens).
Findings
-shot Inference
- Base LLama3 is worse than random guessing at .
- Adding a few examples improves performance significantly.
- TabuLa-8b shows superior performance at low , but the gap starts to close between TabPFN2 and XGBoost as increases.
- In general, TabuLa-8b is data efficient (better performance in lower-data settings).
Using Row Causal Mask
SETTING
Training without the row causal mask results in significant performance drop.
- It degrades as increases -- suggesting that vanilla fine-tuning may lead to forgetting of the base model's few-shot capabilities.
Usefulness of column heading
SETTING
Replace column headings with uninformative tokens, e.g. X1
, X2
, etc.
- For small , the model performs better with column headings.
- But as increases, the model performance gap closes.
INTUITION
Maybe the model starts looking more at the cell "values" as increases, and the column headings become less important. If this is true, we should be able to verify this by looking at the attention weights.
Performance change with less features
SETTING
Train XGB on the full (clean) data and compute feature importances. Use the importance scores to remove certain features for TabuLa. Compare against XGB that has been trained on partially corrupted (missing) data on the same test data.
- Higher for TabuLa shows stronger performance.
- Removing good features doesn't suddenly destroy performance -- only reasonable drops.
Sensitivity to column ordering
SETTING
Shuffle the columns of the tables for inference.
- Drop in performance, but still good performance.
Performance on numerical/categorical columns
SETTING
Inference only on tables that contain at least one or all numerical column(s).
- Performance on some numerical columns is still better than baselines.
- Performance on all numerical columns is on par with baselines.
Resources
Footnotes
Eggert, Gus, Kevin Huo, Mike Biven, and Justin Waugh. "TabLib: A Dataset of 627M Tables with Context." arXiv preprint arXiv:2310.07875 (2023). Approximate Labs Post ↩
Hollmann, Noah, Samuel Müller, Katharina Eggensperger, and Frank Hutter. "Tabpfn: A transformer that solves small tabular classification problems in a second." arXiv preprint arXiv:2207.01848 (2022). ↩