Summary

  • The authors propose a new architecture called CALinear that enables a meta-function space which can be learned with a combination of multiple basis linear functions.

Approach

  • Instead of learning one linear layer after MHSA, learn a set of linear layers and some coefficient generator.
    • Claims that this method is more expressive and efficient.
Overview

Architecture overview of XTFormer

Vanilla attention:

x=MHSA(x)=Softmax(Linear(QKTdk)V)x' = \text{MHSA}(x) = \text{Softmax}(\text{Linear}(\frac{QK^T}{\sqrt{d_k}})V)

The authors propose changing Linear\text{Linear} to CALinear\text{CALinear}, where:

CALinear()=i=1nciLinear()ci=Softmax(Mcal(vi))Mcal=Softmax(MLP(x))\begin{gather*} \text{CALinear}(\cdot) = \sum_{i=1}^{n} c_i \cdot \text{Linear}(\cdot) \\ c_i = \text{Softmax}(\text{M}_{\text{cal}}(v_i)) \text{M}_{\text{cal}} = \text{Softmax}(\text{MLP}(x)) \end{gather*}

where viv_i is a learnable context vector for feature ii and Mcal\text{M}_{\text{cal}} is the calibration module.

In the proposed architecture, only viv_i and the embedding layer is trained from scratch for a new dataset. Mcal\text{M}_{\text{cal}} is trained over multiple datasets during pre-training but is frozen during fine-tuning.

Training for downstream task

2 stages

  1. Task Calibration: learn task(dataset)-specific modules from scratch (embedding, output classifier, feature context viv_i).
  2. Refinement: fine-tune all parameters on the downstream task.

Findings

  • XGB and CatBoost outperform most DL baselines.
  • But CALinear beats them (pretty much) consistently.
  • Using 4 - 6 basis functions significantly outperforms using 1-2 basis functions.
    • But 4 and 6 do not show a significant difference.
  • Using Mcal\text{M}_{\text{cal}} yields better performance than learning cic_i directly.
  • More task calibration yields better performance for full-data settings. But for limited-data settings, it actually lowers performance.
    • Refinement seems to help in all cases.

Resources