blog | about

learning to (learn at test time): transformers without quadratic attention

Nov 10, 2024

The recent paper "Learning to (Learn at Test Time): RNNs with Expressive Hidden States" introduces test-time training (TTT), an exciting new approach to sequence modeling. The authors present the TTT model as a recurrent neural network (RNN) capable of achieving transformer-level quality. In this blog post, I motivate the TTT model from the other direction — as a transformer with RNN-like linear complexity in the sequence length. Towards this end, I give an alternate formulation of TTT as an approximation to the transformer's quadratic attention mechanism. Apart from providing fresh intuition for TTT, this perspective suggests interesting possibilities such as initializing TTT models with pre-trained transformer weights.

approximating attention with test-time training

For simplicity, let us first consider attention with a single query vector.

Single-query attention has the type signature \( o = a(q, K, V) \), where,

The standard scaled-dot-product-attention (\( a^{\text{SDPA}} \)) is computed as follows:

  1. \( w = \text{softmax}\left(\left[\frac{q \cdot k_1}{\sqrt{d_k}}, \dots, \frac{q \cdot k_T}{\sqrt{d_k}}\right]\right) \)
  2. \( o = \sum_{i=1}^t w_i v_i \)

Below is an intuitive breakdown of the two steps:

  1. Soft-select a key \( k_i \) most similar to the query \( q \),
  2. Return the associated value \( v_i \).

This is reminiscent of applying a nearest-neighbor model with train set \( \{(k_1, v_1), \dots, (k_T, v_T)\}, \) and test input \( q \). What if we instead use a parametric model such as a neural network? This would look like below:

  1. Train a neural network \( f_\theta \) to predict \( v_i \) for input \( k_i \), and,
  2. Return the prediction of the network on \( q \).

This suggests \( a^{\text{TTT}} \) as a substitute for \( a^{\text{SDPA}} \), computed as follows:

  1. \( \theta = \arg \min_\theta \sum_{i=1}^T \Vert f_\theta(k_i) - v_i \Vert_2^2 \)
  2. \( o = f_\theta(q) \).

Step 1 above trains the network \( f_\theta \) with an \( L_2 \) loss. Note that this training happens within the lifetime of a single attention call. The attention call can be part of an outer model, such as a transformer. Inferencing the outer model at test-time would still require training \( f_\theta \) within the attention call, hence the name "test-time training".

how does TTT achieve linear scaling?

So far, we train \( f_\theta \) on all key-value pairs for every single-query attention call. In this setting, \( a^{\text{TTT}} \) is at least linear in the sequence length, no better than \( a^{\text{SDPA}} \). To see how TTT can be more efficient, we need to consider how single-query attention is invoked inside the over-arching self-attention computation.

In self-attention, there is a query vector, and an associated output vector, for each key-value pair. The type signature is \( O = A(Q, K, V)\), where,

The relation to single-query attention is given by: \(o_t = a(q_t, [k_1, \dots, k_t], [v_1, \dots, v_t])\). The dependence of \( o_t \) on \( (k_i, v_i) \) only for \( i \le t \) makes this kind of self-attention causal.

\(A^{\text{SDPA}} \) is quadratic in the sequence length, as each \( o_t \) is linear in \( t\), and no further optimization is possible. \(A^{\text{TTT}} \) avoids this quadratic dependence by reusing the learnt neural network from the previous timestep, instead of training a fresh one from scratch. For example, if we use one step of gradient descent on every pair \( (k_t, v_t) \) with learning rate \( \eta\), \(A^{\text{TTT}} \) can be computed as below:

For \( t=1, \dots, T\):

  1. \( \theta_t = \theta_{t-1} - \eta \nabla_{\theta_t} \Vert f_{\theta_t}(k_t) - v_t \Vert_2^2\)
  2. \( o_t = f_\theta(q_t)\)
This is linear in the sequence length \( T\).

afterword

What I have presented is only one way to gain an introductory understanding of TTT. This blog post has nothing to say about making the idea work in practice, or why it should do any better than all the other sub-quadratic sequence models out there. In my opinion, the original paper is profound in its philosophy and technical achievements. I highly recommend giving it a thorough read, if you haven't already. Sections 2.1-2.3 are most directly comparable and they are excellent reads despite this post.

acknowledgements

This post was inspired by discussions with Krish Parikh and Marcel Roed. Thanks to Yu Sun and Himanshu Singh for proofreading and feedback.