GPT-2 Stripped: Comparative Analysis
Comparing various Positional Encodings and Attention Mechanisms
Why GPT-2? Because GPT-2 is a decoder only pipeline that predicts the next token. An encoder-only pipeline is only good for classification tasks, and a sequence-to-sequence pipeline is only good for translation tasks. GPT-2 is a good model to use for text generation tasks, and it is the most popular model for such tasks.
Why Causal Attention? It is a simple attention mechanism that only attends to the previous tokens in the sequence, and it is the most efficient way to use attention mechanism for text generation tasks.
Data Pipeline
We use the FineWeb-Edu dataset, which is a collection of educational articles from the web. We took the sample-10BT version of it, which contains 10 billion GPT2 tokens.
The dataset is then tokenized either using the Hugging Face Tokenizer or tiktoken python library. The dataset is then stored as shards, with shard length 1e8 tokens, to reduce the size of the dataset from 40+GB to 19GB and read it easily when training or validating the model. Below is the configuration for the GPT-2 model.
@dataclass
class GPTConfig:
block_size: int = 1024 # Maximum sequence length
vocab_size: int = 50257 # 50k "Byte Pair Encodings" (BPE) vocab size + 256 bytes tokens + 1 <|endoftoken|>
# special end of sequence token delimits document boundaries and can start generation as well
n_layer: int = 12 # Number of transformer blocks (how deep is the model)
n_head: int = 12 # Number of heads in the multi-head attention (how wide is the model)
n_embed: int = 768 # Embedding dimensionality
To understand this, the sequence length T = 1024, batch_size B is set accordingly based on the GPU size, the number of tokens in the vocabulary is set to 50257. The number of layers is 12, the number of heads is 12 as we use multi-head attention and the embedding size C = 768.
To read a batch, from each shard, B * T + 1 tokens are read. The input takes the first B * T tokens and are reshaped to (B, T). The output is the next token in the sequence, which takes last B * T tokens and are reshaped to (B, T). For example, if the B * T + 1 is 9 and tokens are [1, 2, 3, 4, 5, 6, 7, 8, 9], the input will be [1, 2, 3, 4, 5, 6, 7, 8] and the output will be [2, 3, 4, 5, 6, 7, 8, 9]. The last token is the end of sequence token, which is used to delimit the document boundaries and can start the generation as well. The input and output are then passed to the model for training or validation.
If you are using parallel training, then from each shard, B * T * number_of_gpus + 1 tokens are read. Later, these tokens are read based on the process number.
Model Pipeline
The GPT-2 model takes tokens and converts them to 768-dimensional embeddings, then adds positional encodings and finally passes them through the transformer blocks. There are 12 transformer blocks with each containing 12 multi-head attention layers, with each attention block containing 12 attention heads, and feed-forward layer. The output of the transformer blocks is then passed through a linear layer to get the logits, which are then passed through a softmax layer to get the probabilities of the next token in the sequence. Below is the code for the GPT2 structure.
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# Developing Transformer
self.transformer = nn.ModuleDict({
'wte': nn.Embedding(config.vocab_size, config.n_embed), # Token embedding weights
'wpe': nn.Embedding(config.block_size, config.n_embed), # Positional embedding weights
'h': nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # All transformer blocks
'ln_f': nn.LayerNorm(config.n_embed)
})
# Final Linear layer after all transformer blocks
self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)
# Weight sharing scheme
# This is for sharing the weights between token and positional embeddings
# Reason: Since they are semantically similar, they should have similar weights
self.lm_head.weight = self.transformer['wte'].weight
# Initialize parameters with mean 0 and standard deviation 0.02 because 1/sqrt(768), 1/sqrt(1600)
self.apply(self._init_weights)
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.config.block_size,
f"Cannot forward sequence of length {T}, model block size is {self.config.block_size}"
# IMP: Token and Positional Embeddings
pos = torch.arange(0, T, device=idx.device, dtype=torch.long)
pos_emb = self.transformer.wpe(pos) #Positional Embeddings of shape (T, n_embed)
tok_emb = self.transformer.wte(idx) #Token Embeddings of shape (B, T, n_embed)
x = tok_emb + pos_emb # broadcast along the batch dimension
# Forward pass through each transformer block
for block in self.transformer.h:
x = block(x)
# Final Linear layer
x = self.transformer.ln_f(x)
x = self.lm_head(x)
# Loss function
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
return x, loss
Chinchilla Scaling Law: The Chinchilla Scaling law states that optimal performance is achieved when the number of training tokens should be at least 2 times the number of parameter of parameters. Since the number of tokens are 10 billion and the number of parameters are 124 million, this law is satisfied.
Attention Mechanisms
Attention mechanisms are used to give the model a sense of the importance of the tokens in the sequence. The attention mechanism is used to compute the attention scores for each token in the sequence. There are various types of attention mechanisms, such as normal attention, flash attention, linformer. We will compare these attention mechanisms in this project.
Normal Attention
Normal attention is the most basic attention mechanism, which computes the attention scores for each token in the sequence. It is defined as follows:
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embed % config.n_head == 0
# Query/key/value projections
self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)
# Output projection
self.c_proj = nn.Linear(config.n_embed, config.n_embed)
self.c_proj.NANOGPT_SCALE_INIT = 1.0
# Configuration
self.n_head = config.n_head
self.n_embed = config.n_embed
# Causal mask
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size()
# Split QKV
q, k, v = self.c_attn(x).split(self.n_embed, dim=2)
# Reshape for multi-head attention
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
# Attention computation
att = (q @ k.transpose(-2, -1)) * (1.0 / (k.size(-1) ** 0.5))
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ v
return y
Flash Attention
Flash Attention is an optimized version of Normal attention. It is defined as follows:
# Attention mechanism
# att = (q @ k.transpose(-2, -1)) * (1.0 / ((k.size(-1)) ** 0.5))
# # Masked Attention
# att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
# att = F.softmax(att, dim=-1)
# y = att @ v
# Flash Attention
y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # wow who knew flash attention was so easy to implement
Linformer Attention
Linformer attention is a type of attention mechanism that uses linear projections to reduce the computational complexity of the attention mechanism. It is defined as follows:
def get_EF(input_size, dim, method="no_params", head_dim=None, bias=True):
'''Returns E/F matrix with xavier initialization'''
assert method in ["learnable", "convolution", "no_params"], "Invalid method!"
if method == "convolution":
return nn.Conv1d(head_dim, head_dim, kernel_size=int(input_size/dim))
elif method == "no_params":
mat = torch.zeros((dim, input_size))
torch.nn.init.normal_(mat, std=1/dim)
return mat
elif method == "learnable":
lin = nn.Linear(input_size, dim, bias)
torch.nn.init.xavier_normal_(lin.weight)
return lin
def gen_causal_mask(input_size, dim_k, full_attention=False):
return (torch.triu(torch.ones(input_size, input_size)) if full_attention
else torch.tril(torch.ones(input_size, dim_k))).bool()
class CausalSelfAttention(nn.Module):
def __init__(self, config, dim_k=None, eps=0.4):
super().__init__()
assert config.n_embed % config.n_head == 0
# Projections
self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)
self.c_proj = nn.Linear(config.n_embed, config.n_embed)
self.c_proj.NANOGPT_SCALE_INIT = 1.0
# Config
self.n_head = config.n_head
self.n_embed = config.n_embed
self.block_size = config.block_size
self.eps = eps
self.dim_k = dim_k or int(min(9*self.n_embed*np.log(self.n_embed)/self.eps**2,
5*np.log(self.block_size)/self.eps**2))
# Linformer projections
delta = 1/(2**self.block_size)
self.causal_mask = gen_causal_mask(self.block_size, self.dim_k)
self.E_proj = nn.Parameter(delta * torch.stack([get_EF(self.block_size, self.dim_k)
for _ in range(self.n_head)]))
self.F_proj = nn.Parameter(np.exp(-delta) * torch.stack([get_EF(self.block_size, self.dim_k)
for _ in range(self.n_head)]))
# Causal mask
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)))
def forward(self, x):
B, T, C = x.size()
q, k, v = self.c_attn(x).split(self.n_embed, dim=2)
# Reshape for multi-head attention
q = q.view(B, T, self.n_head, -1).transpose(1, 2)
k = k.view(B, T, self.n_head, -1).transpose(1, 2)
v = v.view(B, T, self.n_head, -1).transpose(1, 2)
# Project keys/values
K_proj = self.E_proj.expand(B, -1, -1, -1) @ k
V_proj = self.F_proj.expand(B, -1, -1, -1) @ v
# Flash attention implementation
y = F.scaled_dot_product_attention(q, K_proj, V_proj, is_causal=True)
# Output projection
return self.c_proj(y.transpose(1,2).contiguous().view(B,T,C))
Positional Encodings
Positional encodings are used to give the model a sense of the order of the tokens in the sequence. The positional encodings are added to the token embeddings before passing them through the transformer blocks. There are various types of positional encodings, such as sinusoidal, learned, and absolute positional encodings. We will compare these positional encodings in this project.
Learned Positional Encoding
Learned positional encoding is a type of positional encoding that is learned during training. It is defined as follows:
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# Developing Transformer
self.transformer = nn.ModuleDict({
'wte': nn.Embedding(config.vocab_size, config.n_embed), # Token embedding
'wpe': nn.Embedding(config.block_size, config.n_embed), # Positional embedding
'h': nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
'ln_f': nn.LayerNorm(config.n_embed)
})
# Final Linear layer
self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)
self.lm_head.weight = self.transformer['wte'].weight
# Initialize parameters
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
std = 0.02
if hasattr(module, 'NANOGPT_SCALE_INIT'):
std *= (2 * self.config.n_layer) ** -0.5
torch.nn.init.normal_(module.weight, mean=0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is {self.config.block_size}"
# Embeddings
pos = torch.arange(0, T, device=idx.device)
pos_emb = self.transformer.wpe(pos)
tok_emb = self.transformer.wte(idx)
x = tok_emb + pos_emb
# Transformer blocks
for block in self.transformer.h:
x = block(x)
# Final output
x = self.transformer.ln_f(x)
x = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
return x, loss
Sinusoidal Positional Encoding
Sinusoidal positional encoding is a fixed encoding that is based on the sine and cosine functions. It is defined as follows:
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# Developing Transformer
self.transformer = nn.ModuleDict({
'wte': nn.Embedding(config.vocab_size, config.n_embed), # Token embedding weights
'h': nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # All transformer blocks
'ln_f': nn.LayerNorm(config.n_embed)
})
...
def get_sinusoidal_encoding(self, T):
position = torch.arange(0, T, dtype=torch.float).unsqueeze(1)
div_term = 10000 ** (-2 * torch.arange(self.config.n_embed // 2) / self.config.n_embed)
encoding = torch.zeros(T, self.config.n_embed)
encoding[:, 0::2] = torch.sin(position * div_term)
encoding[:, 1::2] = torch.cos(position * div_term)
return encoding
def forward(self, idx, targets=None):
B, T = idx.size()
assert T <= self.config.block_size,
f"Cannot forward sequence of length {T}, block size is {self.config.block_size}"
# Token Embeddings
tok_emb = self.transformer.wte(idx)
# Positional Encodings
pos_emb = self.get_sinusoidal_encoding(T).to(idx.device)
pos_emb = pos_emb.unsqueeze(0).expand(B, -1, -1)
x = tok_emb + pos_emb
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
x = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
return x, loss
RoPE (Rotary Positional Encoding)
Rotary Positional Encoding (RoPE) is a type of positional encoding that is based on the rotation of the token embeddings. It is defined as follows:
class RotaryPositionEmbeddings(nn.Module):
'''Rotary Position Embeddings, as described in the RoPE paper'''
def __init__(self, config, base=10_000):
super().__init__()
self.base = base
self.dim = config.n_embed
self.max_seq_len = config.block_size
self.config = config
self.rope_init()
def rope_init(self):
'''Initialize the RoPE cache with sin and cos values for each position.'''
theta = torch.pow(self.base, -2 * torch.arange(0, self.dim // 2).float() / self.dim)
self.register_buffer('theta', theta, persistent=False)
self.build_rope_cache()
def build_rope_cache(self):
'''Build the RoPE cache for the given block size.'''
seq_idx = torch.arange(self.max_seq_len, dtype=self.theta.dtype, device=self.theta.device)
idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float()
hs_half = self.config.n_embed // self.config.n_head // 2
idx_theta = idx_theta[:, :hs_half]
cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
self.register_buffer('cache', cache, persistent=False)
def forward(self, x):
b, seq_len, nh, hs = x.shape
print(f"shape of x before reshaping in RoPE: {x.shape}")
rope_cache = self.cache[:seq_len, :hs // 2].to(x.device)
x = x.reshape(*x.shape[:-1], -1, 2)
print(f"shape of x after reshaping in RoPE: {x.shape}")
rope_cache = rope_cache.unsqueeze(0).unsqueeze(2)
print(f"Shape of rope_cache in RoPE: {rope_cache.shape}")
rotated = torch.stack([
x[..., 0] * rope_cache[..., 0] - x[..., 1] * rope_cache[..., 1],
x[..., 1] * rope_cache[..., 0] + x[..., 0] * rope_cache[..., 1]
], dim=-1)
print(f"shape of rotated before flattening in RoPE: {rotated.shape}")
print(f"shape of rotated after flattening in RoPE: {rotated.flatten(-2).shape}")
return rotated.flatten(-2).type_as(x)
Kerple Positional Encoding
Kerple Positional Encoding is a type of positional encoding that is based on the kernel trick. To apply this transformation, we need to set the causal mask to be False and apply attention mask to be kerple mask if you are using flash attention. It is defined as follows:
class ClampedKerpleParameterModule(nn.Module):
def __init__(self, num_heads, scale=1, min_value=1e-2):
super().__init__()
self.param = nn.Parameter(torch.rand(num_heads) * scale)
self.min_value = min_value
def forward(self, x):
clamped_param = torch.clamp(self.param, min=self.min_value)
return clamped_param.view(-1, 1, 1) * x
class KerplePositionalEncoding(nn.Module):
def __init__(self, num_heads, block_size=1024, scale_r1=1, scale_r2=2, min_value=1e-2):
super().__init__()
self.r1 = ClampedKerpleParameterModule(num_heads, scale=scale_r1, min_value=min_value)
self.r2 = ClampedKerpleParameterModule(num_heads, scale=scale_r2, min_value=min_value)
self.n_head = num_heads
self.block_size = block_size
# Precompute distance matrix
self.distance_matrix = torch.arange(self.block_size).view(-1, 1) - torch.arange(self.block_size).view(1, -1)
self.distance_matrix = self.distance_matrix.abs().float().to(device)
self.distance_matrix = self.distance_matrix.repeat(self.n_head, 1, 1)
def forward(self, q_shape):
applied_distance_matrix = self.distance_matrix
if q_shape != self.block_size:
pos = torch.arange(q_shape[2]).view(-1, 1) - torch.arange(q_shape[2]).view(1, -1)
distance_matrix = pos.abs().float().to(device)
distance_matrix = distance_matrix.repeat(self.n_head, 1, 1)
applied_distance_matrix = distance_matrix
distance_matrix = self.r1(applied_distance_matrix)
distance_matrix = -self.r2(1 + torch.log(1 + distance_matrix))
mask = torch.ones(q_shape[2], q_shape[2]).tril(diagonal=0).repeat(self.n_head, 1, 1)
distance_matrix = distance_matrix.masked_fill(mask.logical_not().to(device), float('-inf'))
return distance_matrix
Alibi Positional Encoding
Alibi Positional Encoding is a type of positional encoding that is based on the attention mechanism. To apply this transformation, similar to kerple positional encoding, we need to set the causal mask to be False and apply attention mask to be alibi mask if you are using flash attention. It is defined as follows:
class CausalSelfAttention(nn.Module):
def alibi_mask(self, q_shape):
# Create alibi_mask: q* k_T + m * [a - b]
# [a - b] is lower triangular matrix with negative values
pos = - torch.arange(q_shape[2]).view(-1, 1) + torch.arange(q_shape[2]).view(1, -1)
pos = pos.float().masked_fill(pos > 0, float('-inf')) # Shape: (block_size, block_size)
head_m_value = torch.tensor([2.0**(-8/i) for i in torch.linspace(1, 8, self.n_head)])
return head_m_value.view(-1, 1, 1) * pos
FIRE Positional Encoding
FIRE Positional Encoding is a type of positional encoding that is based on the Fourier Transform. To apply this positional encoding, we need to do similar to Kerple and Alibi positional encoding in the attention section which is to set the causal mask to be False and apply attention mask to be fire mask. It is defined as follows:
class FIRE(nn.Module):
def __init__(self, num_heads=12, mlp_width=32, init_c=0.1, init_L=512., eps=1e-6):
super(FIRE, self).__init__()
# Define MLP layers
self.mlp = nn.Sequential(
nn.Linear(1, mlp_width),
nn.ReLU(),
nn.Linear(mlp_width, num_heads)
)
# Initialize parameters
self.c = nn.Parameter(torch.tensor(init_c))
self.init_L = nn.Parameter(torch.tensor(init_L), requires_grad=False)
self.L_multiplier = nn.Parameter(torch.tensor(1.0))
self.eps = eps
def forward(self, x: torch.Tensor):
seq_length = x.size(2)
positions = torch.arange(seq_length, dtype=torch.float, device=x.device)
rel_distance = positions[:, None] - positions[None, :]
# Thresholding and normalization
threshold = torch.abs(self.L_multiplier * self.init_L)
pos_normalizer = torch.max(positions, threshold).unsqueeze(1)
# Log transformations
rel_distance = torch.log(torch.abs(self.c * rel_distance) + 1)
pos_normalizer = torch.log(torch.abs(self.c * pos_normalizer) + 1) + self.eps
# Compute FIRE bias
normalized_distance = rel_distance / pos_normalizer
fire_bias = self.mlp(normalized_distance.unsqueeze(-1))
fire_bias = fire_bias.permute(2, 0, 1)
# Apply causal mask
mask = torch.ones(seq_length, seq_length).tril(diagonal=0).repeat(fire_bias.shape[0], 1, 1)
fire_bias = fire_bias.masked_fill(mask.logical_not().to(device), float('-inf')).unsqueeze(0)
return fire_bias