logo

Amrits Blog

🤖 Machine Learning | Analytics | Tech 👨‍💻

🌟 check out my website 🌟

Building GPT from scratch

2024-02-16

A diagram of a transformer

Modern language, image and video models are based on the transformer architecture proposed by researchers in 2017 in the landmark paper attention is all you need.

Before we look to the future of models like GPT-5 and SoRA we'll first need to tackle the basics. In this blog I will go over the transformer architecture piece by piece in Python and train it on the works of Shakespeare.

This article will be based on Andrej Karpathy video on transformers

What's done cannot be undone. - Macbeth

Index

  • Section 1: The dataset
  • Section 2: Encoding and Decoding
  • Section 3: Preparing the data
    • block sizes
    • batch normalisation
  • Section 4: The bigram model
    • Embeddings
    • Predictions
    • Generations
    • Model training
  • Section 5: Self Attention
    • Simple attention
    • Single head attention
    • Multi head attention
    • Feed Forward Layer
  • Section 6: Building the transformer
    • The decoder block

Section 1: The Dataset

The data used to train this model can be found here

First things first, let's open up the training data and take a quick look.

inputs:

with open("./data/input.txt", 'r') as file:
    text:str = file.read()
    
print(f"number of characters: {len(text)}")
print(text[:300])

outputs:

number of characters: 1115394

First Citizen: Before we proceed any further, hear me speak.

All: Speak, speak.

First Citizen: You are all resolved rather to die than to famish?

All: Resolved. resolved.

First Citizen: First, you know Caius Marcius is chief enemy to the people.

All: We know't, we know't.

First Citizen: Let us

Excellent! With the data loaded let's move to the next step, encoding the data.

Section 2: Encoding & Decoding

This project will use a simple character level tokenizer. This means each unique character will be represented by a numerical value. Read more about it here

We can achieve this by:

  • Obtaining all the unique characters in the text
  • Create a mapping of between character to index for the encoder
  • Create a mapping of between index to character for the decoder

inputs:

# Get unique characters 
chars = sorted(list(set(text)))

# Create a mapping from character to integer 
string_to_integer = {char:i for i, char in enumerate(chars)}
# Create a mapping from integer to string 
integer_to_string = {i:char for i,char in enumerate(chars)}

# Define functions to encode & decode data 
encode = lambda sentence: [string_to_integer[char] for char in sentence]
decode = lambda numbers: ''.join([integer_to_string[number] for number in numbers])

# View the vocabulary list 
print(f"Vocab Size: {len(chars)}") 
print(f"Unique Characters: {''.join(chars)}")

# Try out some encoding and decoding 
print("Encode:")
print(encode("Hey! How's it going?"))
print("Decode:")
print(decode(encode("Hey! How's it going?")))

outputs:

Vocab Size: 65

Unique Characters: 
!$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
 
Encode:
[20, 43, 63, 2, 1, 20, 53, 61, 5, 57, 1, 47, 58, 1, 45, 53, 47, 52, 45, 12]

Decode:
Hey! How's it going?

Section 3: Preparing the data

Now that we have our simple encoder and decoder working let's tokenize the works of Shakespeare! We will:

  • Encode the text
  • Convert the encoded text to a PyTorch tensor
  • Split the data into training and testing
  • Implement batches

Let's handle the first two steps below

input:

import torch 

# Tokenize the text and convert into a torch tensor 
data = torch.tensor(encode(text), dtype=torch.long)

# View tensor properties
print(f"data shape: {data.shape}")
print(f"data type: {data.dtype}")
print(f"First 100 characters:\n {data[:100]}")

output:

data shape: torch.Size([1115394])

data type: torch.int64

First 100 characters:
 tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])

Great, we've now gotten our data in a format which can be used by PyTorch. Next let's split the data into testing and validation, we can use a 90-10 split.

block size

We will now divide the training data into block sizes, there are two key reasons why we do this:

  • To optimize computational efficiency : Rather than overwhelming the transformer by feeding it all the information at once, we break down the input into sequential blocks optimising processing power.
  • Context: we also improve the model's comprehension of varying context lengths ranging from 1 -> block size.

In the following code snippet, we've set a block size of eight, but with an added offset of +1. This is used for predicting the next character at every position. To ensure we make eight predictions within each block, we need a block size of nine.

Let's use a smaller example demonstrating why we want five characters if we define the block size of four.

We can use the sequence of length five [10, 7, 2, 2, 15] which generates four predictions below:

SequenceTarget
10 ->7
10, 7 ->2
10, 7, 2 ->2
10, 7, 2, 2 ->15

Let's demonstrate this using a real example with our data.

input:

# Calculate the index that represents 90% of the data length
n = int(0.9 * len(data))
# First 90% of the data will be training 
train_data = data[:n]
# Remaining 10 % of the data will be testing 
test_data = data[n:]

# Define block size
block_size = 8
# Create a sequence of inputs for the specified block size
x = train_data[:block_size]
# Create target values for each input 
y = train_data[1:block_size + 1]

# Iterate through the first block mapping input to a target
for i in range(block_size):
    context_input = x[:i + 1]
    target = y[i]
    print(f"when input is {context_input} the target is : {target}")

output:

when input is tensor([18]) the target is : 47
when input is tensor([18, 47]) the target is : 56
when input is tensor([18, 47, 56]) the target is : 57
when input is tensor([18, 47, 56, 57]) the target is : 58
when input is tensor([18, 47, 56, 57, 58]) the target is : 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target is : 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target is : 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target is : 58

Batching & Shuffling

To be more efficient in our computation we will organize our data blocks into batches. This approach effectively utilises the parallel computing capabilities of your GPU. This will allow us to process multiple batches at the same time, completely independently.

We will sample random locations in the data to pull out these chunks. By changing the order we process chunks we can reduce any chances of overwriting and improve the models generalisation, this process is known as shuffling.

The function get_batch() will be responsible for:

  • Select Data : Select the training or validation based on user input.
  • Generate Random Indices : Create a tensor with dimensions batch_size x 1 containing random indices used to shuffle the selected data.
  • Generate Context : Using the random indices we can generate context sequences of dimensions block_size x 1.
  • Construct Tensor : These sequences are then stacked on top of each other to create a final tensor of batch_size x block_size dimensions.

Let's now do a code implementation for one training batch, so we can see how it looks!

inputs:

# Set a seed value for reproducibility
torch.manual_seed(1337)
# Maximum context length for each prediction
block_size = 8
# Number of blocks we will process in parallel
batch_size = 4


def get_batch(split: str) -> [torch.Tensor, torch.Tensor]:
    # Use training or validation data depending on user specification 
    data = train_data if split == 'train' else 'val_data'
    # Generate a tensor of random indices : batch_size x 1
    ix = torch.randint(len(data) - block_size, (batch_size,))
    # Generate context sequences and stack together : batch_size x block_size
    x = torch.stack([data[i:i + block_size] for i in ix])
    y = torch.stack([data[i + 1:i + block_size + 1] for i in ix])
    return x, y


# Generate one training batch of inputs / labels
x_batch, y_batch = get_batch('train')
# View the shape and value of our inputs
print(f"inputs:\n{x_batch.shape}\n{x_batch}")
# View the shape and value of our targets
print(f"targets:\n{y_batch.shape}\n{y_batch}")

# View the context sequence and target for our batch
print("\n----------------------------------------------\n")

# Iterate through batch dimension
for batch in range(batch_size):
    print(f"\nBatch : {batch + 1}/{batch_size}")
    # Iterate through time dimension
    for t in range(block_size):
        context = x_batch[batch, :t + 1]
        target = y_batch[batch, t]
        print(f"input: {context.tolist()} target: {target}")

outputs:

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])

----------------------------------------------


Batch : 1/4
input: [24] target: 43
input: [24, 43] target: 58
input: [24, 43, 58] target: 5
input: [24, 43, 58, 5] target: 57
input: [24, 43, 58, 5, 57] target: 1
input: [24, 43, 58, 5, 57, 1] target: 46
input: [24, 43, 58, 5, 57, 1, 46] target: 43
input: [24, 43, 58, 5, 57, 1, 46, 43] target: 39

Batch : 2/4
input: [44] target: 53
input: [44, 53] target: 56
input: [44, 53, 56] target: 1
input: [44, 53, 56, 1] target: 58
input: [44, 53, 56, 1, 58] target: 46
input: [44, 53, 56, 1, 58, 46] target: 39
input: [44, 53, 56, 1, 58, 46, 39] target: 58
input: [44, 53, 56, 1, 58, 46, 39, 58] target: 1

Batch : 3/4
input: [52] target: 58
input: [52, 58] target: 1
input: [52, 58, 1] target: 58
input: [52, 58, 1, 58] target: 46
input: [52, 58, 1, 58, 46] target: 39
input: [52, 58, 1, 58, 46, 39] target: 58
input: [52, 58, 1, 58, 46, 39, 58] target: 1
input: [52, 58, 1, 58, 46, 39, 58, 1] target: 46

Batch : 4/4
input: [25] target: 17
input: [25, 17] target: 27
input: [25, 17, 27] target: 10
input: [25, 17, 27, 10] target: 0
input: [25, 17, 27, 10, 0] target: 21
input: [25, 17, 27, 10, 0, 21] target: 1
input: [25, 17, 27, 10, 0, 21, 1] target: 54
input: [25, 17, 27, 10, 0, 21, 1, 54] target: 39

Section 4: The Bigram language model

Before we build using transformer architecture let's first establish a baseline by using a very simple language model.

In short, a N-Gram model predicts the probability of the next item in a sequence based on the previous N - 1 items. It's assumed that the probability of the next item depends only on the previous N - 1 items, therefore a bigram model will predict the next item based only on the previous one. Learn more here

In practice the next word in a sentence does not depend solely on the previous one. Using this approach will allow us to draw up comparisons between this and a transformer.

Within our bigram model we will

  • Define a embedding table
  • Define a forward pass where we make predictions

Embeddings

Why do we need an embedding table?

We can tokenize characters, sub-words and words into a numerical representation and this is something we've already done on a character level. To demonstrate how embeddings look let's take a very simple example using word level tokenization for the following sentence :

The dog ate my homework

The dog feasted on my socks

The numeric representation of that would be

1 2 3 4 5

1 2 6 7 4 8

Here's what our data dictionary would look like:

WordIndex
The1
Dog2
Ate3
My4
Homework5
feasted6
on7
socks8

Now, The question is how could we measure closeness of each word? Or in this example how could we represent that ate and feasted have a similar meaning?

We can't use the arbitrary number we've assigned, so each word is given a unique dense vector which is learned from the data. Here is an example of if we used an embedding size of 2 x 2 for training the model.

WordIndexEmbedding
The1[[0.46, 0.79], [0.20, 0.51]]
Dog2[[0.59, 0.05], [0.61, 0.17]]
Ate3[[0.07, 0.95], [0.97, 0.81]]
My4[[0.12, 0.50], [0.03, 0.91]]
Homework5[[0.26, 0.66], [0.31, 0.52]]
on7[[0.55, 0.18], [0.97, 0.78]]
socks8[[0.94, 0.89], [0.60, 0.92]]

A visual representation:

Now, let's use a practical example where the embeddings are plotted on a two-dimensional axis. The visualisation can be found in the paper Cross-domain sentiment-aware word embeddings for review sentiment analysis

Words which share similar themes appear "closer" together within the embedding space such as light, good, sharp, best, excellent. Similarly, poor, bad are grouped close together.

An example

In our model we will have the dimensionality of vocab_size x vocab_size. Using a high dimensional space is usually impractical as it can increase computational complexity and dilute meaningful relationships between words. It is common to have embedding dimensions such as 50,100, 300 however in this example our vocabulary size is 62.

Making predictions and calculating the loss

Within our forward function we will pass the inputs and target values as arguments.

  • The inputs are passed into our embedding table where a prediction will be made
  • The predictions will have dimensionality batch_size x block_size x vocab_size
  • We will need to reshape the predictions as the loss function in PyTorch expects dimensionality batch_size x vocab_size x block_size
  • We will also reshape the target values by 'flattening' them into one dimension
  • Using cross entropy as our loss function we can see how well the model did at predicting

Generate

This function will generate new tokens based on the learned embeddings by

  • Loop for new tokens: The function enters a loop which runs for max_new_tokens In each iteration the sequence will be extended by one.
  • Make predictions : The forward function is called and the predictions are stored in the logits variable.
  • Isolate the last time step: The last step in the predicted values is then extracted as we want to predict what is next in the sequence based on the latest context
  • Convert into probabilities : The softmax function is applied to convert predictions into probabilities.
  • Select next token: Sample from the distribution to randomly select a token for each sequence based on the computed probabilities.
  • Add new token to current sequence : After generating a new token it is added to the current sequence.

inputs:

import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)


class BigramModel(nn.Module):

    def __init__(self, vocab_size: int):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, inputs, targets=None):
        # Make a prediction with our current inputs
        logits = self.token_embedding_table(inputs)  # dimensions (Batch_size, block_size, vocab_size)

        if targets is None:
            loss = None
        else:
            # Unpack each dimension  
            B, T, C = logits.shape
            # Reshape logits & targets 
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)

            # Calculate the loss based on our predictions and the target
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, inputs, max_new_tokens):
        # Input is (batch_size x block_size) 
        for _ in range(max_new_tokens):
            # Retrieve predictions
            logits, loss = self(inputs)
            # Focus on the last time step of prediction
            logits = logits[:, -1, :]  # dimensions: (B,C)
            # Calculate probabilities by applying a softmax
            probs = F.softmax(logits, dim=-1)  # dimensions: (B,C) 
            # Sample from the distribution
            inputs_next = torch.multinomial(probs, num_samples=1)  # dimensions: (B, 1) 
            # Append sampled index to the running sequence 
            inputs = torch.cat((inputs, inputs_next), dim=1)  # (B, T+1) 
        return inputs


vocab_size = len(chars)
model = BigramModel(vocab_size=vocab_size)

logits, loss = model(x_batch, y_batch)
print(logits.shape)
print(loss)

# Let's generate 100 new characters based on an input of 0 which represents a new line
print(decode(model.generate(idx=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

outputs:

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3

The output is complete gibberish right now as we haven't trained our model.

Model Training

For the training of the bigram model we will use the following hyperparameters:

  • Optimiser : Adam
  • Learning Rate : 0.001

To learn more about optimisers click here To learn more about learning rate click here

We will now train the model, feel free to adjust the number of steps. Once model training concludes we will generate the next 100 characters in the sequence based off a new line character.

input:

# Create torch optimizer 
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
# Define a larger batch size
batch_size = 32

for step in range(1000):
    # Sample a batch of data
    x_batch, y_batch = get_batch('train')
    
    # Make a prediction and evaluate the loss
    logits, loss = model(x_batch, y_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())
print(decode(model.generate(inputs=torch.zeros((1, 1), dtype=torch.long), max_new_tokens=300)[0].tolist()))

output:

2.4722609519958496


Ann ied he y.
tht l acke siespre brt s RWALO:


t is
Qclide hedeist wote s ntis chendve

F blouratr; br:
Fau CAs'r to titopouthoul mar, th o avalfimingl ithaces, ms mulid t h swesed
There, lly you, to, brsqquloe h Genthil teacond sirngucerour se, urw hMar TEY3f CARDone sthiPve fallut ndr dZy, pil fa

Certainly no Shakespeare but progress is being made for a simple bigram model. Recall we are only giving the model a context of only one character.

Section 5: Self Attention

In our current implementation our tokens in the time series component do not interact with each other. To remedy this we will implement the simplest for of self attention - a summed average.

Simple attention

Let's take an example of eight tokens.

We wouldn't want the fifth token to be able to communicate with sixth, seventh and eighth as they are future tokens. The fifth token should only be able to communicate with prior tokens.

The simplest way for us to communicate with the past context is to take an average of previous elements. Granted this method is an extremely weak form of interaction as we lose a lot of information on spacial arrangements of these tokens.

We will refer to the weighted average tensor as xbow. Bag Of Words (BOW) refers to an unordered collection of words which you can read more about here.

The xbow tensor will have the same dimensionality as our input tensor being batch size, time component, channel with the key change being in the time component section.

To be more computationally efficient we can use matrix multiplication with this trick:

input:

# Set seed for reproducibility 
torch.manual_seed(101)

# Set batch size, time component and channels
B, T, C = 4, 8, 2

# Create a Tensor of random values with dimensions B,T,C
x = torch.randn(B, T, C)

# Create a Tensor which when multiplied by our input will produce a weighted sum
# Lower traingle will tell us how much of each element will fuse 
tril = torch.tril(torch.ones(T, T))

# Weights begin at Zero 
weights = torch.zeros((T, T))

# Assert the future can not communicate with the past  
weights = weights.masked_fill(tril == 0, float('-inf'))

# Normalize and sum 
weights = F.softmax(weights, dim=1)
xbow = weights @ x

By constructing a weights matrix with dimensions Time series component x Time series component we can simply produce a weighted average by calling weights @ x where x is the input matrix.

single head attention

When we initialise the affinities (relationships) between tokens to be zero in the line weights = torch.zeros((T, T)) we give the weights a uniform numbers. We don't want this to be uniform as different tokens will find other tokens more or less interesting.

For example if you are a vowel then you would look for consonants in the past context. We want to gather information from the past but do so in a data dependent way.

Self attention solves this by having every single token will emit two vectors

  • A query
  • A key

Query vector -> what am I looking for

Key vector -> what do I contain

The way we get the affinities between tokens is by doing a dot product between the queries and keys. A query will dot product with the keys of all the other tokens which will produce the weights.

If the key and query are aligned they will interact by a higher amount leading to a more meaningful relationship.

It is important to note that tokens do not interact across batches as they are processed independently.

Quick note:

In encoder blocks we allow the tokens to communicate with each other from the future and past

In decoder blocks we mask future context with the triangular matrix

Self attention means that the queries, keys and values all come from the same source, x However attention is much more general such as in encoder and decoder transformers where the queries are produced from x but the keys and values come from different sources. This is referred to as cross-attention.

# Set seed for reproducibility 
torch.manual_seed(101)
# Set batch size, time component and channels
B, T, C = 4, 8, 32
# Create a Tensor of random values with dimensions B,T,C
x = torch.randn(B, T, C)

# A single head performing self attention 
head_size = 16
key = nn.Linear(in_features=C, out_features=head_size, bias=False)
query = nn.Linear(in_features=C, out_features=head_size, bias=False)
value = nn.Linear(in_features=C, out_features=head_size, bias=False)

# Forward these modules on the input tensor 
k = key(x)  # Dims: B,T,head_size
q = query(x)  # Dims: B,T,head_size

# Communication
# Transpose last two dimensions for K 
weights = q @ k.transpose(-2, -1) # (B,T,16) @ (B,16,T) --> (B, T, T) 
# Apply scaling to control variance at initialisation 
weights = weights * head_size**-0.5

tril = torch.tril(torch.ones(T,T))
# Ensure tokens do not communicate
weights = weights.masked_fill(tril==0, float('-inf'))
weights = F.softmax(weights, dim=1)

# X is private information to this token 
v = value(x)
out = weights @ v

Multi-head attention

Now that we've looked at how we can create a single head of transformation, next up let's look at multi head attention.

In the code below we first define the class Head which contains the query, key, value parameters as discussed above. There is also a forward pass where we calculate the weighted matrix of inputs.

For multi-head attention we can have multiple independent channels of communication running at once. Each head will produce its own set of attention weights and produces an output vector. These output vectors are then concatenated back together and passed through a linear projection to combine them. In the example below we also pass the combination through a dropout layer, read more about them here

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)  # (B,T,C)
        q = self.query(x)  # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2, -1) * C ** -0.5  # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))  # (B, T, T)
        wei = F.softmax(wei, dim=-1)  # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x)  # (B,T,C)
        out = wei @ v  # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out


class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

Feed Forward Layer

Feed forward neural network layers (FFNN) are positioned after each multi attention block in both the encoder and decoder parts of a transformer.

In our code the FFNN layer will consist of a

One of the main reasons we include a FFNN layer is to introduce non-linearity to the model this allows the transformer to capture higher complexity abstractions.

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

Section 6: Building the Transformer

Decoder Block

We've covered all the individual parts of the Transformer so let's assemble the Decoder block. In the transformer diagram at the start of this article the decoder block is the section on the right hand side.

To improve optimisation we will also implement residual connections in the forward function of our Block class.

We've also included the Add & Normalisation layers at self.ln1 = nn.LayerNorm(n_embd) self.ln2 = nn.LayerNorm(n_embd)

class Block(nn.Module):
""" Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

Bringing it all together

We've gone through in detail all the different concepts that go into a Transformer so let's bring it all together here:

import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 16 
block_size = 32 #
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
# ------------

torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('data/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
string_to_index = { ch:i for i,ch in enumerate(chars) }
index_to_string = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [string_to_index[c] for c in s] 
decode = lambda l: ''.join([index_to_string[i] for i in l])

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = BigramLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))