Visualising contextualised large language model embeddings with context

deep learning
LLMs
visualisation
Large language models (predictably) learn to represent the semantic meaning of sentences.
Author

Augustas Macijauskas

Published

April 10, 2024

A follow up on this post.

Imports

Toggle cells below if you want to see what imports are being made.

Code
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Code
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

Utils

Use [CLS] pooling according to this:

def compute_sentence_embedding(sentence: str, model, tokenizer):
    sentence_tokenized = tokenizer(sentence, return_tensors="pt")

    print(f"Num tokens: {sentence_tokenized["input_ids"].shape[1]}")

    with torch.no_grad():
        return model(**sentence_tokenized).last_hidden_state[0, 0, :]
def perform_distance_comparison(s1, s2, s3):
    euclidean_dist_1 = torch.linalg.vector_norm(s1 - s2).item()
    euclidean_dist_2 = torch.linalg.vector_norm(s1 - s3).item()

    print(f"|s1 - s2| = {euclidean_dist_1:.3f}")
    print(f"|s1 - s3| = {euclidean_dist_2:.3f}")
    print(f"|s1 - s2| < |s1 - s3| = {euclidean_dist_1 < euclidean_dist_2}")

    cosine_sim_1 = F.cosine_similarity(s1[None, :], s2[None, :])[0].item()
    cosine_sim_2 = F.cosine_similarity(s1[None, :], s3[None, :])[0].item()

    print(f"sim(s1, s2) = {cosine_sim_1:.3f}")
    print(f"sim(s1, s3) = {cosine_sim_2:.3f}")
    print(f"sim(s1, s2) > sim(s1, s3) = {cosine_sim_1 > cosine_sim_2}")

Easier example

model_name = "google-bert/bert-base-cased"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
sentence_1 = "sensitive information"
sentence_2 = "confidential details"
sentence_3 = "sensitive individual"
sentence_1_transformers = compute_sentence_embedding(sentence_1, model, tokenizer)
sentence_2_transformers = compute_sentence_embedding(sentence_2, model, tokenizer)
sentence_3_transformers = compute_sentence_embedding(sentence_3, model, tokenizer)

sentence_1_transformers.shape, sentence_2_transformers.shape, sentence_3_transformers.shape
Num tokens: 4
Num tokens: 4
Num tokens: 4
(torch.Size([768]), torch.Size([768]), torch.Size([768]))
# Both should be true
perform_distance_comparison(
    sentence_1_transformers, sentence_2_transformers, sentence_3_transformers
)
|s1 - s2| = 4.981
|s1 - s3| = 6.788
|s1 - s2| < |s1 - s3| = True
sim(s1, s2) = 0.955
sim(s1, s3) = 0.900
sim(s1, s2) > sim(s1, s3) = True

Harder example

sentence_1 = "your data removal request has been reviewed and concluded"
sentence_2 = "the sensitive personal information has been deleted"
sentence_3 = "she has been a sensitive person"
sentence_1_transformers = compute_sentence_embedding(sentence_1, model, tokenizer)
sentence_2_transformers = compute_sentence_embedding(sentence_2, model, tokenizer)
sentence_3_transformers = compute_sentence_embedding(sentence_3, model, tokenizer)

sentence_1_transformers.shape, sentence_2_transformers.shape, sentence_3_transformers.shape
Num tokens: 11
Num tokens: 9
Num tokens: 8
(torch.Size([768]), torch.Size([768]), torch.Size([768]))
# Both should be true
perform_distance_comparison(
    sentence_1_transformers, sentence_2_transformers, sentence_3_transformers
)
|s1 - s2| = 5.626
|s1 - s3| = 7.529
|s1 - s2| < |s1 - s3| = True
sim(s1, s2) = 0.941
sim(s1, s3) = 0.893
sim(s1, s2) > sim(s1, s3) = True

Try the same with a text embedding model

model_name = "mixedbread-ai/mxbai-embed-large-v1"
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
sentence_1_transformers = compute_sentence_embedding(sentence_1, model, tokenizer)
sentence_2_transformers = compute_sentence_embedding(sentence_2, model, tokenizer)
sentence_3_transformers = compute_sentence_embedding(sentence_3, model, tokenizer)

sentence_1_transformers.shape, sentence_2_transformers.shape, sentence_3_transformers.shape
Num tokens: 11
Num tokens: 9
Num tokens: 8
(torch.Size([1024]), torch.Size([1024]), torch.Size([1024]))
perform_distance_comparison(
    sentence_1_transformers, sentence_2_transformers, sentence_3_transformers
)
|s1 - s2| = 14.404
|s1 - s3| = 19.134
|s1 - s2| < |s1 - s3| = True
sim(s1, s2) = 0.672
sim(s1, s3) = 0.372
sim(s1, s2) > sim(s1, s3) = True