When is Test-Time Training Valid?

A dependency analysis of three TTT strategies, from this discussion.

Each matrix shows which tokens influence each token's loss. Green = valid dependency (matches standard AR). Red = information leakage.

 
Valid dep Leakage No dep
Key insight: Case 3's dependency matrix is identical to Standard AR — per-document TTT is just another way of conditioning on context within each document. Cases 1 & 2 introduce cross-document or future-token dependencies that corrupt the evaluation.

Three cases

Here are 3 different instantiations of test-time training (TTT) I see for PRs in this repo, from least-to-most legitimate.

1: Train on val, then measure loss on val

def eval_ttt_cheat(model, val_tokens, num_epochs):
    for epoch in range(num_epochs):
        for batch in batches(val_tokens):
            sgd_step(model, loss(model, batch))
    return mean(loss(model, batch) for batch in batches(val_tokens))

No, not even close. This is literally training on test and obviously invalid. You could just init a large random model and train for many epochs and your loss will be 0.

2: TTT auto-regressively on token stream

def eval_ttt_token_stream(model, val_tokens, chunk_size, context_len):
    ttt_model = deepcopy(model)
    for ci, chunk in enumerate(chunks(val_tokens, chunk_size)):
        context = val_tokens[max(0, chunk.end - context_len) : chunk.end]
        loss_per_token = ttt_model(context)
        accumulate(loss_per_token[chunk.start_in_context:])
        sgd_step(ttt_model, mean(loss_per_token[chunk.start_in_context:]))
    return total_bpb()

Well, it depends on what the validation set is supposed to measure. The validation token stream in this competition is 50k documents from FineWeb concatenated together with BOS tokens starting each document. So, assuming we want to model the distribution of sequences in FineWeb, then the validation set is 50k samples from that underlying distribution. If the competition chose 10k documents or 100k documents, the expectation of the validation loss should not change: only the variance of the estimator.

However, if we train on the whole token stream: this is not the case — we have more training data for the 100k document case, so the loss will in expectation be lower. This is clearly broken and differs from the standard definition of what an eval set is, so this should not be allowed (I think it is under current rules). You could say that the point of this competition is to model this exact token sequence (50k docs) which would make this valid, but I think that is quite unnatural.

Another way to see why this should be disallowed: imagine you were trying to measure your model's "emergent capability", so you make an eval set of 50k examples of something the model was not trained to do (let's say multiply numbers). Then, if you did token-stream TTT at inference time and your model got non-zero accuracy, you cannot say it is an "emergent phenomenon" but rather that your model can learn to multiply numbers when trained on thousands of examples. Token-stream TTT is identical to just continuing to train for 1-epoch on the eval set and reporting the mean loss (in the first epoch losses for each batch are losses before training on them).

3: TTT auto-regressively on each document, independently

def eval_ttt_docbased(model, val_tokens, chunk_size, context_len):
    for doc in split_at_BOS(val_tokens):
        eval_ttt_token_stream(model, doc, chunk_size, context_len)
    return total_bpb()

Yes! The dependency graph is identical to a non-TTT case: token i in document j is only dependent on tokens i'<i in document j: there is no dependency on tokens ahead (the problem of just training on test) and no dependency between documents (the problem of doing TTT on the whole sequence).

I think people get hung up on the fact that it is doing backprop at test-time: it is just another operator — no different than attending to tokens or updating some hidden state. If you attend to the label token: you are leaking the answer — it is the same for TTT. For example, see Energy-based Transformers: the entire inference procedure is backprop. Similarly, look at LLM test-time-training papers like End-to-End Test-Time Training for Long Context or Let's (not) just put things in Context: these papers all do case 3 here.

Other notes

"I've also noticed that people are doing TTT on multiple parts of the evaluation sequence in parallel (via DDP rank), justifying it by only assigning each 'doc' to one rank"

This is just an efficiency improvement. Assuming you follow case 3 outlined above (as in this PR), then parallelizing across documents is identical to doing eval over the token stream autoregressively (since each document is independent of other documents).

"How does TTT even make sense?"

Let me know if case 3 above (and referenced papers) clear it up. I agree the first two cases are incoherent.