Tokenizers deep dive

Deep dive into understanding and building tokenizers with an end goal of replicating the LLaMA 3 tokenizer.
deep learning
LLMs
tokenization
Author

Augustas Macijauskas

Published

May 3, 2024

In the conclusion of my recent blog post I argued that I disagree with Andrej Karpathy’s claims about the current state of tokenizer availability and tooling. In particular:

This prompted me to do a deep dive into tokenizers and how one would go about building one from scratch. Inspired by the top-down approach from Jeremy Howard, whose courses I enjoy a lot, this blog post will start with the very basics of tokenizers and then focus on how to train a Llama 3-like tokenizer on your own data with as little code as possible. Finally, we will explore what influence different proportions of English/non-English/code data have on the final vocabulary learnt by the tokenizer, as well as discuss a paper on how tokenizer design choices impact the downstream performance of the LLM.

In the second part, we will explore what influence different proportions of English/non-English/code data have on the final vocabulary learnt by the tokenizer, as well as discuss a paper on how tokenizer design choices impact the downstream performance of the LLM, so stay tuned! Now, let’s jump right in!.

Warm-up: basics of tokenization

I am mostly assuming that if you found and opened this article, you have at least a basic understanding of what tokenization is and what its purpose is, but let’s do a quick revision just to make sure we are all on the same page.

Why do we need tokenization in the first place?

  1. Majority of machine learning models operate on numbers, and (large) language models are no exception.
  2. Whether it is natural language, code, or something else, we want to be able to input strings into the language models.
  3. This is where tokenization comes in, it is a process of converting strings of text into numbers that can then be fed into language models.
Note

This last point is actually the reason why the current state-of-the-art language models are not actually end-to-end systems.

How do tokenizers work on a high level?

  1. First, the long input string is split into smaller chunks called tokens.
  2. Then, a vocabulary, which is a mapping from known tokens to integers, is used to convert the tokens to integer ids.
  3. These ids can be fed as input into language models.

What are some qualities that we want our tokenizers to have?

We want our tokenizers to be able to process various kinds of inputs. In this article, we will not be able to fully dig into the various design choices made when building tokenizers to achieve this, but here is a (likely non-exhaustive) list of qualities that we want the tokenizers to have:

  1. They should work on both English and non-English texts, both uppercase and lowercase, both Latin and non-Latin alphabet.
  2. They should be able to handle various kinds of special characters that we might find in the wild, such as emojis 🤗.
  3. They should be able to handle code.
  4. (Optional) They should be able to tokenize unseen words without the need to include a special <unk> token in their vocabulary.

How does this look in code?

Before diving into training a tokenizer, let’s illustrate some of the aspects described above in code. We will find that all of the above is very simple to do by using the tokenizers library!

Code
%load_ext autoreload
%autoreload 2
import json
from pprint import pprint
from tokenizers import (
    Tokenizer, Regex, models, pre_tokenizers, processors, decoders
)
from transformers import AutoTokenizer
Code
# This code will only be used later when training a tokenizer
from datasets import load_dataset

dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")
len(dataset)
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)

Now that we have a tokenizer defined, let’s see how we can tokenizer a piece of text:

text = "Hello world, we're live 🤗."
tokens = tokenizer.tokenize(text)
tokens
['Hello', 'Ġworld', ',', 'Ġwe', "'re", 'Ġlive', 'ĠðŁ', '¤', 'Ĺ', '.']

We can see that the string was split into smaller tokens. The Ġ is how the Llama 3 tokenizer represents spaces in its vocabulary, so you can actually think of the tokens as text, to, etc.

Note

It turns out that it is very likely that nowadays the use of the Ġ Unicode character is a historical artifact that started with the GPT-2 implementation. A very strong evidence for this claim is that the GPT-4 tokenizer has ordinary spaces in its vocabulary. See Section 5 for more.

Next, let’s turn these tokens into ids:

token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_ids
[9906, 1917, 11, 584, 2351, 3974, 11410, 97, 245, 13]

As simple as that! We could now feed these token_ids into a language model, though it is important to note that in practice we would perform everything in one step:

inputs = tokenizer(text, return_tensors="pt")
pprint(inputs, sort_dicts=False)  # pprint stands for "pretty print"
{'input_ids': tensor([[128000,   9906,   1917,     11,    584,   2351,   3974,  11410,     97,
            245,     13]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

We would then call a Hugging Face-compatible model like so:

# Define the model before this line
outputs = model(**inputs)
Note

For the sake of completeness, it is worth noting that one would only tokenize a single string on its own during inference. During training, one would either tokenize the whole dataset in batches before training or tokenize a batch of data on the fly. The former is usually more reliable and allows to restart training more easily if a crash occurs (see Thomas Wolf’s video).

Training a Llama 3 tokenizer

This section discusses the code needed to train a new tokenizer from an old one using the tokenizers library. For now, we will assume that we already have a dataset to train on, and we will simply reuse the architecture of Llama 3 tokenizer, so that we do not have to worry about the different tokenizer design choices.

First, we define a function that takes a Hugging Face datasets and return a batched iterator over it:

# Define an iterator over the training split of the dataset
def batch_iterator(dataset, batch_size=1000, verbose=False):
    if verbose:
        print(f"Dataset size: {len(dataset)}")

    for i in range(0, len(dataset), batch_size):
        yield dataset[i:i+batch_size]["text"]

With that, we can go ahead and use it to train a new tokenizer:

new_tokenizer = tokenizer.train_new_from_iterator(
    batch_iterator(dataset, verbose=False),
    len(tokenizer.get_vocab()),
    initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
)
new_tokenizer.add_tokens(list(tokenizer.get_added_vocab().keys()))

print(f"Vocab length={len(new_tokenizer.get_vocab())}")

# Save the new tokenizer
new_tokenizer.save_pretrained("new-llama-tokenizer-english-only");
Vocab length=128256

This takes some time (about 1 minute on my machine), but the result is a working tokenizer, as easy as that!

Assembling a Llama 3 tokenizer ourselves

Let’s now dive a bit deeper and see how we could construct the tokenizer ourselves. This is where the different design decisions come into play.

Essentially, there are up to five components (most of them being optional) that have to be specified to have a working tokenizer. We will cover them briefly here, but I encourage reading this great piece of documentation to learn more about the various components available in the tokenizers library that can be used to assemble a tokenizer.

The 5 components are:

  1. Normalizer (optional): pre-process the input string for a given use case, e.g. strip accents (é -> e) or make the strings lowercase.
  2. Pre-tokenizer (optional): performs some initial splitting with the main intent being to prevent the existence of tokens that are too long, e.g. ones that connect multiple common words. We could avoid this by having a pre-tokenizer that splits across spaces.
  3. Model: an algorithm that performs the tokenization (i.e. takes strings, splits them an converts into tokens).
  4. Post-processor (optional): used for post-processing the tokenized string, e.g. we could add special tokens or apply some other template.
  5. Decoder (optional): helps using the tokenizer in the opposite directiom. i.e. mapping a list of token ids into readable text. In the simplest case, just the model’s vocabulary is enough to perform the reversal, but certain normalizers or pre-tokenizers add special characters which are removed by the decoder.

With this knowledge, we can now build our own Llama 3 tokenizer! We can check this file to learn what the hyperparamaters are and just copy them over. It turns out that Llama 3 does not use a normalizer (this is because normalization usually makes tokenization an irreversible process, e.g. if you remove accents, you cannot recover them during decoding), while the pre-tokenizer consists of a regex splitter and a byte-level pre-tokenizer. To understand the former, see this Andrej Karpathy’s video on tokenizers from 57:36 to 01:14:59, while for the latter, BPE tokenizers are trained by starting with an initial vocabulary and then iteratively merging the most common byte pairs until the desired vocabulary size is reached. Back when GPT-2 was created, instead of using the bytes corresponding to the first 256 Unicode characters, researchers decided to remap them to some other set of bytes. The reasons for this are not enrirely clear, but it is likely that using the set of the leading 256 bytes caused errors under their implementation.

Note

This byte-level normalizer is actually the reason why the Ġ characters occur, since the space character " " is mapped to the Ġ character during normalization.

Before jumping further into the model and decoder used by Llama 3, let’s quickly see these concepts in action.

import regex as re
gpt_4_pat = re.compile(r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""")

some_text = "Hello world, we're live 🤗. The number is 1234. How's it going!!!?"
print(re.findall(gpt_4_pat, some_text))
['Hello', ' world', ',', ' we', "'re", ' live', ' 🤗.', ' The', ' number', ' is', ' ', '123', '4', '.', ' How', "'s", ' it', ' going', '!!!?']

We can see that the regex splits by whitespace, numbers are no longer than 3 digits, and apostrophe abbreviations and punctuation are also separated out (the list is non-exhaustive). As the for ByteLevel pre-tokenizer, we can see that ordinary English characters remain unchanged, spaces are indeed replaced by the Ġ Unicode character, and the emoji is completely replace by a combination of multiple bytes:

pre_tokenizer = pre_tokenizers.ByteLevel(
    add_prefix_space=False, trim_offsets=True, use_regex=False
)
pre_tokenization_result = pre_tokenizer.pre_tokenize_str(some_text)[0][0]
pre_tokenization_result
"HelloĠworld,Ġwe'reĠliveĠðŁ¤Ĺ.ĠTheĠnumberĠisĠ1234.ĠHow'sĠitĠgoing!!!?"

Returning to the components utilized by the Llama 3 tokenizer, the model employs the Byte Pair Encoding (BPE) algorithm. To avoid duplicating the excellent explanations already available, I encourage you to view the relevant sections of Andrej Karpathy’s video or consult the corresponding Wikipedia entry. The ByteLevel post-processor simply reverses the byte shifting that took place in the the pre-tokenization phase, as well as prepends the output string with the <|begin_of_text|> special token (this can be turned off by setting add_special_tokens=False when calling a tokenizer). Finally, the ByteLevel decoder is responsible for mapping token ids back to human-readable strings.

Note

How these design choices are made when building a completely new tokenizer from scratch is out of scope of this article, but it turns out that, in practice, it is enough to simply use a BPE model and only tune dataset used for training, the vocabulary size and the regular expression used to split strings during pre-tokenization.

Assembling our own Llama 3 tokenizer

That was a lot to take in but we can finally assemble our own Llama 3 tokenizer:

# Define the most important part - the model
tokenizer_custom = Tokenizer(models.BPE(ignore_merges=True))

# Add a pre-tokenizer
tokenizer_custom.pre_tokenizer = pre_tokenizers.Sequence([
    pre_tokenizers.Split(
        pattern=Regex(
            "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
        ),
        behavior="isolated",
        invert=False,
    ),
    pre_tokenizers.ByteLevel(
        add_prefix_space=False, trim_offsets=True, use_regex=False
    )
])

# Add the post-processor
tokenizer_custom.post_processor = processors.Sequence([
    processors.ByteLevel(
        add_prefix_space=True, trim_offsets=False, use_regex=True
    ),
    processors.TemplateProcessing(
        single="<|begin_of_text|> $A",
        pair="<|begin_of_text|> $A <|begin_of_text|>:1 $B:1",
        special_tokens=[("<|begin_of_text|>", 128000)],
    ),
])

# And finally, the decoder
tokenizer_custom.decoder = decoders.ByteLevel()

Let’s train and save it:

from tokenizers import trainers

trainer = trainers.BpeTrainer(
    # the `vocab_size` here actually means the number of merges
    vocab_size=len(tokenizer.get_vocab()) - len(pre_tokenizers.ByteLevel.alphabet()),
    initial_alphabet=pre_tokenizers.ByteLevel.alphabet()
)

tokenizer_custom.train_from_iterator(
    batch_iterator(dataset, verbose=False), trainer=trainer
)
tokenizer_custom.add_tokens(list(tokenizer.get_added_vocab().keys()))

print(f"Vocab length={len(tokenizer_custom.get_vocab())}")

# Save the new tokenizer
tokenizer_custom.save("new-llama-tokenizer-custom-english-only/tokenizer.json")
Vocab length=128256

We can check that the resulting tokenizer is indeed the same as the one that we have previously trained:

with open("new-llama-tokenizer-english-only/tokenizer.json", "r") as f:
    new_tokenizer_config = json.load(f)

with open("new-llama-tokenizer-custom-english-only/tokenizer.json", "r") as f:
    custom_tokenizer_config = json.load(f)

custom_tokenizer_config["model"]["merges"] == new_tokenizer_config["model"]["merges"]
True
Note

Upon inspecting the corresponding tokenizer configuration files (file 1 and file 2) you might notice that the two vocabularies are not exactly the same which is because a tokenizer trained from an old one has all the added tokens prepended at the start of the vocabulary. This is why above we check that the merges (i.e. which tokens where merged and in what sequence) are the same for the two tokenizers, which they are.

Conclusion

And that is it, we have succesfully replicated the structure of the Llama 3 tokenizer using the tokenizers library! As we can see, it is not that hard if we familiarize ourselves with the API of the tokenizers library and know where to find the configuration parameters.

Next steps

There are many things that did not make it into this article which I briefly discuss here.

  1. How are the design choices actually made and implemented in practice? Above we just copied the configuration parameters, but how would we go about implementing a state-of-the-art tokenizer ourselves?
    • I enjoyed the Getting the most out of your tokenizer for pre-training and domain adaptation paper which discusses these topics (link).
    • My takeaways are that one should:
      1. Use BPE.
      2. Use a diverse training set that is large enough to cover different languages and code well.
      3. Choose vocabulary size to trade off text compression and memory overhead well (check the paper to understand what this means). A vocabulary size of 100k seems to be a good starting point, especially for smaller models.
  2. The inner workings of the BPE algorithm itself, as well as how one would go about implementing it. For a proper deep dive, looking into the internals of the BPE algorithm and its implementation would seem like a must, but I think that there are already plenty great resources on the topic. Here are some of them that helped me build good intuitions for what is going on:
    • The aforementioned Karpathy’s video is a must-see for every LLM practitioner.
    • The implementation of the GPT-2 tokenizer (link)
    • The educational version of the tiktoken library (note that it is meant to be easy to understand, but is not actually very efficient).
    • [Much harder] The sources code of tiktoken and tokenizers libraries. Both of them are written in Rust 🥲.
  3. Tokenizer parameters used during the inference stage. These mainly include truncation and padding, they are not inherently complex concepts, so I decided to exclude them here.
    • Consult the Hugging Face documentation and various blog posts to learn more about them.

I strongly encourage you to check out these resources to get a better sense of what the BPE algorithm is doing and other aspects of tokenizers!

Open questions (for the curious reader)

I leave here some interesting questions and observations which arose to me while writing this article. I hope to come back to (at least some of) them at some point in the future, but if you end up going down the rabbit holes to find the answers, I would be curious to know them too! Here is the list:

  1. Replacing spaces with the Ġ Unicode character in BPE-based tokenizers is not necessary. The evidence for this is that the GPT-4 tokenizer no longer uses them, so there is a high chance that it is just a historical artifact that people got used to since the release of GPT-2 code and never questioned it afterwards.
    • Unless there is something that I am missing here, I believe that the community should ideally get rid of them over time.
  2. The Llama-3-8B-Instruct and Llama-3-70B-Instruct tokenizers are almost identical apart from one setting: the former has the model.ignore_merges key specified as true, while the latter does not have such a key specified. The question is simple, what difference does this make? A good starting point to answer this question could be this piece of code.
  3. Running inference using the tiktoken library is much faster than with the tokenizers library.
  4. The first 100256 (256 initial + 100k merges) tokens of the GPT-4 and Llama 3 tokenizers are the same (see this note).
    • OpenAI never released their training code, so how could Meta know?

Sources

  1. Let’s build the GPT Tokenizer (video) by Andrej Karpathy.
  2. A little guide to building Large Language Models in 2024 (video) by Thomas Wolf.
  3. Training a new tokenizer from an old one (article) on the Hugging Face course.
  4. Training from memory (part of Hugging Face docs).