The recent paper "Learning to (Learn at Test Time): RNNs with Expressive Hidden States" introduces test-time training (TTT), an exciting new approach to sequence modeling. The authors present the TTT model as a recurrent neural network (RNN) capable of achieving transformer-level quality. In this blog post, I motivate the TTT model from the other direction — as a transformer with RNN-like linear complexity in the sequence length. Towards this end, I give an alternate formulation of TTT as an approximation to the transformer's quadratic attention mechanism. Apart from providing fresh intuition for TTT, this perspective suggests interesting possibilities such as initializing TTT models with pre-trained transformer weights.
For simplicity, let us first consider attention with a single query vector.
Single-query attention has the type signature \( o = a(q, K, V) \), where,
The standard scaled-dot-product-attention (\( a^{\text{SDPA}} \)) is computed as follows:
Below is an intuitive breakdown of the two steps:
This is reminiscent of applying a nearest-neighbor model with train set \( \{(k_1, v_1), \dots, (k_T, v_T)\}, \) and test input \( q \). What if we instead use a parametric model such as a neural network? This would look like below:
This suggests \( a^{\text{TTT}} \) as a substitute for \( a^{\text{SDPA}} \), computed as follows:
Step 1 above trains the network \( f_\theta \) with an \( L_2 \) loss. Note that this training happens within the lifetime of a single attention call. The attention call can be part of an outer model, such as a transformer. Inferencing the outer model at test-time would still require training \( f_\theta \) within the attention call, hence the name "test-time training".
So far, we train \( f_\theta \) on all key-value pairs for every single-query attention call. In this setting, \( a^{\text{TTT}} \) is at least linear in the sequence length, no better than \( a^{\text{SDPA}} \). To see how TTT can be more efficient, we need to consider how single-query attention is invoked inside the over-arching self-attention computation.
In self-attention, there is a query vector, and an associated output vector, for each key-value pair. The type signature is \( O = A(Q, K, V)\), where,
\(A^{\text{SDPA}} \) is quadratic in the sequence length, as each \( o_t \) is linear in \( t\), and no further optimization is possible. \(A^{\text{TTT}} \) avoids this quadratic dependence by reusing the learnt neural network from the previous timestep, instead of training a fresh one from scratch. For example, if we use one step of gradient descent on every pair \( (k_t, v_t) \) with learning rate \( \eta\), \(A^{\text{TTT}} \) can be computed as below:
For \( t=1, \dots, T\):
What I have presented is only one way to gain an introductory understanding of TTT. This blog post has nothing to say about making the idea work in practice, or why it should do any better than all the other sub-quadratic sequence models out there. In my opinion, the original paper is profound in its philosophy and technical achievements. I highly recommend giving it a thorough read, if you haven't already. Sections 2.1-2.3 are most directly comparable and they are excellent reads despite this post.
This post was inspired by discussions with Krish Parikh and Marcel Roed. Thanks to Yu Sun and Himanshu Singh for proofreading and feedback.