- Published on
Retrieval & Fine-Tuning for In-Context Tabular Models
Summary
While TabPFN shows great performance in certain circumstances, it does not scale well with larger datasets (quadratic memory). This paper proposes a method to improve the context given to the PFN model by using a nearest-neighbors based approach.
Approach
Architecture
Description of the proposed architecture. a). Perform NN for in as input to the TabPFN classifier. b). Use approximation of NN (pre-computed by randomly selected points) to improve efficiency.
Context from local data
MOTIVATION
TabPFN is limited to randomly sampling the dataset when is too large. This can lead to suboptimal performance. How can we optimize the context given to the TabPFN model?
Original TabPFN classification given , , and a query point :
p_{\theta}(y_{\text{qy}} \mid x_{\text{qy}}, \mathcal{D}_{\text{context}}) = \cfrac{\exp(f_{\theta}(x_{\text{qy}}, \mathcal{D}_{\text{train}})[y_{\text{qy}}])}{\sum_{c=1}^{C} \exp(f_{\theta}(x_{\text{qy}}, \mathcal{D}_{\text{context}})[c])}using as the indexing operator and (context is entire training dataset).
The proposed method, LoCalPFN, uses a -nearest neighbors approach to improve the context given to the TabPFN model. Thus, is now a subset of .
Improving efficiency for fine-tuning
MOTIVATION
Original TabPFN takes input of shape where is the batch size (set to 1 because there is only one context that is shared for every query point), is the number of context points, is the number of query points, and is the dimension of the input. But if we were to apply the above approach, we need to re-compute the NN context for each query point, meaning the input now has shape , , . This can become very expensive for fine-tuning.
Instead, the authors propose to pre-compute the NN context to approximate the exact process. If we want to fine-tune the model for points, we start by selecting random points, compute their NN context where , , and store it. Then each NN group can be split into query and context to fine-tune the TabPFN model. This way, we can ensure that the query points and context points are always local to each other.
Findings
Limits of TabPFN/Benefit of local context
Comarison between TabPFN and LocalPFN on a toy dataset
a). As the complexity/size of the dataset increases, vanilla TabPFN struggles. b). Using local context as input instead of the whole training set improves performance. c). Performance vs. . Large tends to "oversmooth" and suffer from high bias/underfitting, while small enables more complex decision boundaries but can suffer from more variance/overfitting.
Experiments
SETTING
96 datasets from TabZilla1 benchmark suite. Main: TabPFN, TabPFN + NN (No fine-tuning) and LocalPFN.
Dataset size/complexity
- TabPFN is already quite competitive in small datasets.
- But LocalPFN improves (and upon TabPFN + NN).
- LocalPFN can outperform other models in larger/more complex* datasets.
Ablations
- Using both fine-tuning and NN yields best performance.
- NN in the original space is already quite good. Using one-hot embedding further improves a bit.
- But using the embeddings from the TabPFN encoder is not as good.
INTUITION
Features values in tabular datasets can be semantically meaningful. Thus a distance metric that decomposes over individual features, i.e. can be more effective than a learned distance metric.
- But using the embeddings from the TabPFN encoder is not as good.
- Using the local context is better than using the global context.
- Instead of a single randomly-sampled context, compare against variations that try to use the full data for a global context.
- Compared against random ensemble and ensemble with no overlap.
- When not fine-tuning (TabPFN + NN), does not matter as much. But when fine-tuning (LocalPFN), more can be better.
- NN is better than data distillation2
- Instead of a single randomly-sampled context, compare against variations that try to use the full data for a global context.
*: The authors define a proxy for complexity as the difference between the lowest and highest performer.
Resources
Footnotes
McElfresh, Duncan, Sujay Khandagale, Jonathan Valverde, Vishak Prasad C, Ganesh Ramakrishnan, Micah Goldblum, and Colin White. "When do neural nets outperform boosted trees on tabular data?." Advances in Neural Information Processing Systems 36 (2024). ↩
Ma, Junwei, Valentin Thomas, Guangwei Yu, and Anthony Caterini. "In-Context Data Distillation with TabPFN." arXiv preprint arXiv:2402.06971 (2024). ↩