Sparse Dictionary Learning and Transformer Interpretability
An informal note on sparse coding and its application to language model interpretability
High-level takeaway
In tackling the curse of dimensionality, the task of wrangling high dimensional, superimposed and composite feature spaces into something manageably human-interpretable is currently a main task of those working in mechanistic interpretability. For the past year, Anthropic has been teasing a line of research which involves applying compressed sensing techniques—such as sparse dictionary learning—to the task of deconstructing the hierarchy of learned semantic structures in autoregressive language models similar to those in the GPT family.
While Anthropic’s public updates on these experiments have, so far, been limited to results on toy models, the most promising work in this area has come from Yun et al. with this paper in which they apply sparse dictionary learning to a pre-trained BERT model and find evidence of three tiers of semantic structure: word-level polysemy disambiguation (low-level factors), sentence-level pattern formation (mid-level factors) and long-range dependency (high-level factors). They provide a visualization website to explore these three categories and how they develop across each of the model’s layers.
As their work is focused specifically on BERT, I made some relatively minor adjustments to their code to apply this method to GPT models. My (ongoing) adjustments to their code can be found here.
While I don’t have the compute to train a sparse dictionary for GPT, there are threads I’m excited about picking up with this work. The so-called low-level factors are those whose sparse dictionary activation values peak at early layers and gradually decay to a low plateau at later layers. The mid-level and high-level factors both gradually peak at later layers but the authors of the paper have to manually sort the two categories as they don’t have an automatic heuristic to distinguish them. Attention heads play a big role in managing both mid-range and long-range semantics so in order to distinguish between the two it might be useful to additionally learn dictionaries for the mid-layer residual stream—the output right after the attention, before the MLP—and adapt circuit analysis techniques to analyze the two dictionaries.
Anthropic has the institutional capacity to adapt these techniques for models much larger than the twelve layer BERT/GPT models in this work. If they do, it would be interesting to see how the hierarchical semantic representations change as the models scale in parameter size. It might also be worth exploring this method in different data scaling regimes. Would the factorizations look the same in models trained in the Chinchilla scaling law regime versus a high epoch/low data regime? Taking inspiration from Eleuther’s Pythia suite and charting how the semantic representations change during training would also be a worthwhile line of work.
Informal technical discussion: sparse dictionary learning and its application
Sparse dictionary learning is a family of numerical methods for learning an overcomplete basis set to represent input data. Definitionally, an overcomplete basis makes for ambiguity in the representation of input data. Enforcing sparsity on this overcomplete basis is how ambiguity is managed—even sparse, the overcomplete basis allows for a richer representation of the data than would a complete basis. For models like GPT or BERT, the goal is to learn a basis set which can model input data at each hidden state. For those interested, here is a gentle but significantly more technical explanation of sparse dictionary learning.
The dictionary training method in Yun et al.’s paper is roughly:
Collect hidden states for many batches of sentences at each layer (or every-other layer, as they do in practice).
Apply the FISTA algorithm to the sparse dictionary learning objective: for a given number of iterations, calculate and update the error between the hidden state and the dictionary multiplied by an activation vector.
Each token embedding of the resulting error is weighted by the inverse of the token’s frequency in the training corpus.
The basis dictionary is updated using a second-order optimization technique from the AdaGrad family of optimization algorithms.
The inference and visualization methods in Yun et al.’s paper are roughly:
Collect the hidden states for a batch of sentences at a given layer or multiple layers.
Apply the FISTA algorithm for a given number of iterations to minimize the error between the hidden state and the learned dictionary multiplied by an activation vector (this vector is the only thing getting updated each iteration).
Collect the tokens which set off the highest activation values for some specified subset of basis vectors.
The so-called low-level factors are those whose activation values peak at early layers and gradually decay to a low plateau at later layers.
Mid-level and high-level factors both gradually peak at later layers, but for now the authors of the paper have to manually sort the two categories; they don’t have an automatic heuristic to distinguish them.
The authors adapt the black box interpretability tool LIME to generate saliency maps: for an inference sentence, they generate multiple versions in which tokens are randomly masked, and then run the LIME explainer module to detect the tokens whose perturbation through masking is most disruptive to the original activation value.
Overall, this Yun et al.’s paper, code and visualization website is the most promising work to date in the very nascent field of mechanistic interpretability. I’m looking forward to what labs with the dedicated resources for this kind of work are able to do in this direction.
Future work
Training data
The authors acknowledge shortcomings and limitations of their work in their paper and there is only one thing that stands out enough for me to mention as a limitation that they seem to have not: the so-called “sentences” they generate aren’t actually sentences. They are samples of the corpus taken based on length. Here is an example of a few:
“ 's prosecutors convicted Wilson and his brother Jackie of murder, and Andrew Wilson was sentenced to death. On April 2, 1987 the Illinois Supreme Court overturned the convictions, ruling that Wilson was forced to confess involuntarily after being beaten by police. \n = = = First campaign for Mayor, 1983 : challenge to Jane Byrne"
' her. They have four children : Nora, Patrick, Elizabeth and Kevin, all born at Mercy Hospital and Medical Center in Chicago. Their second son, Kevin, died at age two of complications from spina bifida in 1981. \n Daley graduated from De La Salle Institute high school in Chicago and obtained'
'th Ward, the Daley family\'s home neighborhood and ward. " I\'ve never known them to be anything but hard working, and I feel for them at this difficult time, " Daley said. " It is fair criticism to say I should have exercised greater oversight to ensure that every worker the city hired'
While the context length is long enough to capture full thoughts and the structure outside of the truncated beginnings and ends of each “sentence”, I would be curious to see if a more deliberate splitting of the corpus would affect the basis factorizations.
Threads worth pursuing
Yun et al. note that it took about 2 days on a Nvidia 1080 Ti GPU to train the dictionary. I don’t personally have these resources so I won’t be able to provide sparse dictionaries for the GPT models. Nevertheless, in addition to more careful data curation methods, there are some next-steps I would pursue in my adaptation of this work for GPT models:
Revisit their method for saliency mapping through random masking. Does this make sense for GPT models? Is there a more GPT-appropriate perturbation for this task?
Mid-level and high-level factor disambiguation ideas:
Circuit tracing and causal tracing methods might be adapted to distinguish between mid-level and high-level factors
Yun et al.’s work focuses just on the residual stream at the output of each layer; it could be worth it to try to learn dictionaries for the mid-layer residual stream—the output right after the attention layer, before the MLP layer. It is easy to imagine that attention heads might play a big role in both mid and long range semantics.