← All Posts

Supervising models with their own activations

Jasper Gilley


TLDR: We look directly at the activation space of Zipfian grokking models using a Generative Latent Prior (GLP) — a diffusion model trained on the task model’s own activations. We discover that the Sisyphean collapse cycle is driven by a single direction in the model’s penultimate-layer representations. Two independent methods — GLP residual analysis and Zipf-weighted gradient decomposition — converge on the same 1-D memorization subspace. Projecting this direction out of the backward pass via a custom gradient rule completely eliminates the collapse pathology, and a replay-consuming meta-model can rediscover the direction autonomously.

In the first essay in this series, we introduced Zipfian grokking: modify the loss weighting in a modular arithmetic grokking setup to follow Zipf’s Law, and the model enters a perpetual cycle of grokking and un-grokking — what we called the Sisyphean dynamics. In the second essay, we showed that adding an inverse dynamics auxiliary objective can partially stabilize training by biasing representations toward Fourier structure. But that approach required domain knowledge: someone had to choose the right transformation family.

Here, we take a different approach. Instead of designing auxiliary objectives, we look directly at the model’s own internal representations and ask: what, geometrically, is going wrong?

Modeling activations with a diffusion model

We train a Generative Latent Prior (GLP)1 — a small flow-matching diffusion model — on the penultimate-layer activations of the task model during training. The GLP learns the empirical distribution of the model’s internal representations: what “normal” activations look like at any given point in training.

The key quantity is the residual: for each training sample, we compute the difference between the model’s actual penultimate-layer activation \(h_1\) and its GLP-denoised manifold projection \(h_1^{\text{manifold}}\):

$$r = h_1 - h_1^{\text{manifold}}$$

The scalar norm \(\|r\|\) tells us how far a sample’s representation is from the learned manifold — a measure of representational novelty. But the full 128-dimensional residual vector tells us something richer: in what direction each sample deviates from normal.

A 1-dimensional memorization subspace

We save full residual vectors at 41 dense snapshot epochs across 100k training epochs of the standard Zipfian grokking setup (\(s = 1.5\), where two training pairs carry ~50% of the gradient).2 The directional structure of the residuals reveals a striking pattern.

At and just before each collapse, the two highest-weight training samples have residuals pointing in nearly the same direction — and that direction is the dominant mode of representational variation:

EpochTest accTop-2 cosinecos(top-2, PC1)PC1 var%Phase
24,00098.5%+0.160.163.6%Approaching peak
25,00099.0%+0.680.904.2%Fourier peak
26,00099.2%+0.920.9810.1%Pre-collapse
27,00099.2%+0.750.9816.2%Collapse onset
28,00083.3%+0.440.734.1%Collapsing

At epoch 26,000 — one snapshot before collapse — the memorization direction captures the top-2 residuals almost exactly (cos = 0.98 with PC1). The projections onto this direction track Zipfian weight rank:

GroupProjection at epoch 27,000
Top-2 (52.6% of gradient)+24.6
Top-5+13.5
Top-10+5.6
Top-28 (1%)+1.9
Population mean−0.05

The memorization pressure from the Zipfian weighting carves a single groove in the 128-dimensional penultimate layer. And the same direction recurs across all three collapse cycles — cosine similarity 0.88–0.93 across collapses separated by 30,000 epochs. The memorization subspace is a structural feature of how Zipfian gradient pressure distorts the Fourier manifold, not an accident of a particular training epoch.

Fourier solution contamination

The most important finding is about what happens to the Fourier solution during the pre-collapse phase.

At epoch 27,000, the model still classifies at 99.2% test accuracy, but \(h_1\) is effectively one-dimensional: PC1 captures 94.5% of variance, and the participation ratio — the effective number of dimensions carrying variance, computed as \((\sum \lambda_i)^2 / \sum \lambda_i^2\) — has dropped to 1.1 out of 128. If you classify test samples by nearest class centroid (NCC) in raw \(h_1\) space, accuracy is a pitiful 27.4%. The representations look ruined.

But remove the single memorization direction and measure NCC in the orthogonal complement:

EpochModel accNCC (raw h1)NCC (h1 − PC1)
24,00098.5%99.6%98.8%
25,00099.0%97.6%99.8%
26,00099.2%81.9%99.9%
27,00099.2%27.4%99.8%
28,00083.3%66.7%90.4%

At epoch 27,000 — where raw \(h_1\) is essentially 1D and raw NCC is 27% — removing one direction recovers 99.8% NCC accuracy. The Fourier solution is completely intact in the 127 dimensions orthogonal to the memorization direction, even when that direction captures 94.5% of variance.

It is like trying to hear a conversation next to a jet engine. The conversation (Fourier structure) hasn’t gotten quieter — the engine (memorization amplification) has gotten louder. This is contamination, not degradation: the generalization structure is hidden by memorization variance, not damaged by it.

After the collapse event (epoch 28,000), the picture changes. Removing the memorization direction no longer helps — the damage is now real structural damage, not recoverable contamination. This confirms the distinction: the pre-collapse phase is a window of opportunity where the solution can be saved by suppressing a single direction.

The memorization direction from the loss landscape

We now have two independent ways to identify the memorization direction:

Method 1: GLP residuals. The data-driven approach described above — PC1 of the \(h_1\) GLP residual covariance at pre-collapse epochs.

Method 2: Zipf-weighted effective gradient. Compute \(\partial L / \partial h_1\) for each training sample, then weight by the Zipfian loss weights to get the effective gradient — the net force the loss is exerting on \(h_1\) representations. At pre-collapse, the direction of this effective gradient is overwhelmingly aligned with the memorization direction:

EpochPhasecos(effective gradient, memorization dir)Top-1% contribution along \(m\)
24,000Build−0.49
25,000Peak−0.59
26,000Pre-collapse−0.8593.4%
27,000Fragile−0.27
28,000Collapsed+0.10

The negative sign means the loss gradient points along \(-m\), so gradient descent pushes features along \(+m\) — the memorization direction. At epoch 26,000, 85% of the effective gradient’s direction is memorization force, contributed almost entirely by the top 1% of samples (28 out of 2,822). The generalization gradient lives in the orthogonal complement, distributed broadly across all samples.

The two methods — manifold-based (GLP residuals) and loss-landscape-based (Zipf-weighted gradient) — identify the same direction: cosine similarity > 0.97 at pre-collapse epochs. This convergence makes the direction actionable: project it out of the gradient and the memorization force vanishes while the orthogonal generalization signal stays intact.

Suppressing the memorization gradient

We insert a custom_vjp3 between the encoder output and the classification head. In the forward pass, it is an identity. In the backward pass, it projects the memorization component out of \(\partial L / \partial h_1\):

$$\frac{\partial L}{\partial h_1}\bigg|_{\text{clean}} = \frac{\partial L}{\partial h_1} - \left(\frac{\partial L}{\partial h_1} \cdot \hat{m}\right) \hat{m}$$

We test three conditions, identical in architecture, optimizer, and seed, differing only in whether and how \(\hat{m}\) is chosen:

BaselineGradient-basedFrozen GLP direction
\(\hat{m}\) sourceN/AZipf-weighted eff. gradientPC1 of h1 GLP residuals @ epoch 27k
Updated?Every 100 epochsNever
Accuracy collapses300
Final test acc0.8710.9990.998
Baseline training dynamics (no suppression)
Baseline: the Sisyphean collapse cycle. Three collapses at epochs ~28k, ~58k, and ~88k. Top-left panel shows test accuracy (red) and train accuracy (blue).
Training dynamics with gradient-based memorization suppression
Gradient-based suppression: zero collapses. Test accuracy climbs to 99.9% and holds for the remaining 80k epochs.

Both suppression conditions completely eliminate the Sisyphean collapse. The gradient-based method, which recomputes the memorization direction every 100 epochs, achieves 99.9% sustained test accuracy with healthy representation geometry throughout (\(h_1\) PC1 stays below 12.5%, participation ratio stays above 11.5).

The frozen GLP direction — extracted once at epoch 27,000 and never updated — is equally effective: 99.8% sustained accuracy with zero collapses. It eventually develops a geometric anomaly (\(h_1\) PC1 climbs to ~97% around epoch 60k), but accuracy holds at 99.9% through the reorganization. That a single direction, discovered by a manifold model with no knowledge of Zipfian weights, suffices to prevent all future collapses confirms that the memorization subspace is low-dimensional and causally decisive.

Autonomous discovery

The previous sections used directions that were identified by us: we chose which epoch to extract from, which layer to look at, and which analysis to run. Can a meta-model discover the relevant direction on its own?

Replay as a metacognitive buffer

We treat the 41 saved snapshots from the directional residuals experiment as a replay buffer — a record of the model’s internal representational states across its full training history, including three complete collapse cycles.4 Each snapshot contains the full \(h_1\) activation covariance, which encodes the geometry of the model’s representations at that moment.

We train tiny 1-D probes on these replayed states. Each probe takes as input the recent covariance drift \(\Delta C_e = C_e - C_{e-\Delta}\) — how \(h_1\)’s representational geometry has been changing — and outputs a direction \(\hat{m}\) in \(h_1\), optimized to maximize \(\hat{m}^T \Delta C_e \hat{m}\) at epochs that precede geometric collapse (future peak \(h_1\) PC1 variance). Critically, the probe only trains on states where the model still looks healthy (\(h_1\) PC1 < 0.2), and its input is recent change, not the current state itself. It must find the dangerous axis from recent representational drift while the model still appears fine — a genuinely metacognitive test.

The replay scores

Replay meta-model scores overlaid on training dynamics
Top: current (black) and future (red) h1 PC1 variance fraction — the three collapse spikes. Middle: test accuracy (green) with pink shading marking “collapse soon” windows. Bottom: z-scored replay scores from several probe variants. The healthy past-trajectory probe (orange) spikes positively before each collapse, while the model still has >95% accuracy.

The healthy past-trajectory probe (orange, bottom panel) assigns high scores to epochs that precede collapse, even when trained only on the first two cycles. The signal comes from representational drift alone — no access to accuracy, loss, or any outcome measure — and appears thousands of epochs before accuracy drops.

Alignment with the known memorization axis

The best healthy past-trajectory probe recovers essentially the same direction as the known memorization axis:

Reference directionCosine with replay-discovered direction
Frozen memorization dir @ epoch 27k0.999
Residual PC1 @ epoch 58k (cycle 2)0.958
Residual PC1 @ epoch 88k (cycle 3)0.933

This is not approximate alignment — 0.999 cosine in 128 dimensions is near-identity. The replay meta-model, with no knowledge of Zipfian weights, Fourier structure, or modular arithmetic, found the same structural axis that our hand-crafted analysis identified.

The direction also generalizes functionally. Projecting the replay-discovered direction out of raw \(h_1\) on held-out cycle-3 collapse snapshots recovers nearest-class-centroid accuracy:

EpochRaw test NCCNCC after projectionPhase
86,00099.8%100.0%Pre-collapse
87,00095.3%100.0%Approaching collapse
88,00070.4%94.4%Collapse

Using this direction as the frozen suppression target in the same custom_vjp intervention — projecting it out of \(\partial L / \partial h_1\) starting at epoch 20,000 — eliminates the collapse cycle entirely:

MetricBaselineReplay-discovered frozen direction
Accuracy collapses30
Final test acc0.8710.9997
Max test acc0.999 (brief)0.9998 (sustained)
Training dynamics with replay-discovered direction suppression
Causal validation: training dynamics when projecting out the replay-discovered memorization direction. Test accuracy (red, top-left) climbs to 99.9% and holds for the remaining 80k epochs — zero collapses. The late geometric reorganization (h1 PC1 climbing around epoch 55k, middle-left) mirrors the earlier frozen-direction experiment and does not affect accuracy.

The replay-discovered direction eliminates the Sisyphean collapse just as effectively as the hand-identified direction. A meta-model consuming only replayed representational states can autonomously discover a low-dimensional control variable that is causally sufficient to fix the core pathology.

Metacognitive supervision

The memorization direction result demonstrates something broader than a fix for a toy problem. It is a concrete example of what we might call metacognitive supervision: modulating a model’s learning based on the organization of its own representations, rather than based on per-example performance.

The relevant control variable has three properties worth highlighting:

  1. It is low-dimensional. A single direction in 128-D space governs whether training is healthy or pathological. Projecting out that direction removes ~85% of the memorization force while affecting only ~3% of the generalization gradient.
  2. It is stable. The same direction recurs across collapse cycles separated by 30,000 epochs (cosine > 0.88). The pathology has a fixed geometric signature in representation space, even though the Fourier solution itself takes on different neural instantiations each time the model rebuilds after collapse.
  3. It is autonomously discoverable. A replay-consuming meta-model, with no prior knowledge of modular arithmetic or Fourier structure, recovers the memorization direction from internal representational statistics alone — and the recovered direction is causally valid.

These properties — together with the fact that they hold in a setting where we understand the ground truth completely — suggest that similar control variables may exist in larger-scale systems where the ground truth is not known.

Standard training treats models as passive recipients of data: present examples, compute gradients, update weights. The model’s own learned representations — which may already encode rich knowledge about the task’s structure — play no role in deciding how new information should interact with existing knowledge. This is understandably destructive. The Zipfian grokking model already knows the Fourier solution; the problem is that naive backpropagation lets a handful of high-weight examples overwrite that knowledge anyway.

Biological learners do not work this way. The hippocampal-neocortical complementary learning systems interplay5 aggressively filters for information that is genuinely new and surprising before allowing it to interact with consolidated representations. Replay, gating, and consolidation ensure that what a system already knows is protected from what it is currently experiencing.

What we have demonstrated here is a scaled-down but genuine proof of concept for this idea. The model’s own activation space contains enough structure to identify which gradient components are destructive — and suppressing those components lets the model keep what it already knows. The model is, in a sense, more intelligent than we typically treat it as being; it is just rarely given the opportunity to bring that intelligence to bear on its own learning process. This line of work is about giving it that opportunity.


Code

All code for these experiments can be found here.


Citation

@article{gilley2026activationsupervision,
  title   = {Supervising models with their own activations},
  author  = {Gilley, Jasper},
  year    = {2026},
  month   = {May},
  url     = {https://jagilley.github.io/activation-supervision.html}
}

Notes

  1. Luo, Feng, Darrell, Radford, and Steinhardt, Learning a Generative Meta-Model of LLM Activations (2026). We adapt their architecture — a stack of SwiGLU MLP blocks with flow-matching objective — to our 128-dimensional setting.
  2. 2-layer MLP, hidden dims [128, 128], \(p = 97\), 30% train split, AdamW (lr = \(10^{-3}\), weight decay = 1.0), full-batch gradient descent.
  3. JAX’s mechanism for defining custom backward passes, analogous to PyTorch’s autograd.Function.
  4. Snapshots are spaced at 1k-epoch intervals with denser sampling around expected collapses. The probes are trained on the first two collapse cycles and evaluated on the held-out third cycle.
  5. McClelland, McNaughton, and O’Reilly, Why There Are Complementary Learning Systems in the Hippocampus and Neocortex (1995).
← All Posts

Jasper Gilley — Twitter Github