Inspecting Head Behavior through Pattern Projections
In this post, I explore a novel way to interpret attention head functionality in transformer models by reformulating the Query and Key projections into a single matrix, $W^P = (W^Q)^T W^K$. This formulation allows us to view each head not just as a black-box attention unit but as a distinct pattern extractor operating in model space. By projecting an input embedding onto $W^P$, we obtain a “pattern vector” that can be directly compared to vocabulary embeddings using cosine similarity. This method opens up a new avenue for understanding what each head is “searching for” in the input sequence.
Rather than attempting to provide a definitive explanation of attention, the aim here is to demonstrate the potential of this approach through some initial experiments. I’ll start by applying the method to BERT’s early layers—where patterns tend to be more syntactic and less confounded by positional or deeper semantic factors—and then extend the analysis to GPT‑2. The following code shows how to extract these pattern projection matrices and leverage them to probe head behavior.
By offering an alternative, low-rank perspective on head behavior, I hope to provide both an accessible teaching tool and a stepping stone toward more advanced interpretability research. Let’s dive into the code and see what insights we can uncover!
▂▂▂▂▂▂▂▂▂▂▂▂
S1. BERT Heads
A naive approach seems to work well for at least the first few layers of BERT, where, even without adding position information or running the tokens through the model, we can simply compare the head patterns back to the input embeddings with cosine similarity.
Let’s check out some results!
1.1. Analysis Functions
I’m doing this analysis first on bert-base-uncased, due to its simplicity and “small” size.
First, load the model and extract its embedding layer.
import torch
from transformers import AutoModel, AutoTokenizer
import numpy as np
import scipy.spatial
import pandas as pd
# Load BERT-base model and tokenizer
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
# Extract embedding layer
embedding_layer = model.embeddings.word_embeddings.weight.detach().cpu().numpy()
tokenizer_config.json: 0%| | 0.00/48.0 [00:00<?, ?B/s]
config.json: 0%| | 0.00/570 [00:00<?, ?B/s]
vocab.txt: 0%| | 0.00/232k [00:00<?, ?B/s]
tokenizer.json: 0%| | 0.00/466k [00:00<?, ?B/s]
model.safetensors: 0%| | 0.00/440M [00:00<?, ?B/s]
Next, a function for constructing the Pattern projection matrix for a given head in a given layer of BERT.
Get the $W^P_i$ matrix for the specified layer and head in BERT.
def get_BERT_WP(layer, head):
# Extract W^Q and W^K for the chosen head
W_Q = model.encoder.layer[layer].attention.self.query.weight.detach().cpu().numpy()
W_K = model.encoder.layer[layer].attention.self.key.weight.detach().cpu().numpy()
# Extract just the slice for this head
head_size = W_Q.shape[0] // num_heads
W_Q_i = W_Q[head * head_size:(head + 1) * head_size, :]
W_K_i = W_K[head * head_size:(head + 1) * head_size, :]
# Compute W^P for this head (Transposing W^Q_i first)
W_P_i = np.dot(W_Q_i.T, W_K_i) # Shape (768, 768)
return W_P_i
Similarity metrics–cosine similarity is what seems to work, I guess the vectors need to be normalized.
# Compute dot product similarity
def dot_product_similarity(vec, matrix):
return np.dot(matrix, vec)
# Compute cosine similarity
def cosine_similarity(vec, matrix):
return 1 - scipy.spatial.distance.cdist([vec], matrix, metric="cosine")[0]
For a given head and input word, find the closest matching vocabulary embeddings to the resulting pattern vector.
def find_head_matches(W_P_i, input_word, k=15):
# Some tokens may get split up...
tokens = tokenizer.tokenize(input_word)
if len(tokens) != 1:
print(f"Warning: The word '{input_word}' was tokenized into multiple tokens: {tokens}. Using the first token.")
token = tokens[0]
# Convert the token to its corresponding ID in the vocabulary.
word_id = tokenizer.convert_tokens_to_ids(token)
# Extract their original embeddings
word_emb = np.array(embedding_layer[word_id])
# Project the token embedding to get the pattern vector.
pattern = np.dot(word_emb, W_P_i)
# Caclulate cosine similarities.
similarities = cosine_similarity(pattern, embedding_layer)
# Sort to retrieve the top k.
top_indices = similarities.argsort()[-k:][::-1]
top_words = []
# Construct a list of tuples to return, (word_str, similarity)
for idx in top_indices:
# Convert the vocabulary index back into a token string.
word_str = tokenizer.convert_ids_to_tokens(int(idx))
top_words.append((word_str, similarities[idx]))
return top_words
1.2. Probing Layers 1 - 4
I had GPT suggest a few words to try, and found the closest matching embeddings for all of the heads in the first four layers.
Skimming through, here were some interesting examples:
Disambiguating Heads
These all seem to be examples where the head is looking for a context word to clarify the right meaning of the input word.
- For example, when given the word “run”, these two heads appear to look for context for it:
- Layer 0, head 1:
- election, innings, theatre, mayor, sales, reelected, theaters, selling, wickets, theater, commercials, electoral, elections, gallons, elected,
- Layer 1, head 3:
- pitcher, home, inning, goalkeeper, wickets, schumacher, pitchers, bowler, baseball, wicket, shortstop, nfl, mlb, ##holder, drivers,
- Layer 0, head 1:
- I noticed a number of examples of this behavior in Layer 0, head 2:
- “dog” –> hot, sent, watch, guard, radio, guide, send, hound, unsuccessfully, voice, neck, sends, success, feel, mas,
- “bed” –> truck river playoff creek speed flow vehicle lecture fish stream flower thunder drain narrow dry
- “drive” –> disk disc leaf flash club magnetic wheel gene reverse rip data blood commercially serpent captive
Unsure…
“happy”
- Layer 0, head 3:
- make, making, made, makes, people, not, ##made, women, ##ria, something, felt, city, dee, men, paper,
- Layer 3, head 8:
- picture, faces, genesis, aftermath, emotional, expression, concurrently, pictures, expressions, emotion, hearts, account, jasper, mental, disorders,
Special Tokens
Many of the results were a pattern closely resembling the special tokens. (Matching the finding in “What Does BERT Look At?”, pdf)
For example, layer 1 head 1,
“couch” –> [CLS], [MASK], [SEP], ##⁄, ##rricular, ##fully, ##vances, ##ostal, pmid, ##⋅, ##atable, ##tained, ##lessly, ##genase, ##ingly,
With cosine similarities 0.55, 0.29, 0.23, 0.17, 0.14, …
Self-Attending
Some patterns matched the input word and its synonyms, implying the head is attending to the input token.
Layer 2, head 6,
“couch” –> couch, sofa, lagoon, ##ppel,
With cosine similarities: 0.23, 0.2, 0.17, 0.17
# Select a sample set of words
words = ["couch", "dog", "run", "happy"]
# Layers to process (0 and 1)
layers = [0, 1, 2, 3]
num_heads = model.config.num_attention_heads
# Store results
results = []
# For each of the layers / heads / words...
for layer in layers:
for head in range(num_heads):
# Get the Pattern projection matrix for this head.
W_P_i = get_BERT_WP(layer, head)
# For each of the words...
for word in words:
# Find the matching word embeddings.
matches = find_head_matches(W_P_i, word, k=15)
# Separate the words and scores.
top_k_strs = ""
top_k_sims = ""
# Turn them into strings
for word_str, sim in matches:
top_k_strs += f"{word_str:>8}, "
top_k_sims += f"{sim:.2}, "
# Add the result as a row.
results.append({
"Word": word,
"Layer": layer,
"Head": head,
"Top-k": top_k_strs,
"Scores": top_k_sims
})
# Convert results to DataFrame and display
df_results = pd.DataFrame(results)
# Set pandas precision to 3 decimal points
pd.options.display.float_format = '{:.3f}'.format
display(df_results)
df_results.to_csv("bert_head_results.csv")
Word | Layer | Head | Top-k | Scores | |
---|---|---|---|---|---|
0 | couch | 0 | 0 | for, with, ", on, ... | 0.16, 0.15, 0.14, 0.14, 0.13, 0.13, 0.12, 0.12... |
1 | dog | 0 | 0 | engine, roman, rome, html, vehic... | 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.13, 0.13... |
2 | run | 0 | 0 | an, [SEP], the, [MASK], ... | 0.14, 0.14, 0.14, 0.13, 0.12, 0.11, 0.1, 0.1, ... |
3 | happy | 0 | 0 | [MASK], -, to, on, ... | 0.27, 0.16, 0.16, 0.16, 0.14, 0.14, 0.14, 0.14... |
4 | couch | 0 | 1 | [MASK], [CLS], -, ,, ... | 0.44, 0.27, 0.16, 0.15, 0.15, 0.15, 0.14, 0.14... |
... | ... | ... | ... | ... | ... |
187 | happy | 3 | 10 | [CLS], [SEP], ¨, ##⋅, forg... | 0.29, 0.2, 0.11, 0.11, 0.11, 0.1, 0.1, 0.1, 0.... |
188 | couch | 3 | 11 | [CLS], [SEP], allmusic, credits, sherlo... | 0.22, 0.12, 0.089, 0.088, 0.087, 0.086, 0.084,... |
189 | dog | 3 | 11 | [CLS], ##⁄, ##icio, ##igraphy, trans... | 0.25, 0.12, 0.11, 0.1, 0.1, 0.1, 0.099, 0.097,... |
190 | run | 3 | 11 | according, ##ː, checked, took, depen... | 0.12, 0.11, 0.1, 0.098, 0.095, 0.094, 0.093, 0... |
191 | happy | 3 | 11 | [CLS], [SEP], icao, nothin, someth... | 0.28, 0.15, 0.1, 0.1, 0.099, 0.09, 0.09, 0.09,... |
192 rows × 5 columns
1.3. Probing Specific Heads
When suspecting a behavior in a particular head, this version loops through different groups of words to try on that head.
layer = 0
head = 3
print(f"\n==== Layer {layer}, Head {head} ====\n")
W_P_i = get_BERT_WP(layer, head)
word_groups = [
# Note: Some words break into multiple tokens, so I avoid them:
# "joyful", "ecstatic", "elated", "cheery"
["happy", "content", "cheerful", "sad", "miserable", "depressed"],
["scared", "confused", "hopeful", "discouraged"],
["cat", "wolf", "puppy"],
["walk", "walking", "walked"],
["jump", "jumping", "jumped"],
["justice", "freedom", "democracy"]
]
for input_words in word_groups:
for input_word in input_words:
print(f"{input_word:>10}: ", end="")
matches = find_head_matches(W_P_i, input_word)
for word, sim in matches:
# Print out the matching words, padded to 10 characters each
print(f"{word:>10} ", end="")
print()
print("\n\n--------------------------------\n\n")
==== Layer 0, Head 3 ====
happy: make making made makes people not ##made women ##ria something felt city dee men paper
content: make making [MASK] made makes una lot town ir con arrow city people ##made nick
cheerful: [MASK] arte una cod 338 ag im ir city apr 336 ##cc 268 pot ##tm
sad: [MASK] [CLS] making make made paper light makes technology tan ##made obe ir business я
miserable: making [MASK] make fir made jenny una ##made city lara makes veil li bar sci
depressed: [MASK] technology making make ##made [CLS] veronica made wasps una craft spice nora ##bic paper
--------------------------------
scared: [MASK] [CLS] fir una paper craft ##ior hoc я ##owe ir ##mas nothing im technology
confused: nobody never not nothing without trust un absence grant women paper technology resistant lacking force
hopeful: [MASK] un city all paper una town mist country women art village para ##cc light
discouraged: [MASK] una forte diva ag katie app jennifer jenny olympia technology nora ##ever disco hoc
--------------------------------
cat: [MASK] [CLS] fifty una amazon couple bride twenty þ abe quo forty abdul 50 ##iss
wolf: [MASK] ##ada ##oor counter melody ##中 sv dementia lullaby ##ধ ##ad pen ##བ ##oche ##道
puppy: [MASK] [CLS] una un im nec 266 disco 336 paper technology 338 334 ina pac
--------------------------------
walk: help helping let saw watching helped helps watched make letting watch made seen making on
walking: pro probe feeling woman saw tech women watch paper op view ec pod preview angel
walked: ... [MASK] as ? when lara va if be feeling miranda ##zzi ##was 16 tech
--------------------------------
jump: make help made saw watch making seen let helped see from seeing watching force helping
jumping: [MASK] un una women people technology woman ag paper town force spirit communication diva sam
jumped: [MASK] lara force [CLS] pac un tech ai ##ida anything if ##g grayson be ...
--------------------------------
justice: men help people company di light team make film women power game to man force
freedom: help offer escape make grant give company break force display di demand given technology request
democracy: team the foot ##zak development ##tech [MASK] berg training paper power body gym probe child
--------------------------------
layer = 0
head = 2
print(f"\n==== Layer {layer}, Head {head} ====\n")
W_P_i = get_BERT_WP(layer, head)
word_groups = [
# Test different furniture and objects to check if it maps "couch" to physical items
["couch", "chair", "table", "bed", "sofa"],
# Test different animals to see if "dog" behavior generalizes
["dog", "cat", "lion", "elephant", "wolf"],
# Test different movement-related words to see if "run" finds action-based contexts
["run", "sprint", "walk", "fly", "drive"],
# Test security & alert-related words based on "dog" results (guard, watch, radio)
["guard", "watch", "alert", "detect", "surveillance"],
# Test whether it groups abstract concepts (following "happy" head test)
["freedom", "justice", "law", "rights", "democracy"]
]
for input_words in word_groups:
for input_word in input_words:
print(f"{input_word:>10}: ", end="")
matches = find_head_matches(W_P_i, input_word)
for word, sim in matches:
# Print out the matching words, padded to 10 characters each
print(f"{word:>10} ", end="")
print()
print("\n\n--------------------------------\n\n")
==== Layer 0, Head 2 ====
couch: em truck ar cell car sea tanks tr rotor cam bull ##car org ##em kim
chair: rear sedan club tree bank season vehicle car parish vehicles church bed dresser argument partial
table: round water hill rounds writing left hall opened center periodic heart todd turn square accept
bed: truck river playoff creek speed flow vehicle lecture fish stream flower thunder drain narrow dry
sofa: ha da um na inn om sa org ##la granny em ##ha ser \ ##dran
--------------------------------
dog: hot sent watch guard radio guide send hound unsuccessfully voice neck sends success feel mas
cat: del bihar ##del lo hotel hotels della ##wal hold le pussy molly ##slin ##a po
lion: george angelo sea ##con tree mountain train canyon over cell voice river answered ##jer ##tia
elephant: fan mac ##lio ##ier sy ##ille theo ##zy ##ian ty ti yan ##illon ##ach hagen
wolf: li mis ##x semi ##zi auto thought shatter space ##t rosen ##tia un ##isi gray
--------------------------------
run: home child prison narrow patient remote female crowd fish imagination willow transfer limited festival reduced
sprint: fan ##ber ##del sara ram ##hom star om ##fra laura ##ele ##lo ##fan ##run ##ndra
walk: dry square heath florida em rose cotton snow earth up patricia marsh dead grass media
fly: swan pigs ears fish ear drops slave drama crane dragon paris battles grapes eye eyes
drive: disk disc leaf flash club magnetic wheel gene reverse rip data blood commercially serpent captive
--------------------------------
guard: called honor honour goal waist stood sin neural named off nu followed cell star memorial
watch: wrist night bird human hand finished palm thousand humans dream turn farm devon case streak
alert: information national earlier nu ##ency lin sound higher chief after cal match jones peek tissue
detect: ##le ##iti ##ele mor er ##hom worry ) theo ##is ##el im ##dran ##lo um
surveillance: shah za sha im om com ##uit marsh ##com jana pri tito ser jen bell
--------------------------------
freedom: catholic held ned pere relative poet vida reserve homeland media environmental democrat ordained ##ia )
justice: deep inter raven ang criminal pl sound ##to upper air dora park formal dev primary
law: bird fish dog gift je match sound barn purple forest andre live pack rolled competition
rights: fan hair left heard called all naming film have universal human came bel words broadcasting
democracy: ##ri pseudo * ka ##re ##io ##con sven ##di ki theo ##ana bank ) nor
--------------------------------
▂▂▂▂▂▂▂▂▂▂▂▂
S2. GPT-2: Layer-Wise Evolution of an Embedding
2.1. Load Model
We start by loading GPT‑2 and its tokenizer from the Hugging Face Transformers library. GPT‑2’s token embeddings are stored in the model’s transformer component (specifically in wte
).
from transformers import GPT2Tokenizer, GPT2LMHeadModel
# Load GPT‑2 model and tokenizer.
# Note: GPT‑2 uses a byte-level BPE tokenizer, so tokenization behavior may differ from BERT.
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# Extract the token embeddings.
# In GPT‑2, embeddings are stored in model.transformer.wte.
embedding_layer = model.transformer.wte.weight.detach().cpu().numpy()
# Print the shape of the embedding matrix to inspect vocabulary size and embedding dimension.
print("Embedding matrix shape:", embedding_layer.shape)
# Retrieve the LM head matrix
lm_head_embeddings = model.lm_head.weight.detach().cpu().numpy()
Embedding matrix shape: (50257, 768)
2.2. Run Sequence
-
Preparing the Input and Obtaining Hidden States:
We tokenize an input sentence (e.g.,"The cat sat on the"
) and perform a forward pass withoutput_hidden_states=True
so that we receive the embedding output as well as the output from each transformer layer. -
Predicting the Next Token:
Using the final logits (after all layers), we determine the next token by selecting the one with the highest probability. Since GPT‑2 ties its input embeddings with its LM head weights, we can retrieve the predicted token’s embedding from the same matrix. -
Layer-wise Analysis:
For each layer (starting from the initial embedding), we:- Extract the hidden state corresponding to the last token.
- Compute the dot product similarity between that hidden state and the predicted token’s embedding.
- This requires first applying the final output layer normalization to the hidden state to bring it back into the vocabulary embedding space.
- Retrieve the top tokens that are most similar to the hidden state.
This lets us observe how, as the data passes through each layer, the representation of the context (here, the last token) evolves toward the characteristics of the token that the model will eventually predict.
Run our sentence through the model and get the hidden states, plus the final predicted token.
import torch
# Example sentence.
input_text = " The cat sat on the"
inputs = tokenizer(input_text, return_tensors="pt")
# Get outputs with all hidden states.
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states # Tuple: (embedding output, layer 1, ..., layer 12)
# Determine the predicted next token using the final logits.
logits = outputs.logits
predicted_token_id = torch.argmax(logits[0, -1, :]).item()
predicted_token = tokenizer.decode([predicted_token_id])
print("Predicted next token:", predicted_token)
# Retrieve predicted token's embedding from LM head.
predicted_emb = lm_head_embeddings[predicted_token_id]
Predicted next token: floor
2.3. Compare Hidden States
# Sort similarities and return the top k words with their scores.
def get_top_k_words(similarities, k=5):
top_indices = similarities.argsort()[-k:][::-1]
top_words = []
for idx in top_indices:
# Convert the vocabulary index back into a token string.
word_str = tokenizer.convert_ids_to_tokens(int(idx))
if word_str[0] == "Ġ":
word_str = word_str[1:] # Remove the leading space.
else:
#word_str = "·" + word_str # Add a symbol to indicate a subword.
#word_str = "~" + word_str # Add a symbol to indicate a subword.
word_str = "…" + word_str # Add a symbol to indicate a subword.
top_words.append((word_str, similarities[idx]))
return top_words
import torch
import pandas as pd
# For each layer, compare the hidden state of the last token with the predicted token.
print("\nLayer-wise similarity analysis:")
# Dictionaries to store top tokens and their similarity scores per layer.
token_table = {}
sim_table = {}
for i, hs in enumerate(hidden_states):
# Get the last token's hidden state.
last_token_state = hs[0, -1, :]
# Apply the final layer normalization to each hidden state to bring it into the embedding space.
if i == len(hidden_states) - 1:
# For the final layer, the normalization has already been applied.
state = last_token_state.detach().cpu().numpy()
else:
state = model.transformer.ln_f(last_token_state).detach().cpu().numpy()
print(f"\nLayer {i}:")
# Direct dot product similarity with the predicted token embedding.
dot_sim = state.dot(predicted_emb)
print(f"Dot product with predicted token '{predicted_token}': {dot_sim:.3f}")
# Retrieve top-5 similar tokens using dot product similarity.
similarities = dot_product_similarity(state, lm_head_embeddings)
top_words = get_top_k_words(similarities, k=5)
print("Top 5 similar tokens to hidden state:")
for token, sim in top_words:
print(f" {token:10s} {sim:.3f}")
# Save results for table construction.
token_table[f"Layer {i}"] = [token for token, sim in top_words]
sim_table[f"Layer {i}"] = [sim for token, sim in top_words]
# Construct DataFrames where columns are layers and rows are rank positions.
tokens_df = pd.DataFrame(token_table, index=[f"Rank {i+1}" for i in range(5)])
sims_df = pd.DataFrame(sim_table, index=[f"Rank {i+1}" for i in range(5)])
print("\nTable of Top 5 Tokens (per layer):")
display(tokens_df)
print("\nTable of Similarities (per layer):")
display(sims_df)
Layer-wise similarity analysis:
Layer 0:
Dot product with predicted token ' floor': 48.647
Top 5 similar tokens to hidden state:
destro 81.759
mathemat 80.826
livest 79.613
challeng 78.476
…theless 78.156
Layer 1:
Dot product with predicted token ' floor': 0.726
Top 5 similar tokens to hidden state:
same 11.226
latter 8.338
first 8.058
world 7.621
last 7.551
Layer 2:
Dot product with predicted token ' floor': 1.161
Top 5 similar tokens to hidden state:
same 13.096
latter 9.439
first 8.846
world 8.409
last 8.355
Layer 3:
Dot product with predicted token ' floor': 0.513
Top 5 similar tokens to hidden state:
same 11.183
last 5.757
world 5.405
first 5.294
latter 5.180
Layer 4:
Dot product with predicted token ' floor': 1.848
Top 5 similar tokens to hidden state:
same 10.353
opposite 5.604
last 4.869
first 4.859
next 4.608
Layer 5:
Dot product with predicted token ' floor': -0.502
Top 5 similar tokens to hidden state:
same 5.646
opposite 1.171
next 0.955
table 0.894
very 0.596
Layer 6:
Dot product with predicted token ' floor': -1.299
Top 5 similar tokens to hidden state:
same 2.140
opposite -0.933
table -1.194
floor -1.299
board -2.418
Layer 7:
Dot product with predicted token ' floor': -1.397
Top 5 similar tokens to hidden state:
same 1.074
shoulders -1.286
table -1.347
floor -1.397
opposite -1.717
Layer 8:
Dot product with predicted token ' floor': -6.399
Top 5 similar tokens to hidden state:
floor -6.399
edge -7.088
same -7.228
ground -7.307
table -7.308
Layer 9:
Dot product with predicted token ' floor': -7.291
Top 5 similar tokens to hidden state:
ground -7.132
floor -7.291
table -7.300
edge -7.660
bottom -8.409
Layer 10:
Dot product with predicted token ' floor': -9.594
Top 5 similar tokens to hidden state:
floor -9.594
bed -11.009
sofa -11.136
table -11.299
ground -11.632
Layer 11:
Dot product with predicted token ' floor': -43.615
Top 5 similar tokens to hidden state:
floor -43.615
sofa -44.381
bed -44.545
couch -44.738
table -45.060
Layer 12:
Dot product with predicted token ' floor': -80.597
Top 5 similar tokens to hidden state:
floor -80.597
bed -80.720
couch -80.899
ground -81.091
edge -81.102
Table of Top 5 Tokens (per layer):
Layer 0 | Layer 1 | Layer 2 | Layer 3 | Layer 4 | Layer 5 | Layer 6 | Layer 7 | Layer 8 | Layer 9 | Layer 10 | Layer 11 | Layer 12 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Rank 1 | destro | same | same | same | same | same | same | same | floor | ground | floor | floor | floor |
Rank 2 | mathemat | latter | latter | last | opposite | opposite | opposite | shoulders | edge | floor | bed | sofa | bed |
Rank 3 | livest | first | first | world | last | next | table | table | same | table | sofa | bed | couch |
Rank 4 | challeng | world | world | first | first | table | floor | floor | ground | edge | table | couch | ground |
Rank 5 | …theless | last | last | latter | next | very | board | opposite | table | bottom | ground | table | edge |
Table of Similarities (per layer):
Layer 0 | Layer 1 | Layer 2 | Layer 3 | Layer 4 | Layer 5 | Layer 6 | Layer 7 | Layer 8 | Layer 9 | Layer 10 | Layer 11 | Layer 12 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Rank 1 | 81.759 | 11.226 | 13.096 | 11.183 | 10.353 | 5.646 | 2.140 | 1.074 | -6.399 | -7.132 | -9.594 | -43.615 | -80.597 |
Rank 2 | 80.826 | 8.338 | 9.439 | 5.757 | 5.604 | 1.171 | -0.933 | -1.286 | -7.088 | -7.291 | -11.009 | -44.381 | -80.720 |
Rank 3 | 79.613 | 8.058 | 8.846 | 5.405 | 4.869 | 0.955 | -1.194 | -1.347 | -7.228 | -7.300 | -11.136 | -44.545 | -80.899 |
Rank 4 | 78.476 | 7.621 | 8.409 | 5.294 | 4.859 | 0.894 | -1.299 | -1.397 | -7.307 | -7.660 | -11.299 | -44.738 | -81.091 |
Rank 5 | 78.156 | 7.551 | 8.355 | 5.180 | 4.608 | 0.596 | -2.418 | -1.717 | -7.308 | -8.409 | -11.632 | -45.060 | -81.102 |
▂▂▂▂▂▂▂▂▂▂▂▂
S3. GPT-2 Heads: Analyzing Pattern Vectors in Context
Building on our analysis of BERT’s attention heads, we now extend the same methodology to GPT-2.
This time, creating the pattern vectors from actual intermediate hidden states generated by processing a sentence, instead of just using vocabulary embeddings directly.
3.1. Extracting Pattern Projection Matrices for GPT-2
We start by defining a function to extract the $W^P$ matrices for specific heads in GPT-2. Since GPT-2 is a decoder model, it uses causal self-attention, but the core extraction process is similar.
import torch
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# Load GPT-2 model and tokenizer
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Extract token embedding matrix
embedding_layer = model.transformer.wte.weight.detach().cpu().numpy()
# Retrieve the LM head matrix
lm_head_embeddings = model.lm_head.weight.detach().cpu().numpy()
def get_GPT2_WP(layer, head):
attn_layer = model.transformer.h[layer].attn
# Get the full weight matrix from c_attn; here W is (768, 2304)
W = attn_layer.c_attn.weight.detach().cpu().numpy()
# Transpose so that the Q, K, V segments are along axis 0:
W_T = W.T # Now shape is (2304, 768)
# Slice out Q and K: each is 768 rows (for GPT‑2, hidden_size=768)
W_Q = W_T[:model.config.n_embd, :] # (768, 768)
W_K = W_T[model.config.n_embd:2*model.config.n_embd, :] # (768, 768)
num_heads = model.config.n_head # e.g., 12
head_size = model.config.n_embd // num_heads # 768 // 12 = 64
# Extract the slice for the given head along the rows
W_Q_i = W_Q[head * head_size:(head + 1) * head_size, :] # (64, 768)
W_K_i = W_K[head * head_size:(head + 1) * head_size, :] # (64, 768)
# Compute the pattern projection matrix
W_P_i = np.dot(W_Q_i.T, W_K_i) # Results in (768, 768)
return W_P_i
3.2. Comparing GPT-2 Patterns to Vocabulary Embeddings
Now, we extend our prior method by applying layer normalization to the pattern vectors before computing their similarity to the vocabulary embeddings.
def find_head_matches_GPT2(W_P_i, hidden_state, k=15):
# Project the token embedding to obtain the pattern vector
pattern = np.dot(hidden_state, W_P_i) #.detach().cpu().numpy()
# Apply final layer normalization
pattern = model.transformer.ln_f(torch.tensor(pattern)).detach().cpu().numpy()
#pattern = pattern.detach().cpu().numpy()
# Compute cosine similarities
similarities = cosine_similarity(pattern, embedding_layer)
#similarities = dot_product_similarity(pattern, embedding_layer)
# Retrieve top-k matches
top_indices = similarities.argsort()[-k:][::-1]
top_words = []
for idx in top_indices:
# Convert the vocabulary index back into a token string.
word_str = tokenizer.convert_ids_to_tokens(int(idx))
if word_str[0] == "Ġ":
word_str = word_str[1:] # Remove the leading space.
else:
#word_str = "·" + word_str # Add a symbol to indicate a subword.
#word_str = "~" + word_str # Add a symbol to indicate a subword.
word_str = "…" + word_str # Add a symbol to indicate a subword.
top_words.append((word_str, similarities[idx]))
return top_words
3.3. Run Sequence
import torch
# Example sentence.
#input_text = " The cat sat on the"
input_text = " While formerly a Democrat, in next year's election, the senator intends to"
inputs = tokenizer(input_text, return_tensors="pt")
# Get outputs with all hidden states.
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states # Tuple: (embedding output, layer 1, ..., layer 12)
# Determine the predicted next token using the final logits.
logits = outputs.logits
predicted_token_id = torch.argmax(logits[0, -1, :]).item()
predicted_token = tokenizer.decode([predicted_token_id])
print("Predicted next token:", predicted_token)
# Retrieve predicted token's embedding from LM head.
predicted_emb = lm_head_embeddings[predicted_token_id]
Predicted next token: run
3.3. Probing GPT-2 Heads
Compare the pattern vector for a hidden state to the vocabulary.
num_heads = model.config.n_head
# Store results
results = []
# For each layer...
for layer_i, hs in enumerate(hidden_states):
# Get the last token's hidden state.
last_token_state = hs[0, -1, :]
print("Layer", layer_i)
# As a general sanity check, ensure this dot product yields the same
# as the prior example.
# Direct dot product similarity with the predicted token embedding.
state = model.transformer.ln_f(last_token_state).detach().cpu().numpy()
dot_sim = state.dot(predicted_emb)
print(f" Dot product with predicted token '{predicted_token}': {dot_sim:.3f}")
# Find the current most similar word to the hidden state. It will gradually
# become more like the predicted word, as we saw in section 2.
sims = dot_product_similarity(state, lm_head_embeddings)
top_indices = sims.argsort()[-1:][::-1]
closest_word = tokenizer.convert_ids_to_tokens(int(top_indices[0])).replace("Ġ", " ")
print(f" Closest word:", closest_word, f"({sims[top_indices[0]]:.3f})")
print()
# ======== Analyze Head Patterns ========
# Hidden state of token to predict...
last_token_state = last_token_state.detach().cpu().numpy()
# For each of the heads...
for head in range(num_heads):
# Get the pattern matrix
W_P_i = get_GPT2_WP(layer_i, head)
# Match the head pattern to the vocabulary.
matches = find_head_matches_GPT2(W_P_i, last_token_state, k=10)
# Separate the words and scores.
top_k_strs = ""
top_k_sims = ""
# Turn them into strings
for word_str, sim in matches:
top_k_strs += f"{word_str:>8}, "
top_k_sims += f"{sim:.2}, "
# Add the result as a row.
results.append({
"Closest Word": closest_word,
"Layer": layer_i,
"Head": head,
"Top-k": top_k_strs,
"Scores": top_k_sims
})
# TODO - Not sure what's up here.
if layer_i == 11:
break
# Convert results to DataFrame and display
df_results = pd.DataFrame(results)
# Set pandas precision to 3 decimal points
pd.options.display.float_format = '{:.3f}'.format
display(df_results)
Layer 0
Dot product with predicted token ' run': 73.287
Closest word: destro (112.097)
Layer 1
Dot product with predicted token ' run': 4.512
Closest word: be (10.018)
Layer 2
Dot product with predicted token ' run': 0.683
Closest word: be (7.008)
Layer 3
Dot product with predicted token ' run': 2.225
Closest word: be (6.327)
Layer 4
Dot product with predicted token ' run': -3.593
Closest word: be (1.300)
Layer 5
Dot product with predicted token ' run': -6.860
Closest word: be (-2.620)
Layer 6
Dot product with predicted token ' run': -13.255
Closest word: make (-9.850)
Layer 7
Dot product with predicted token ' run': -15.505
Closest word: retire (-11.693)
Layer 8
Dot product with predicted token ' run': -20.814
Closest word: vote (-15.865)
Layer 9
Dot product with predicted token ' run': -26.565
Closest word: vote (-19.277)
Layer 10
Dot product with predicted token ' run': -35.651
Closest word: vote (-27.857)
Layer 11
Dot product with predicted token ' run': -74.890
Closest word: vote (-72.214)
Closest Word | Layer | Head | Top-k | Scores | |
---|---|---|---|---|---|
0 | destro | 0 | 0 | …ò, pione, …ÃĥÃĤÃĥÃĤÃĥÃĤÃĥÃĤÃĥÃĤÃĥÃĤÃ... | 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.12, 0.11... |
1 | destro | 0 | 1 | to, for, from, by, ... | 0.24, 0.2, 0.2, 0.2, 0.2, 0.19, 0.19, 0.19, 0.... |
2 | destro | 0 | 2 | not, that, in, at, ... | 0.1, 0.096, 0.094, 0.094, 0.092, 0.09, 0.09, 0... |
3 | destro | 0 | 3 | …ÃĥÃĤÃĥÃĤÃĥÃĤÃĥÃĤÃĥÃĤÃĥÃĤÃĥÃĤÃĥÃĤÃĥÃĤÃĥÃĤÃĥÃĤÃ... | 0.094, 0.094, 0.092, 0.092, 0.09, 0.09, 0.09, ... |
4 | destro | 0 | 4 | an, in, at, …ÃĥÃĤÃĥÃĤÃĥÃĤÃĥÃ... | 0.12, 0.12, 0.12, 0.12, 0.11, 0.11, 0.11, 0.11... |
... | ... | ... | ... | ... | ... |
139 | vote | 11 | 7 | neighb, …PDATE, …Þ, eleph, nomi... | 0.12, 0.12, 0.12, 0.12, 0.12, 0.11, 0.11, 0.11... |
140 | vote | 11 | 8 | proble, '', …oda, recently, report... | 0.0049, 0.0019, -0.0015, -0.004, -0.006, -0.00... |
141 | vote | 11 | 9 | …escription, horizont, mathemat, …ãĤ¯, c... | 0.18, 0.17, 0.17, 0.16, 0.16, 0.16, 0.15, 0.15... |
142 | vote | 11 | 10 | Chaser, Surviv, …Firstly, …Interested, …Pr... | 0.03, 0.029, 0.029, 0.027, 0.026, 0.026, 0.024... |
143 | vote | 11 | 11 | up, on, in, out, fo... | 0.28, 0.28, 0.27, 0.27, 0.27, 0.27, 0.27, 0.27... |
144 rows × 5 columns
3.4. Observations
Analysis of the results by ChatGPT o3-mini-high:
There’s a lot to unpack!
- Function Word and Syntactic Patterns:
In the early layers (layer 0), many head patterns consistently are matching tokens that are common function words—prepositions, articles, and conjunctions (e.g. “to”, “for”, “in”, “that”). This suggests that several heads are capturing low-level syntactic or relational patterns rather than content per se.
index | Closest Word | Layer | Head | Top-k | Scores |
---|---|---|---|---|---|
1 | destro | 0 | 1 | to, for, from, by, on, in, that, into, with, of, | 0.24, 0.2, 0.2, 0.2, 0.2, 0.19, 0.19, 0.19, 0.19, 0.19, |
7 | destro | 0 | 7 | an, in, at, for, as, on, that, a, not, by, | 0.18, 0.18, 0.18, 0.18, 0.18, 0.17, 0.17, 0.17, 0.17, 0.17, |
- Political and Temporal Semantics:
One head pattern in layer 6 matches “Libertarian” and another gives a token that appears to be “November.” These tokens indicate that some heads are honing in on the political and electoral context of the sentence. Later layers 8–11 begin to show tokens like “vote,” “rights,” “president,” and even fragments that resemble “nomine” (hinting at “nominee”). This progression suggests the model is gradually shifting from general syntactic features toward more semantically rich, context-dependent political concepts.
index | Closest Word | Layer | Head | Top-k | Scores |
---|---|---|---|---|---|
69 | be | 5 | 9 | …20439, …Welcome, Donald, Libertarian, âĢº, Amelia, Canaver, Kathryn, …ãĤ¶, Practices, | 0.04, 0.035, 0.034, 0.031, 0.029, 0.028, 0.028, 0.027, 0.026, 0.025, |
74 | make | 6 | 2 | …ovember, Various, normally, …*/(, Simply, Normally, …nesday, withd, …Normally, …CRIPTION, | 0.031, 0.03, 0.03, 0.026, 0.026, 0.024, 0.021, 0.021, 0.021, 0.021, |
108 | vote | 9 | 0 | …President, Parallel, himself, …Dialog, commentary, …*/(, Twitter, …ãĤ±, …Republican, President, | 0.11, 0.098, 0.097, 0.094, 0.088, 0.088, 0.088, 0.087, 0.087, 0.086, |
127 | vote | 10 | 7 | proposed, proposals, …sub, …The, …An, …President, …government, qu, proposal, Equality, | 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14, |
-
Diverse, Sometimes Noisy, Representations:
A number of heads (especially in some layers) produce tokens with unusual or garbled characters (for example, sequences of non-standard symbols). These may indicate that certain heads are either less interpretable or are picking up on subword fragments and idiosyncratic patterns that don’t align neatly with our intuitive understanding of words. Their consistent appearance across a range of heads might point to a more nuanced or experimental role in the model’s internal representation. -
Layer-Specific Behavior:
The shift from heads returning mostly functional tokens in lower layers to heads returning politically charged or temporally relevant tokens in higher layers is noteworthy. It aligns with the idea that early layers capture general patterns (e.g., syntax) while later layers increasingly reflect the specific semantic context—in this case, the political narrative of a senator’s electoral intentions.
In summary, aside from the clear political and temporal cues, the experiment reveals a layered internal structure where different heads focus on different aspects of language—from basic syntactic roles to more context-specific and even slightly noisy subword patterns. This multifaceted behavior is exactly what makes probing head functionality both challenging and fascinating.
S4. Related Work
written by OpenAI o3-mini-high
Traditional interpretability work on transformer attention has largely focused on visualizing attention weights or using gradient‐based feature attribution methods. For example, studies like “What Does BERT Look At?” by Clark et al. and critiques such as “Attention is Not Explanation” by Jain and Wallace examine the distribution and impact of attention scores. While these methods have deepened our understanding of how models attend to different tokens, they primarily address the output of the attention mechanism rather than its internal transformations.
Our approach takes a complementary route by reframing the query-key interactions into a single projection matrix,
\[W^P = (W^Q)^T W^K,\]which directly produces a “pattern vector” when applied to an input embedding. This vector encapsulates what a particular head is searching for in model space, allowing us to compare it directly against vocabulary embeddings using cosine similarity.
Key differences include:
-
Focus on Internal Transformations:
Instead of solely examining attention weights, our method isolates the low-rank structure inherent in the linear transformations. This provides a more granular view of how individual heads process information—a perspective that complements structural probing methods like those proposed by Hewitt and Manning. -
Quantitative Analysis of Head Function:
By extracting pattern vectors and analyzing their singular value distributions, we can quantify the effective rank of each head’s transformation. This not only informs us about the head’s capacity for representing complex patterns but also opens up potential avenues for efficient approximations and dynamic rank selection. -
Bridging Representation and Attention:
Our technique links the abstract notion of attention to the concrete space of word embeddings. This connection offers an interpretable framework that goes beyond what is typically captured by mere attention weight visualizations.
In summary, while existing methods provide valuable insights into where attention is allocated, our probing technique delves into how each head transforms the input, offering a fresh perspective on the inner workings of transformer models.
S4. Conclusion
written by OpenAI o3-mini-high
By explicitly recognizing the roles of $W^P$ (as a pattern extractor) and $W^M$ (as a modifier generator), we open up new avenues for both interpretability and efficiency. For example, understanding which heads are responsible for syntactic versus semantic processing could inform targeted pruning or specialized training regimes. Moreover, if the effective rank of these matrices is significantly lower than their theoretical limit, it might be possible to develop dynamic, low-rank approximation techniques that reduce computational overhead without compromising performance.
In summary, this new framing of attention deepens our understanding of how transformer models process language. It provides educators with a more intuitive tool for explaining head behavior and offers researchers a fresh lens through which to explore model efficiency and specialization. As we refine these insights and extend the analysis to include modifier / “message” vectors, we expect further opportunities to bridge the gap between theoretical understanding and practical advancements in model architecture.
Next to research…
Probe the “message” vectors!
# Compute the modifier / "message" matrix for this head:
W_M_i = np.dot(W_V_i, W_O_i) # shape (768, 768)
# Now compute the modifier vector m for the attended-to word:
m = np.dot(x2_word, W_M_i) # shape (768,)
# TODO - What can we learn from `m`??