1. Introduction
In one of my previous articles, we explored Google’s Titans (Behrouz et al., 2024)1 and how TTT (Test-Time Training) can be used to equip an LLM with a human-like, malleable memory, which can update its information at test time.
Test-time training, as the name suggests, is a paradigm that lets the model update its parameters on unseen data. But at test time, there are no ground truth labels that can help steer the model in the right direction (because that would be overt cheating). Instead, it performs a task with the data (designed and baked into the model), which leads the model to “subconsciously” learn about it.
Examples of such tasks can be:
- Rotation Prediction (Gidaris et al., 2018)2: The input images are rotated arbitrarily (eg, by 90°, 180°, or 270°), with the model being made to predict which is the correct orientation. This enables it to recognize salient features and determine which way is “up”.
- Masked-Language Modeling (Devlin et al., 2019)3: One or more tokens are masked from the test instance. The model’s job is to predict the missing tokens while the masked tokens play as the ground truths, which incentivizes a multi-faceted understanding of language.
- Confidence Maximization (Sun et al., 2020)4: Where the model is incentivized to make its output logits (eg, classification logits [0.3, 0.4, 0.3]) to be more peaked (eg, [0.1, 0.8, 0.1]), hence ebbing its diplomatic tendencies.
But these are all educated guesses as to which task might translate the best to learning, because humans imagined them, and as humans are not the “smartest” ones these days, why don’t we let AI figure it out for itself?
Our gradient descent and optimization algorithms are generally considered among the most consequential algorithms humanity has ever invented. So, why not leave the test time training to these algorithms altogether and let the models learn about learning?
2. Motivation: Why was it needed?
At its heart, this research was driven by a core frustration with the existing Test-Time Training (TTT) paradigm. Prior TTT algorithms have historically relied on a form of artistry. A human “designer” (i.e., a creative researcher) must hand-craft a self-supervised task like the ones mentioned above and hope that practicing this specific task will somehow translate to better performance on the main objective. The paper aptly calls this an “art, combining ingenuity with trial and error,” a process that is extremely vulnerable to humanistic fallacies.
Not only can human-designed tasks perform suboptimally, but they can even be counter-productive. Imagine making a model an expert on rotation-prediction as its TTT task. But now, if an image has direction-specific characteristics, like a pointing-down arrow that signifies “download this file,” gets flipped to a pointing-up arrow because of the TTT task (which signifies upload), it might completely corrupt the understanding of the model for that image.
Moreover, we can extrapolate it to ever-decreasing reliance on human ingenuity and increasing reliance on automation. Tasks like curating a word-bank with thousands of “bad words”, just to classify spam emails, are a relic of the past, that remind us how far we’ve come. Over the years, a conventional rule has emerged: automation has always eclipsed the very human ingenuity that conceived it.
Visual depiction of why manual TTT design would be inferior to Meta-TTT via Gradient Descent.
3. Learning to (Learn at test-time)
Researchers at Meta, Stanford, and Berkley (Sun et al., 2024)5 all came together for this monumental collaboration, and they successfully parameterized the TTT task itself, which means that now the model can choose, instead of humans, which task will have the greatest impact on improving the performance on the main objective.
This means that now the model can not only train on test data, but also choose how that test data is to be used to train itself!
3.1 How Does It Work?
The researchers segregated the entire process into two parts — Inner and Outer Loop, where the Outer loop trains the model on its main objective and defines the TTT task, while the Inner loop trains the hidden layers on the defined TTT task.
3.1.1. The Outer Loop: Taking Human Ingenuity Out of The Equation
This acts as the “meta-teacher” in this system. Apart from making the model learn how to classify images, it’s also assigned to create a curriculum for the inner loop to perform TTT on. It achieves this by transforming the entire TTT process into a one giant, differentiable function and optimizing it from end to end.
This multi-step process can be outlined as below:
The full architectural diagram of the model, including a zoomed-in view of the MTTT layer.
The numbers in black indicate the sequence of information flow in the model (Steps).
Steps 1 & 2: Input Preparation
First, the input image X is broken down into patches, and each patch is then converted into an embedding via Embedding Layers. This gives us a sequence of vectors, the Patch embedding vector, which we’ll call P = (P₁, P₂, …, Pₙ).
Step 3: The Overall Architecture
This vector P is then fed through a series of Stacked MTTT layers, which are also the brain of the model. After passing through all the layers, the final representation is sent to a standard Classification Head to produce the final output. To understand what happens in each MTTT layer, we zoom into one to dissect and understand its inner machinery.
Step 4: Learning From the Embeddings
Each MTTT layer has a set of learnable parameters W₀ (Step 4b), which act as a “generic” or “start-off” state, before it sees any data.
The original input patch embeddings (P) are marked as Step 4a.
Step 5: The Inner Loop and Data Transformation
The Outer Loop now invokes the Inner Loop, which we’ll treat as a black-box for now. As per the diagram, it provides two key things:
- The Starting Point (5b): The Initial layer weights, W₀, are fed to the Inner Loop, along with the current input. The Inner Loop outputs WT weights for the layer, which are tuned specifically for the current input.
WT, W0: The Input-Specific Weights and Baseline Generic Weights, respectively.
P: Patch Embedding Vector.
θI: Learnable Parameters of the Inner Loop.
- The Data (5a): The Input embeddings P are prepared to be processed by the adapted layer by a simple linear transformation (ψ). This is done to increase the expressivity and make every MTTT layer learn different sets of attributes about the input.
Here, the new weights WT, which are now specifically tuned for the puppy image, are loaded into the layer.
Steps 6 & 7: The Main Task Forward Pass
Now that the feature extractor has the specialized weights WT, it uses them to process the data for the main task.
The transformed input embeddings from Step 5a are finally processed by the input-specific feature extractor layer (Step 6) and are yielded as the output of the first MTTT layer (Step 7), which are then processed by multiple other MTTT layers, repeating the process all over again.
Steps 8 & 9: The Final Output
After the data has passed through all the stacked MTTT layers (Step 8) and the final Classification Head (Step 9), we get a final prediction, ŷ.
Test vs Train:
If the model is being tested, ŷ remains as the final output, but if the model is being trained, the output (Step 9) is used to calculate a loss (typically cross-entropy) against the ground truth y.
The Outer Loop, with this loss, calculates the gradient with respect to all parameters, and is hence called the “meta-gradient”. This gradient, along with training the model on the main task, also trains the Inner Loop’s parameters, which define the TTT’s self-supervised task. In essence, it uses the final classification error signal to ask itself:
“How should I have set up the test-time learning problem so that the final outcome would have been better?”
This makes the model setup the most effective supervised task to best improve the performance on the main task, taking human guesswork and intuitive sense completely off the equation.
3.1.2 The Inner Loop: Unveiling the Black-Box
Now that we understand the Outer Loop, we unroll the Black-box, a.k.a. the Inner Loop.
Its goal is to take the generic layer weights (W₀) and rapidly adapt them into specialized weights (W_T) for the input it’s currently observing.
It achieves this by solving the self-supervised reconstruction task, which the Outer Loop designed for it. This self-contained learning procedure looks like this:
Zoomed-in view of the Inner Loop, describing its inner workings.
The numbers in black indicate the sequence of information flow (Steps).
Steps 1-3: Setting Up the Learning Problem
The Inner Loop gets two distinct inputs from the Outer Loop:
- The Input Patch Embeddings (Step 2), and,
- The generic weights for the feature extractor, W0.
As shown in Step 3, these original embeddings P=(P1, P2, ...)
are made into a “test-time dataset”, where each datapoint is a singular patch’s embedding yielded sequentially.
Steps 4 & 5: The Forward Pass – Creating a Puzzle
First, an input patch is passed through the Encoder (a linear layer whose parameters, θΦ, were learned by the Outer Loop). This function “corrupts” the input (Step 4), creating a puzzle that the subsequent network must solve. This corrupted patch is then fed into the Feature Extractor (The ‘Brain’), which processes it using its current generic weights (Step 5) to create a feature representation.
Steps 6 & 7: The Learning Step – Solving the Puzzle
The feature representation from the “Brain” is then passed to the Decoder (a linear layer whose parameters, θg, were also learned). The Decoder’s job is to utilize these features to reconstruct the original, uncorrupted patch (Step 6). The Inner Loop then measures how well it did by calculating a loss—typically Mean Squared Error (MSE)—between its reconstruction and the original patch. This error signal drives the Gradient Step (Step 7), which calculates a small update for the Feature Extractor’s weights.
Steps 8-9: The Final Output
This update process, from the old weights to the new, is shown in Step 8a. After running for a set number of steps, T (until all patches are utilized sequentially), the final, adapted weights (WT) are ready. The Inner Loop’s job is complete, and as shown in Step 8b, it outputs these new weights to be used by the Outer Loop for the main task prediction.
3.2 Attention as a Special Case of the MTTT Framework
So far, we’ve treated MTTT as a novel framework. But here is where the paper delivers its most elegant insight: the attention mechanisms, which are globally accepted as the de facto, are just simple versions of this very same “learning to learn” process. This also makes sense because now the model is not constrained to adhere to a particular schema; rather, it can choose and curate the perfect framework for itself, which makes it act as a superset that encompasses everything, including attention.
The authors prove this with a series of deterministic mathematical derivations (which would be way beyond the scope of this article). They show that if you make specific choices for the “Brain” of the inner loop (the Feature Extractor), the entire complex, two-loop MTTT procedure simplifies and becomes an attention mechanism.
Case 1: Feature Extractor = Simple Linear Model
Linear attention (Katharopoulos et al., 2020)6 is a much faster and similar implementation to the self-attention (Vaswani et al., 2017)7 we use widely today. Unlike self-attention, where we compute the (N×N) attention matrix (where ‘N‘ is the number of tokens) that results in an O(n2) bottleneck, linear attention calculates the KT×V matrix (DXD; ‘D‘ is the hidden dimension), which is linear in N.
By multiplying KT and V matrices first, we circumvent the O(n2) attention matrix, which we calculate in the standard self-attention
When “the brain” is just a single linear layer that takes one learning step (T=1, aka just one patch), its “correction” (the gradient step) is mathematically linear regression. The researchers showed that this entire process collapses perfectly into the formula for Linear Attention. The Encoder learns the role of the Key (K), the Decoder learns the role of the Value (V), and the main task’s Input Transformation (ψ) learns the role of the Query (Q)!
Case 2: Feature Extractor = Kernel Estimator.
Now, if the learning layer (feature extractor) is replaced with a Kernel Estimator (which computes a weighted average), specifically the Nadaraya-Watson estimator (Nadaraya, 1964)8 & (Watson, 1964)9, the MTTT process becomes identical to the standard Self-Attention. The kernel’s similarity function collapses to the Query-Key dot product, and its normalization step becomes the Softmax function.
The standard self-attention formula is also just an instantiation of the “learning to learn” superset
What does this mean?
The authors state that in the past three decades of machine learning and AI, a clear pattern regarding the performance of algorithms can be observed.
We know that:
- When the feature extractor is a linear model, we get fast but not so impressive linear attention.
- When the feature extractor is a kernel, we get the ubiquitous self-attention.
- When the feature extractor is a deep-learning model (an MLP, for example), we get….?
What happens if we put an even better learner (like MLP) inside the Inner Loop? Would it perform better?
4. MTTT-MLP: The Primary Contribution
The answer to the above question is the main contribution of the authors in this paper. They equip the inner loop with a small, 2-layer Multi-Layer Perceptron (MLP) as the feature extractor.
4.1 Self-Attention vs. MTTT-MLP vs. Linear-Attention
The authors put MTTT-MLP to the test in two drastically different scenarios on the ImageNet dataset:
Scenario 1: The Standard Scenario (ImageNet with Patches)
First, they tested a Vision Transformer (ViT) on standard 224×224 images, broken into 196 patches. In this configuration, the O(n²) methods are practical as well, which makes it an even playing field for all models.
- The Results:
- MTTT-MLP (74.6% acc.) beat its theoretical predecessor, MTTT-Linear (72.8% acc.), confirming the hypothesis that more complex learners perform better.
- However, standard self-attention (76.5% acc.) still reigned supreme. Although contrary to our hypothesis, it still makes sense because when you can afford the expensive quadratic computation on short sequences, the original is hard to top.
Scenario 2: The Non-Standard Scenario (ImageNet with Raw Pixels)
The researchers drastically changed the environment by feeding the model raw pixels instead of patches. This inflates the sequence length from a manageable 196 to a massive 50,176 tokens, which is the very arch-nemesis of the standard attention algorithms.
- The Results:
- This comparison could only be held between linear attention and MTTT-MLP because self-attention failed to even run. Modeling 50,176 tokens resulted in 2.5 billion entries in the attention matrix, which immediately threw an OOM (Out-Of-Memory) error on any standard GPU.
- Linear Attention performed mediocre, achieving around 54-56% accuracy.
- MTTT-MLP won this round by a large margin, reaching 61.9% accuracy.
- Even when pitted against a larger Linear Attention model with 3x the parameters and 2x the FLOPs, MTTT-MLP still won by around a 10% margin.
The key takeaway from these experiments was that though self-attention reigned supreme in terms of raw performance, MTTT-MLP provides a huge boost in modeling power over linear attention while retaining the same sweet O(n) linear complexity that allows it to scale to massive inputs.
4.2 Watching How the Inner Loop Learns
To interpret the trends of their novel approach, the authors provide a pair of graphs that help us peek into how the inner loop learns and how the outer loop makes it learn the best possible lessons.
Steps vs. Accuracy: The More The Merrier, But Not Always
The x-axis shows the number of inner-loop gradient steps (T), and the y-axis shows the final classification accuracy on the ImageNet dataset.
As T increases from 1 to 4, the model’s accuracy on the main classification task increases commensurately. This demonstrates that allowing the layer to perform a few steps of self-adaptation on each image directly translates to better overall performance. This shows that the inner loop does indeed help the main task, but the benefit isn’t infinite.
The performance peaks at T=4 and then slightly dips. This means that T=4 is the sweet spot, where the model learns enough to aid the main task, but not enough where the model focuses too much on the current input and forgets generalizability.
Epochs vs. Loss: Synergy Between the Two Loops
The x-axis shows the training epochs, and the y-axis shows the inner loop’s reconstruction loss on the TTT task. The colors of different lines indicate the inner loop’s training steps (T).
This graph is the most information-dense. It gives us a look at how the performance of the inner loop changes as the outer loop learns to design a more sophisticated TTT task.
There are two key trends to observe:
Inner-Loop Optimization (The Vertical Trend)
If you look at the blue line (T=0) as a whole, you’ll notice that it has the highest loss, because it’s the case when the outer loop keeps getting better at designing the TTT task (as epochs progress), while the inner loop doesn’t learn anything from it.
If you look at any single epoch (a vertical slice of the graph), for all the others (T ∈ [1,4]), the loss is lower than the blue line, and for every increment in T, the loss decreases. This indicates that the more the inner loop is allowed to learn, the better its performance gets (which is the expected behavior).
Outer-Loop Meta-Learning (The Horizontal Trend)
This could be a bit counterintuitive, as every single line trends upwards in loss over the course of training. If you notice, all the lines except the blue (T=0) start from relatively the same loss value (at 0th epoch), which is much lower than the blue’s loss. This is because the inner loop is allowed to train on the “not-hard” TTT task. After all, the outer loop hasn’t gotten the chance to design it yet, which causes all except the blue to ace it.
But as soon as the outer loop starts to pick up pace (as epochs go by), the inner loop finds it harder and harder to complete the now increasingly difficult but helpful task, leading to the inner loop’s loss to slowly creep up.
References:
[1] Behrouz, Ali, Peilin Zhong, and Vahab Mirrokni. “Titans: Learning to memorize at test time.” arXiv preprint arXiv:2501.00663 (2024).
[2] Gidaris, Spyros, Praveer Singh, and Nikos Komodakis. “Unsupervised representation learning by predicting image rotations.” arXiv preprint arXiv:1803.07728 (2018).
[3] Devlin, Jacob, et al. “Bert: Pre-training of deep bidirectional transformers for language understanding.” Proceedings of the 2019 conference of the North American chapter of the association for computational linguistics: human language technologies, volume 1 (long and short papers). 2019.
[4] Sun, Yu, et al. “Test-time training with self-supervision for generalization under distribution shifts.” International conference on machine learning. PMLR, 2020.
[5] Sun, Yu, et al. “Learning to (learn at test time): Rnns with expressive hidden states.” arXiv preprint arXiv:2407.04620 (2024).
[6] Katharopoulos, Angelos, et al. “Transformers are rnns: Fast autoregressive transformers with linear attention.” International conference on machine learning. PMLR, 2020.
[7] Vaswani, Ashish, et al. “Attention is all you need.” Advances in neural information processing systems 30 (2017).
[8] Nadaraya, Elizbar A. “On estimating regression.” Theory of Probability & Its Applications 9.1 (1964): 141-142.
[9] Watson, Geoffrey S. “Smooth regression analysis.” Sankhyā: The Indian Journal of Statistics, Series A (1964): 359-372.
Source link
#Age #SelfEvolving