Summary

While TabPFN shows great performance in certain circumstances, it does not scale well with larger datasets (quadratic memory). This paper proposes a method to improve the context given to the PFN model by using a nearest-neighbors based approach.

Approach

Architecture

Architecture

Description of the proposed architecture. a). Perform kkNN for xqyx_{\text{qy}} in Dtrain\mathcal{D}_{\text{train}} as input to the TabPFN classifier. b). Use approximation of kkNN (pre-computed by randomly selected points) to improve efficiency.

Context from local data

MOTIVATION

TabPFN is limited to randomly sampling the dataset when Dtrain\mathcal{D}_{\text{train}} is too large. This can lead to suboptimal performance. How can we optimize the context given to the TabPFN model?

Original TabPFN classification given Dtrain{(xtraini,ytraini)}i=1N\mathcal{D}_{\text{train}} \triangleq \{(x^i_{\text{train}}, y^i_{\text{train}})\}_{i=1}^{N}, xtrainiRDx^{i}_{\text{train}} \in \mathbb{R}^D, ytraini{1,...,C}y^{i}_{\text{train}} \in \{1, ... , C\} and a query point xqyx_{\text{qy}}:

p_{\theta}(y_{\text{qy}} \mid x_{\text{qy}}, \mathcal{D}_{\text{context}}) = \cfrac{\exp(f_{\theta}(x_{\text{qy}}, \mathcal{D}_{\text{train}})[y_{\text{qy}}])}{\sum_{c=1}^{C} \exp(f_{\theta}(x_{\text{qy}}, \mathcal{D}_{\text{context}})[c])}

using [][\cdot] as the indexing operator and DcontextDtrain\mathcal{D}_{\text{context}} \triangleq \mathcal{D}_{\text{train}} (context is entire training dataset).

The proposed method, LoCalPFN, uses a kk-nearest neighbors approach to improve the context given to the TabPFN model. Thus, DcontextkNN(xqy)\mathcal{D}_{\text{context}} \triangleq k\text{NN}(x_{\text{qy}}) is now a subset of Dtrain\mathcal{D}_{\text{train}}.

Improving efficiency for fine-tuning

MOTIVATION

Original TabPFN takes input of shape (B,Lctx+Lqy,d)(B, L_{\text{ctx}} + L_{\text{qy}}, d) where BB is the batch size (set to 1 because there is only one context that is shared for every query point), LctxL_{\text{ctx}} is the number of context points, Lqy=NqyL_{\text{qy}} = N_{\text{qy}} is the number of query points, and dd is the dimension of the input. But if we were to apply the above approach, we need to re-compute the kkNN context for each query point, meaning the input now has shape B=NqyB = N_{\text{qy}}, Lctx=kL_{ctx} = k, Lqy=1L_{qy} = 1. This can become very expensive for fine-tuning.

Instead, the authors propose to pre-compute the kkNN context to approximate the exact process. If we want to fine-tune the model for NqyN_{\text{qy}} points, we start by selecting BB random points, compute their kkNN context where k=Lctx+Lqyk = L_{\text{ctx}} + L_{\text{qy}}, Lqy=Nqy/BL_{\text{qy}} = N_{\text{qy}} / B, and store it. Then each kkNN group can be split into query and context to fine-tune the TabPFN model. This way, we can ensure that the query points and context points are always local to each other.

Findings

Limits of TabPFN/Benefit of local context

Toy dataset comparison between TabPFN and LoCalPFN

Comarison between TabPFN and LocalPFN on a toy dataset

a). As the complexity/size of the dataset increases, vanilla TabPFN struggles. b). Using local context as input instead of the whole training set improves performance. c). Performance vs. kk. Large kk tends to "oversmooth" and suffer from high bias/underfitting, while small kk enables more complex decision boundaries but can suffer from more variance/overfitting.

Experiments

SETTING

96 datasets from TabZilla1 benchmark suite. Main: TabPFN, TabPFN + kkNN (No fine-tuning) and LocalPFN.

Dataset size/complexity

  • TabPFN is already quite competitive in small datasets.
    • But LocalPFN improves (and upon TabPFN + kkNN).
  • LocalPFN can outperform other models in larger/more complex* datasets.

Ablations

  • Using both fine-tuning and kkNN yields best performance.
  • kkNN in the original space is already quite good. Using one-hot embedding further improves a bit.
    • But using the embeddings from the TabPFN encoder is not as good.

      INTUITION

      Features values in tabular datasets can be semantically meaningful. Thus a distance metric that decomposes over individual features, i.e. d(x,x)=id(xi,xi)d(x, x') = \sum_{i} d(x_i, x_i') can be more effective than a learned distance metric.

  • Using the local context is better than using the global context.
    • Instead of a single randomly-sampled context, compare against variations that try to use the full data for a global context.
      • Compared against random ensemble and ensemble with no overlap.
    • When not fine-tuning (TabPFN + kkNN), kk does not matter as much. But when fine-tuning (LocalPFN), more kk can be better.
    • kkNN is better than data distillation2

*: The authors define a proxy for complexity as the difference between the lowest and highest performer.

Resources

Footnotes

  1. McElfresh, Duncan, Sujay Khandagale, Jonathan Valverde, Vishak Prasad C, Ganesh Ramakrishnan, Micah Goldblum, and Colin White. "When do neural nets outperform boosted trees on tabular data?." Advances in Neural Information Processing Systems 36 (2024).

  2. Ma, Junwei, Valentin Thomas, Guangwei Yu, and Anthony Caterini. "In-Context Data Distillation with TabPFN." arXiv preprint arXiv:2402.06971 (2024).