Reproducing
the Attention is all you need Paper from Scratch

Ever since ChatGPT’s release in December 2022, the excitement
surrounding transformer models has been on a steady incline.

Though I have worked with transformer models in the past, my
experience mostly revolves around using the
sentence-transformer package and huggingface
interfaces to deploy and fine-tune pre-existing models.

Would I be able to code a transformer from scratch, solely using
basic PyTorch functions, and successfully develop the self-attention
mechanism, encoder, and decoder myself, without referring to the PyTorch
implementation?

In this blog post, I will attempt to reproduce the Attention is
all you need
paper (Vaswani et al., 2017, https://arxiv.org/abs/1706.03762) from scratch, focusing
specifically on the base model that translates from English to
German.

This endeavor comprises the following steps:

  1. Recreate, clean, and tokenize the training data
  2. Create custom transformer variants
    • Implement the multi-headed attention, encoder, and decoder structure
      from scratch, using simple building block elements from PyTorch.
    • Implement a reference implementation using PyTorch’s
      transformer class, enabling comparison between this model and
      our own implementation.
  3. Train the models
  4. Test the models and compare their performance against the reported
    BLEU values from the paper
  5. Write a Flask microservice to host and use the model in
    practice

The code can be found here: https://github.com/scott-weeden/Transformer-Workbench

The final models can be downloaded here: https://scottweeden.online/wp-content/uploads/transformer-from-scratch-results.zip

1. Setup and Tokenize Training
Data

The training data used in the paper is from the WMT-14 translation
task. I downloaded the raw data from https://www.statmt.org/wmt14/translation-task.html,
selecting all files with an EN-DE part: commoncrawl.de-en,
europarl-v7.de-en, and
news-commentary-v9.de-en.

The test data consists of the newstest 2013 (validation) and
newstest 2014 (actual test set) datasets.

First, I cleaned the training data, converting special Unicode
quotation marks to their regular ASCII counterparts (e.g.,
maps to "). This approach ensures the model
can handle quotation marks more consistently and optimizes performance
for direct keyboard inputs.

I also replaced all special Unicode ranges corresponding to specific
languages (such as Chinese, Japanese, Arabic, and Cyrillic) with a
unique unknown token [UNK], aiming to keep the tokenizer
vocabulary small.

Next, I fully tokenized the text using the Hugging Face
tokenizers library (to focus on building the core transformer
components rather than coding the tokenizer from scratch).

I amended every text sequence (both target and source) with a
[START] and [END] token. Some texts in the
training sequence were extremely long, so to keep memory requirements
low and because such long sequences didn’t occur in the test sets, I
truncated all sequences exceeding 200 tokens. This is still much longer
than the token sizes on the test dataset.

Token Counts on the newstest-14 test dataset

The top 20 occurring tokens are:

'[PAD]', '[START]', '[END]', 'the', '.', 
'in', 'of', 'die', 'and', 'der', 
'to', 'und', 'a', 'is', ',', 
'zu', 'that', 'for', 'von', 'den'

A full list of all tokens in the vocabulary and their counts can be
found in the file df_token_counts.csv in the downloaded
results.

Sorting Into Batches

The data is then sorted into batches. As described in the paper, each
batch contains approximately 25,000 source and 25,000 target tokens.
This results in 6,230 batches. The mean batch size (i.e., number of
texts) is 724, with a mean length of 45 tokens.

Considering that the attention paper trained the base model for
100,000 iterations, this should correspond to 16 epochs.

2. Transformer from Scratch

It’s time to code the transformer model. First, I worked on a variant
using the PyTorch reference implementation. This variant required
embeddings as well. I wrote a convenient class that handles both token
and positional embeddings. The positional embeddings were initialized
with the sine-cosine positional embeddings from the original paper, but
I decided to keep these parameters trainable during optimization instead
of freezing them.

The original base attention transformer has the following
structure:

I used the same structure for both the reference and from-scratch
implementations.

The Importance of Being
Padded

Initially, I tried fitting the transformer model without masking the
padding tokens (which are required to make all sentences in a batch the
same length). I suspected that the transformer would be powerful enough
to learn to ignore these tokens and just work without any implicit mask
(where all padding tokens are set to zero during the self-attention).
Unfortunately, this did not work: The model would be confused by the
noise, often outputting weird translations during the test phase, with
the output depending on the length with padding tokens. Thus, I included
a padding mask in both transformer variants.

After adding padding to the PyTorch reference implementation, I
managed to get a consistent translation out of the reference
implementation.

Transformer Structure from
Scratch

In a translation task, the transformer architecture consists of an
encoder and a decoder, working together to transform a source sentence
into a target sentence in a different language.

Encoder: The encoder’s primary function is to
process and understand the source sentence. It takes the input sequence
in the source language and converts it into a continuous, context-aware
representation. The encoder is composed of multiple identical layers,
each containing a multi-head self-attention mechanism and a
position-wise feed-forward network. These layers work together to
capture the contextual relationships between words in the source
sentence. The final output of the encoder is a set of continuous vector
representations that carry the semantic and contextual information of
the input sequence.

Decoder: The decoder’s main role is to generate the
target sentence in the desired language based on the encoder’s output.
Like the encoder, the decoder consists of multiple identical layers, but
with an additional multi-head attention mechanism that connects it to
the encoder. This attention mechanism enables the decoder to focus on
relevant parts of the source sentence while generating each word in the
target sentence. The decoder uses a combination of masked self-attention
(to prevent it from attending to future words in the target sequence),
encoder-decoder attention, and position-wise feed-forward networks to
generate the target sentence step by step, predicting one word at a
time.

In short, the encoder processes the source sentence to create a
meaningful representation, and the decoder uses this representation,
along with its attention mechanisms, to generate the translated target
sentence.

First, I treated multi-headed attention as an existing class and
wrote the core structure of the Encoder and Decoder, then combined them
into a full Transformer model. I implemented the improvement of putting
the layer normalization outside the residual flow, to keep the gradients
efficient, as suggested by the On Layer Normalization in the
Transformer Architecture
paper.

The Basics of Multi-Headed
Attention

Transformers use a mechanism called self-attention to capture the
relationships between words in a sequence. Self-attention allows the
model to weigh the importance of each word in context, considering not
just its individual meaning, but also its relevance to other words in
the sequence.

In the self-attention process, the input sequence is converted into
three representations: key k, query q, and
value v vectors. These vectors are created by multiplying
the input embeddings with separate learned weight matrices.

For each word in the input sequence, the self-attention mechanism
calculates a score by taking the dot product of its query vector with
the key vector of every other word in the sequence. These scores are
then passed through a softmax function to produce attention weights,
which sum up to 1 and represent the importance of each word in the
context of the current word.

Next, the attention weights are multiplied by their corresponding
value vectors. The resulting weighted value vectors are summed to
produce a new context-aware representation for each word in the input
sequence.

In most actual implementations this self-attention process is
performed in parallel for multiple heads in the multi-head
attention layer, enabling the model to capture different types of
relationships between words. The outputs from all heads are then
concatenated and passed through a linear layer to create the final
self-attention output, which is sent to the subsequent layers in the
transformer architecture.

The First Try at
Multi-Headed Attention

Time to code the most complex part: The actual multi-headed
self-attention. For the first version, I thought it would be simplest to
have a simple for loop over each attention head, do the masking
there, and then concatenate the results.

While implementing the self-attention main mechanisms seemed
straightforward, the resulting self-attention was slow and did not work.
It failed to translate even simple sentences.

As a next step, I wrote a small test-class wrapper around the PyTorch
self-attention implementation. Using this, the model did learn, clearly
proving that my own self-attention implementation was the culprit.

I strongly suspect that something with the masking and later
concatenation process went awry, but I failed to properly debug it in
the original implementation. The code is still in the git repository, so
feel free to try to fix it yourself.

The Second Try at
Multi-Headed Attention

As the implementation with the explicit for-loop was slow anyway (one
epoch taking 3-4 hours), I decided to completely rewrite this class:
This time with only one linear mapping operation for the q,
k, and v vectors.


class NewMultiAttentionHeads(nn.Module):
    def __init__(self, embedding_dim, nr_attention_heads):
        """
        A multi-head attention module, implementing the scaled dot-product attention mechanism.
        This is the properly working and efficient implementation.

        Parameters
        ----------
        embedding_dim : int
            The dimension of the input embeddings.
        nr_attention_heads : int
            The number of attention heads.

        Attributes
        ----------
        q_lin : torch.nn.Linear
            The linear layer for the query matrix (of all heads).
        k_lin : torch.nn.Linear
            The linear layer for the key matrix (of all heads).
        v_lin : torch.nn.Linear
            The linear layer for the value matrix (of all heads).
        head_reducer : torch.nn.Linear
            The linear layer to reduce the multi-head attention output to the original embedding dimension.
        """
        # ...
        self.q_lin = nn.Linear(embedding_dim, embedding_dim)
        self.k_lin = nn.Linear(embedding_dim, embedding_dim)
        self.v_lin = nn.Linear(embedding_dim, embedding_dim)

        self.head_reducer = nn.Linear(embedding_dim, embedding_dim)

    # ...

An additional function split_heads is able to distribute
these vectors to a different number of attention heads.

class NewMultiAttentionHeads(nn.Module):
    # ...

    def split_heads(self, x, batch_size):
        """
        Split the input tensor into attention heads.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.
            Shape: [batch size, sequence length, embedding dim]
        batch_size : int
            The size of the batch.

        Returns
        -------
        torch.Tensor
            The reshaped input tensor.
            Shape: [batch size, nr_attention_heads, sequence length, per_head_dim]
        """
        x = x.view(batch_size, -1, self.nr_attention_heads, self.per_head_dim)
        return x.permute(0, 2, 1, 3)

    # ...

While the tricky part was to keep track of this distribution, it
proved way simpler to check how the masks are applied (and broadcasted)
to this one tensor.


class NewMultiAttentionHeads(nn.Module):
    # ...

    def forward(self, x, y=None, mask=None, key_padding_mask=None):
        """
        Forward pass of the multi-head attention module.

        Parameters
        ----------
        x : torch.Tensor
            The input tensor.
            Shape: [batch size, sequence length, embedding dim]
        y : torch.Tensor, optional
            The optional input tensor for cross-attention.
            Shape: [batch size, sequence length, embedding dim]
        mask : torch.Tensor, optional
            The mask to be applied to the attention scores.
            Shape: [batch size, 1, 1, sequence length]
        key_padding_mask : torch.Tensor, optional
            The mask for key padding.
            Shape: [batch size, sequence length]

        Returns
        -------
        torch.Tensor
            The output tensor of the multi-head attention.
            Shape: [batch size, sequence length, embedding dim]
        """
        batch_size = x.size(0)

        q = self.split_heads(self.q_lin(x), batch_size)
        k = self.split_heads(self.k_lin(y if y is not None else x), batch_size)
        v = self.split_heads(self.v_lin(y if y is not None else x), batch_size)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / self.sqrt_per_head_dim

        if mask is not None:
            attn_scores = attn_scores + mask
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1)
            attn_scores = attn_scores.masked_fill(key_padding_mask, float('-inf'))

        attn_weights = nn.functional.softmax(attn_scores, dim=-1)

        context = torch.matmul(attn_weights, v)
        context = context.permute(0, 2, 1, 3).contiguous()
        context = context.view(batch_size, -1, self.embedding_dim)

        return self.head_reducer(context)

With this implementation I finally managed to get my from-scratch
model working. It is also rather efficent, needing only 5 minutes more
per epoch than the pytorch implementation.

So while implementing the actual mappings, softmaxes and
self-attention was pretty straight forward the main difficulty with
getting the multi-headed attention working was in the masking,
especially of the padding tokens and the pytorch broadcasting rules.

3.
The Training of the Transformer Models: A Journey with AdamW

Training a transformer model can be a delicate task, with many
factors to consider. One common issue when using Stochastic Gradient
Descent (SGD) is the potential for instabilities, which are especially
common with transformer models. As a result, researchers and
practitioners often opt for variants of the Adam optimizer, such as
AdamW, which has become particularly popular in recent years, especially
for transformers models. AdamW has gained considerable attention thanks
to its recent success in the whisper paper and other research
projects.

The model was trained on an A100 server on Google Cloud. The 40 GB
GPU memory is sufficient to process one batch in full. One epoch of the
pytorch-reference transformer took approximately 65 minutes; my
from-scratch model needed 70 minutes per epoch.

Hyperparameters:
Beta Values, Weight Decay, and Epsilon

I use the beta values β1=0.9 and β2=0.98.
Some papers have suggested that these values are more stable than the
β2=0.999 used in the
BERT paper.

The original Attention model did not use any weight decay,
while BERT and RoBERTa both utilized a value of 0.01.
I decided to use a weight decay of 0.001 to strike a balance between the
papers and between stability and performance.

Finally, for the epsilon value, I adopt the ϵ=1e6 used in both the
RoBERTa and whisper papers.

Harnessing the Power of
Longer Training

The original Attention paper trained the model for
approximately 20 epochs with a warm-up of 1 epoch. However, a key
insight of the RoBERTa paper was that even BERT was
severely undertrained. To capitalize on this observation without
significantly increasing training costs, I decided to train the
reference model for 1+25 epochs. After seeing that the loss could
potential decrease further I trained the from-scratch model for 1+30
epochs, showing that the loss still decreased quite a bit.

For the first epoch, the learning rate is linearly increased with
each batch to reach a peak learning rate of 1e-3. In the
subsequent 25 (or 30) epochs, the learning rate is linearly decayed with
each epoch. This approach is more aligned with the BERT and
RoBERTa training, as the original attention paper used a slower
decaying schedule.

The loss over the 25 or 30 epochs of the transformer models
Zoom in on the loss, after the first 50,000 steps to allow for a better plot scale

4. Testing the Transformers

Now, let’s compare the transformer models with the original results
from the 2017 paper.

First, I wrote two functions: one performing a greedy translation of
a given text, and one implementing a simple beam search. My beam search
function is much simpler than the one used in the attention paper, which
relies on additional alpha and beta search parameters to weigh different
depths against each other. In contrast, mine simply remembers a few more
variants without any penalties.

The assessment of translation quality is usually performed with
BLEU.

BLEU is an automatic evaluation metric for machine translation that
measures the quality of translated text by comparing it to
human-generated reference translations (in our case only one). BLEU
operates at the level of n-grams (i.e. sequences of n words), and its
main objective is to quantify the similarity between the
machine-generated translation and the reference translations.

BLEU works by counting the number of matching n-grams between the
candidate translation and the reference translations, for various n-gram
lengths (typically 1 to 4). The counts are then normalized by dividing
them by the total number of n-grams in the candidate translation,
resulting in n-gram precision scores for each n-gram length.

Finally, the n-gram precision scores are combined using their
geometric mean (usually with an additional penalty for short sequences).
The resulting BLEU score ranges from 0% to 100%, with higher scores
indicating better translation quality.

While BLEU provides a quick and objective measure for translation
quality, it has some limitations, such as its insensitivity to different
valid translations and its focus on local matching of n-grams rather
than capturing higher-level semantic meaning.

Additionally, it is often unclear how the text was pre-processed and
on which level (raw words vs tokenized words) the comparison is based
(Blogpost
on computing and reporting BLEU scores
).

Additionally, there is SacreBLEU, which claims to be a better metric
for assessing translation quality. I ended up implementing two different
metrics: one based on NLTK’s BLEU and one based on torchmetric’s
SacreBLEU implementation.

The following table shows how the models perform for different
decoding types and metrics.

As a comparison, the reported values for the EN-DE base model from
the attention paper are 25.8 BLEU for newstest-2013 and 27.3 BLEU for
newstest-2014.

Dataset Model type Decoding type BLEU NTLK SacreBLEU torchmetric
Newstest-2013 reference beam 18.83 21.67
Newstest-2013 reference greedy 17.56 20.30
Newstest-2013 own-from-scratch beam 17.85 20.56
Newstest-2013 own-from-scratch greedy 16.26 18.82
Newstest-2014 reference beam 21.19 23.69
Newstest-2014 reference greedy 19.38 21.87
Newstest-2014 own-from-scratch beam 20.13 22.57
Newstest-2014 own-from-scratch greedy 17.56 19.97

When examining these results, one can see that many sequences receive
a score of zero (for all models and metrics):

Histogram of SacreBLEU scores for each sequence in newstest-2014

While the model definitely has its flaws, even some remarkably good
translations of long texts receive a score of zero. For example:

The translation receives a score of zero because the model used
synonyms like Vorschriften for regulations, instead of
Bestimmungen; or Besatzungsmitglieder for crew
members
, instead of Kabinenpersonal.

I’m left pondering the reasons for these results:

I also stumbled upon a source claiming to
reproduce the attention paper. Interestingly, their test code
seems to assess tokenized text, as opposed to the raw text I used.

Unfortunately, the models from the attention paper or the
aforementioned reproduction are not available for download. Being able
to access them would have enabled me to test their performance with my
scripts and decoding routines, providing a fair comparison and even
allowing me to compare the translation of individual sentences.

While the BLEU score is often criticized due to its approximate
nature, instead of getting lost in the details of this peculiar metric,
I decided to put the model to the test with a more manual
assessment.

5. Flask Translation
Application

What’s the point of a trained model if you don’t put it to use? It’s
time to create a small Flask app to host and query the model. Since the
model was trained on short snippets (mostly single sentences), the text
to be translated will first be split into sentences. Each sentence is
then translated separately. While this is not optimal in terms of
translation quality, it is by far the simplest to implement.

The input screen of the app

Let’s put the model to work on some Wikipedia articles:

The output screen for some snippets on machine translation from Wikipedia

Impressively, the model can even translate complex words like
computational linguistics to its proper German translation
Computerlinguistik. Although the wording could be smoother, the
translation doesn’t contain any glaring errors.

The output for the Wikipedia introduction on transformers

Curiously, the model replaced RNN with RN. Otherwise, the translation
is of good quality.

Next, let’s try translating the beginning of Genesis:

Translation of the first few lines of Genesis

While there are some minor inaccuracies, the model seems to struggle
with short quotations. For example, "sky" was translated as
"Sky" and "Sich" (whatever “sich” is supposed
to be). Nonetheless, the words day and night were
translated well, despite the quotation marks.

Now for something more challenging: a snippet of Stevenson’s Treasure
Island.

Translation of a paragraph from Treasure Island

The translation quality of this text is somewhat lacking compared to
the others. For instance, the model doesn’t translate chest as
treasure chest but rather its anatomical equivalent.
Interestingly, Google
Translate
and DeepL make the same
mistake, while GPT-4 correctly translated it as treasure chest,
even providing a consistent explanation and context for the shanty.

However, our translation app struggles with more “literary” texts and
their vivid descriptions. This is understandable, given that it was
trained primarily on non-literary text snippets. Oddly enough, the model
gets so confused that it can’t even translate the word
colour (replacing it with color yields the same
results). It attempts to translate it as Kolonisten (meaning
colonists) for the greedy translation and Räuber
(meaning bandits) for the beam search.

Conclusion

Although my model didn’t achieve the BLEU scores reported in the
original paper, I’m impressed with how well it performed on Wikipedia
articles. Additionally, my primary goal of implementing a transformer
from scratch was undoubtedly successful, as the reference and
from-scratch implementations demonstrated similar performance.

If I were to train a model specifically for translation tasks, I
would gather more high-quality data, ensuring that enough literary
translation pairs are included in the training set.

I hope you find the code and model useful. For my next project, I’ll
return to a language modeling task, which is more familiar territory for
me and my previous work with transformer models.

References