Exploring Automated Weight Space Interpretability
December 10, 2024 • David Atkinson (work advised by David Bau)
Introduction and Related Work
When paired with an architecture, the weights of a neural network fully specify its behavior over all possible input distributions. Further, neural networks are themselves universal function approximators. Why, then, don't we simply learn a function from weights to interpretation? Why all this hard work?
Weight space interpretability (perhaps naively!) takes this question seriously. We would like to train a model, variously called a meta-model or hypernetwork, to predict some property of another model, here referred to as a base model.
Much of the work in this vein has focused on recovering training hyperparameters and test-time performance of base models, which are often small image models, or (even smaller) implicit neural representations (Mildenhall et al., 2021). Many of these target properties are included in the above figure from Schürholt et al. (2024). A notable exception to these relatively simple tasks comes from Langosco et al., (2023) who leverage meta-models to detect trojans in image models, and generate RASP (Weiss et al., 2021) code from compiled TRACR models (Lindner et al., 2023). The current work takes inspiration from those ambitious applications, and investigates spiritually similar settings. The most relevant difference is the focus here on different kinds of generalization, beyond the standard train-test split.
Every project in this "Weight Space Learning" space is forced to confront the tremendously large number of hidden unit permutation symmetries (i.e., neurons can be swapped with each other without changing network behavior, as long as the associated weights are swapped as well). There are three broad responses:
- Graph Neural Networks are naturally permutation equivariant. One simple architecture, used in eg Kofinas et al. (2024) represents bias weights as node features, and neuron weights as edge features. Kalogeropoulos et al. (2024) goes further and proposes an architecture which additionally encodes scaling symmetries. These architectures tend to be conceptually elegant, but slow and memory intensive.
- Just as convolutional networks use weight sharing to achieve translation equivariance, some projects have proposed to do the same for permutation equivariance (e.g., Zhou et al. (2023) and Navon et al. (2023)). These architectures are more efficient, but require careful design for each new base model architecture (although, see Zhou et al. (2024)).
- More straightforwardly, data augmentation techniques can be used to encourage arbitrary equivariances (e.g., Shamsian et al. (2023) and Shamsian et al. (2024).)
Experiments and Methods
Imagenet Classifiers
In our first setting, we ask whether a meta-model can recover the class a base model was trained to recognize.
To train our base models, we start with SqueezeNet 1.1 (Iandola et al., 2016), a 1.2 million parameter pre-trained ImageNet model. After randomly reinitializing the 400 thousand weights following the fire8 module, we train each base model to distinguish a randomly selected ImageNet class from the other 999 classes. We train 60,000 base models in this way.
Our first meta-model is simple: a 6 layer, 400 million parameter transformer encoder model with a model dimension of 1024 and 16 heads. To train it, we flatten each of the seven reinitialized weight matrices and transform them to a vector of dimension \(d_{\text{model}}\). After being fed to the transformer, the transformer's output is linearly mapped to a vector of dimension 1000 (the number of ImageNet classes). We use cross-entropy loss, AdamW, OneCycle learning rate scheduling, and a batch size of 1024.
(Why this particular setup, in which we retrain a portion of an existing model instead of training a new model from scratch? The hope is that the retrained model will be forced to leverage a pre-existing set of semantically meaningful representations. Conjecturally, this encourages newly-learned weights to be distributionally similar across classes, removing a less interesting source of variation for the meta-model to discriminate on.)
After training, we see a respectable 47% top-10 validation accuracy.
Augmentation
In the previous experiment, we saw that the meta-model quickly overfits to the training distribution. To address this, we leverage permutation equivariance of the base model neurons. We randomly permute the weights in such a way as to preserve the model's behavior. We then add those permuted weights to the base model dataset. This should encourage our meta-model to learn a representation of the base model's behavior that is invariant to any symmetry preserving permutation of its weights. And indeed, we see substantial improvements up to augmentation levels which include 400 permutated base models for every one original base model.
After augmentation:
k | Top-k accuracy |
---|---|
1 | 71.3% |
10 | 86.2% |
Testing Generalization: Two-class Classifiers
Generalization over random seeds is encouraging, but if meta-models are to be useful, they will have to generalize over far more dimensions than that. One way to test this is to evaluate generalization over a small distributional shift. Here, we train a small number of base models to recognize two classes, rather than one. If our meta-model has learned a good representation of the base models, it should predict the two trained-on classes as most likely.
To some extent, we see this in the figure and table below.
k | Top-k accuracy, either class present [random acc] | Top-k accuracy, both classes present [random acc] |
---|---|---|
2 | 28.8% [0.4%] | 1.4% [0.0%] |
10 | 56.0% [2.0%] | 8.1% [0.0%] |
100 | 93.2% [19.0%] | 45.7% [1.0%] |
500 | 99.9% [75.0%] | 87.1% [25.0%] |
Testing Generalization: Neuron Permutations
Here, we perform a more demanding test of generalization. Especially with a small enough prediction range, a meta-model classifer might learn undesirable shortcuts, such as low-complexity weight statistics. We would prefer them to learn a richer representation of the base model's behavior. To test this, we take a pretrained meta-model, and test it on a set of base models which have had their neurons permuted without correspondingly permuting the neuron's associated weights. In behavior space, these models should be randomly distributed, even if, by weight distribution statistics, they remain close to their original, unpermuted parents. The meta-model's performance on these base models, then, serves as an indication of what it is doing when it predicts the training class.
What we see suggests that the meta-model is doing something in-between these two hypothesized extremes. Permuting a single layer of neurons (i.e., moving from the blue to the orange line) results in much better than random performance for the meta-model, ruling out perfect behavioral sensitivity. But the performance does drop, and continues to drop as more layers are permuted, suggesting that the meta-model is sensitive to more than layer-wise weight statistics.
Early Forays into Language
Regexes
Here, we report a preliminary experiment in which our base models consist of two-layer, 5,000 parameter GRUs (Cho et al., 2014), trained to mimic the behavior of either the regex (aa)+ or a+. Specifically, to train each base model, we randomly choose one the two regexes, instantiate a GRU, and train it as a binary classifier over strings of (a|b)*s, with lengths ranging from 0 to 10. (Note that we ensure a measure of balance in the training data: 1/3 of the input strings contain a b, making them trivially classifiable. The other two-thirds consist only of as, making the task a parity one.) Given our previous experience image models, it's not surprising that, as we see in the table below, the meta-model (another transformer, with 3 million parameters) is able to successfully distinguish between the two classes of base models, with no data augmentation used.
Models Seen | MSE |
---|---|
60k | 0.72 |
120k | 0.08 |
640k | 0.008 |
Against this encouraging sign, the (small, to be fair) meta-model does need more than 100,000 base models to achieve good performance.
Discussion
In practice, of course, model activations over a chosen dataset provide an enormous amount of information, and it would be foolish to throw that all away. But I do think think there's a more modest, albeit still ambitious role that weight space interpretablity could play, which is to serve as a complementary source of information to more traditional interpretation techniques.
You might worry, for example, that techniques like probing and SAEs depend on a particular data distribution, which leaves those techniques vulnerable to "unknown unknowns". Behavior that only responds to extremely conjunctive or sparse properties of the input distribution is currently difficult to find with traditional techniques. Weight space interpretability can in theory be helpful here.
Of course, a wealth of problems lie between that future and today. I highlight two very high-level ones:
First, ground truth labels are difficult to find. In practice, we are forced to try enforcing certain properties through training. There is a tension here, where, if we knew the ground truth for a particular property of interest, there would be no need for a meta-model. Right now, the hope is that the magic of ML generalization saves us here, but cleverer approaches would be helpful.
Second, every paper so far applies a medium-sized model to small-to-tiny model. It's not clear whether that relationship can be reversed, or whether large models can be successfully decomposed into smaller components without undesirable loss of fidelity (consider Wen et al. (2023)). If not, practical applications of weight space interpretability will be much harder to come by.
I've collected the code for the experiments described above here.