GPT-2 Stripped: Comparative Analysis

Comparing various Positional Encodings and Attention Mechanisms

Why GPT-2? Because GPT-2 is a decoder only pipeline that predicts the next token. An encoder-only pipeline is only good for classification tasks, and a sequence-to-sequence pipeline is only good for translation tasks. GPT-2 is a good model to use for text generation tasks, and it is the most popular model for such tasks.

Why Causal Attention? It is a simple attention mechanism that only attends to the previous tokens in the sequence, and it is the most efficient way to use attention mechanism for text generation tasks.

Data Pipeline

We use the FineWeb-Edu dataset, which is a collection of educational articles from the web. We took the sample-10BT version of it, which contains 10 billion GPT2 tokens.

The dataset is then tokenized either using the Hugging Face Tokenizer or tiktoken python library. The dataset is then stored as shards, with shard length 1e8 tokens, to reduce the size of the dataset from 40+GB to 19GB and read it easily when training or validating the model. Below is the configuration for the GPT-2 model.

@dataclass
class GPTConfig:
    block_size: int = 1024 # Maximum sequence length
    vocab_size: int = 50257 # 50k "Byte Pair Encodings" (BPE) vocab size + 256 bytes tokens + 1 <|endoftoken|>
    # special end of sequence token delimits document boundaries and can start generation as well
    n_layer: int = 12 # Number of transformer blocks (how deep is the model)
    n_head: int = 12 # Number of heads in the multi-head attention (how wide is the model)
    n_embed: int = 768 # Embedding dimensionality
            

To understand this, the sequence length T = 1024, batch_size B is set accordingly based on the GPU size, the number of tokens in the vocabulary is set to 50257. The number of layers is 12, the number of heads is 12 as we use multi-head attention and the embedding size C = 768.

To read a batch, from each shard, B * T + 1 tokens are read. The input takes the first B * T tokens and are reshaped to (B, T). The output is the next token in the sequence, which takes last B * T tokens and are reshaped to (B, T). For example, if the B * T + 1 is 9 and tokens are [1, 2, 3, 4, 5, 6, 7, 8, 9], the input will be [1, 2, 3, 4, 5, 6, 7, 8] and the output will be [2, 3, 4, 5, 6, 7, 8, 9]. The last token is the end of sequence token, which is used to delimit the document boundaries and can start the generation as well. The input and output are then passed to the model for training or validation.

If you are using parallel training, then from each shard, B * T * number_of_gpus + 1 tokens are read. Later, these tokens are read based on the process number.

Model Pipeline

The GPT-2 model takes tokens and converts them to 768-dimensional embeddings, then adds positional encodings and finally passes them through the transformer blocks. There are 12 transformer blocks with each containing 12 multi-head attention layers, with each attention block containing 12 attention heads, and feed-forward layer. The output of the transformer blocks is then passed through a linear layer to get the logits, which are then passed through a softmax layer to get the probabilities of the next token in the sequence. Below is the code for the GPT2 structure.

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Developing Transformer
        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(config.vocab_size, config.n_embed), # Token embedding weights
            'wpe': nn.Embedding(config.block_size, config.n_embed), # Positional embedding weights
            'h': nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # All transformer blocks
            'ln_f': nn.LayerNorm(config.n_embed)
        })

        # Final Linear layer after all transformer blocks
        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)
    
        # Weight sharing scheme
        # This is for sharing the weights between token and positional embeddings
        # Reason: Since they are semantically similar, they should have similar weights
        self.lm_head.weight = self.transformer['wte'].weight

        # Initialize parameters with mean 0 and standard deviation 0.02 because 1/sqrt(768), 1/sqrt(1600)
        self.apply(self._init_weights)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size, 
            f"Cannot forward sequence of length {T}, model block size is {self.config.block_size}"

        # IMP: Token and Positional Embeddings
        pos = torch.arange(0, T, device=idx.device, dtype=torch.long)
        pos_emb = self.transformer.wpe(pos)  #Positional Embeddings of shape (T, n_embed)
        tok_emb = self.transformer.wte(idx)  #Token Embeddings of shape (B, T, n_embed)
        x = tok_emb + pos_emb # broadcast along the batch dimension

        # Forward pass through each transformer block
        for block in self.transformer.h:
            x = block(x)
        
        # Final Linear layer
        x = self.transformer.ln_f(x)
        x = self.lm_head(x)

        # Loss function
        loss = None
        if targets is not None:
            loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
        return x, loss
            

Chinchilla Scaling Law: The Chinchilla Scaling law states that optimal performance is achieved when the number of training tokens should be at least 2 times the number of parameter of parameters. Since the number of tokens are 10 billion and the number of parameters are 124 million, this law is satisfied.

Attention Mechanisms

Attention mechanisms are used to give the model a sense of the importance of the tokens in the sequence. The attention mechanism is used to compute the attention scores for each token in the sequence. There are various types of attention mechanisms, such as normal attention, flash attention, linformer. We will compare these attention mechanisms in this project.

Normal Attention

Normal attention is the most basic attention mechanism, which computes the attention scores for each token in the sequence. It is defined as follows:

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__() 
        assert config.n_embed % config.n_head == 0
        
        # Query/key/value projections
        self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)
        
        # Output projection
        self.c_proj = nn.Linear(config.n_embed, config.n_embed)
        self.c_proj.NANOGPT_SCALE_INIT = 1.0

        # Configuration
        self.n_head = config.n_head
        self.n_embed = config.n_embed
        
        # Causal mask
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                            .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        
        # Split QKV
        q, k, v = self.c_attn(x).split(self.n_embed, dim=2)
        
        # Reshape for multi-head attention
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # Attention computation
        att = (q @ k.transpose(-2, -1)) * (1.0 / (k.size(-1) ** 0.5))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v
        
        return y
            

Flash Attention

Flash Attention is an optimized version of Normal attention. It is defined as follows:

# Attention mechanism
# att = (q @ k.transpose(-2, -1)) * (1.0 / ((k.size(-1)) ** 0.5))
# # Masked Attention
# att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
# att = F.softmax(att, dim=-1)
# y = att @ v

# Flash Attention
y = F.scaled_dot_product_attention(q, k, v, is_causal=True) # wow who knew flash attention was so easy to implement
            

Linformer Attention

Linformer attention is a type of attention mechanism that uses linear projections to reduce the computational complexity of the attention mechanism. It is defined as follows:

def get_EF(input_size, dim, method="no_params", head_dim=None, bias=True):
    '''Returns E/F matrix with xavier initialization'''
    assert method in ["learnable", "convolution", "no_params"], "Invalid method!"
    
    if method == "convolution":
        return nn.Conv1d(head_dim, head_dim, kernel_size=int(input_size/dim))
    
    elif method == "no_params":
        mat = torch.zeros((dim, input_size))
        torch.nn.init.normal_(mat, std=1/dim)
        return mat
    
    elif method == "learnable":
        lin = nn.Linear(input_size, dim, bias)
        torch.nn.init.xavier_normal_(lin.weight)
        return lin

def gen_causal_mask(input_size, dim_k, full_attention=False):
    return (torch.triu(torch.ones(input_size, input_size)) if full_attention 
            else torch.tril(torch.ones(input_size, dim_k))).bool()

class CausalSelfAttention(nn.Module):
    def __init__(self, config, dim_k=None, eps=0.4):
        super().__init__()
        assert config.n_embed % config.n_head == 0
        
        # Projections
        self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed)
        self.c_proj = nn.Linear(config.n_embed, config.n_embed)
        self.c_proj.NANOGPT_SCALE_INIT = 1.0

        # Config
        self.n_head = config.n_head
        self.n_embed = config.n_embed
        self.block_size = config.block_size
        self.eps = eps
        self.dim_k = dim_k or int(min(9*self.n_embed*np.log(self.n_embed)/self.eps**2, 
                                       5*np.log(self.block_size)/self.eps**2))

        # Linformer projections
        delta = 1/(2**self.block_size)
        self.causal_mask = gen_causal_mask(self.block_size, self.dim_k)
        self.E_proj = nn.Parameter(delta * torch.stack([get_EF(self.block_size, self.dim_k) 
                                  for _ in range(self.n_head)]))
        self.F_proj = nn.Parameter(np.exp(-delta) * torch.stack([get_EF(self.block_size, self.dim_k) 
                                  for _ in range(self.n_head)]))

        # Causal mask
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.n_embed, dim=2)
        
        # Reshape for multi-head attention
        q = q.view(B, T, self.n_head, -1).transpose(1, 2)
        k = k.view(B, T, self.n_head, -1).transpose(1, 2)
        v = v.view(B, T, self.n_head, -1).transpose(1, 2)

        # Project keys/values
        K_proj = self.E_proj.expand(B, -1, -1, -1) @ k
        V_proj = self.F_proj.expand(B, -1, -1, -1) @ v

        # Flash attention implementation
        y = F.scaled_dot_product_attention(q, K_proj, V_proj, is_causal=True)
        
        # Output projection
        return self.c_proj(y.transpose(1,2).contiguous().view(B,T,C))
            

Positional Encodings

Positional encodings are used to give the model a sense of the order of the tokens in the sequence. The positional encodings are added to the token embeddings before passing them through the transformer blocks. There are various types of positional encodings, such as sinusoidal, learned, and absolute positional encodings. We will compare these positional encodings in this project.

Learned Positional Encoding

Learned positional encoding is a type of positional encoding that is learned during training. It is defined as follows:

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Developing Transformer
        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(config.vocab_size, config.n_embed), # Token embedding
            'wpe': nn.Embedding(config.block_size, config.n_embed), # Positional embedding
            'h': nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            'ln_f': nn.LayerNorm(config.n_embed)
        })

        # Final Linear layer
        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)
        self.lm_head.weight = self.transformer['wte'].weight

        # Initialize parameters
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            std = 0.02
            if hasattr(module, 'NANOGPT_SCALE_INIT'):
                std *= (2 * self.config.n_layer) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size, f"Cannot forward sequence of length {T}, block size is {self.config.block_size}"

        # Embeddings
        pos = torch.arange(0, T, device=idx.device)
        pos_emb = self.transformer.wpe(pos)
        tok_emb = self.transformer.wte(idx)
        x = tok_emb + pos_emb

        # Transformer blocks
        for block in self.transformer.h:
            x = block(x)
        
        # Final output
        x = self.transformer.ln_f(x)
        x = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
        return x, loss
            

Sinusoidal Positional Encoding

Sinusoidal positional encoding is a fixed encoding that is based on the sine and cosine functions. It is defined as follows:

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Developing Transformer
        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(config.vocab_size, config.n_embed), # Token embedding weights
            'h': nn.ModuleList([Block(config) for _ in range(config.n_layer)]), # All transformer blocks
            'ln_f': nn.LayerNorm(config.n_embed)
        })
        ...

    def get_sinusoidal_encoding(self, T):
        position = torch.arange(0, T, dtype=torch.float).unsqueeze(1)
        div_term = 10000 ** (-2 * torch.arange(self.config.n_embed // 2) / self.config.n_embed)
        encoding = torch.zeros(T, self.config.n_embed)
        encoding[:, 0::2] = torch.sin(position * div_term)
        encoding[:, 1::2] = torch.cos(position * div_term)
        return encoding

    def forward(self, idx, targets=None):
        B, T = idx.size()
        assert T <= self.config.block_size, 
            f"Cannot forward sequence of length {T}, block size is {self.config.block_size}"

        # Token Embeddings
        tok_emb = self.transformer.wte(idx)

        # Positional Encodings
        pos_emb = self.get_sinusoidal_encoding(T).to(idx.device)
        pos_emb = pos_emb.unsqueeze(0).expand(B, -1, -1)

        x = tok_emb + pos_emb

        for block in self.transformer.h:
            x = block(x)
        
        x = self.transformer.ln_f(x)
        x = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.view(-1))
        return x, loss
            

RoPE (Rotary Positional Encoding)

Rotary Positional Encoding (RoPE) is a type of positional encoding that is based on the rotation of the token embeddings. It is defined as follows:

class RotaryPositionEmbeddings(nn.Module):
    '''Rotary Position Embeddings, as described in the RoPE paper'''
    def __init__(self, config, base=10_000):
        super().__init__()
        self.base = base
        self.dim = config.n_embed
        self.max_seq_len = config.block_size
        self.config = config
        self.rope_init()

    def rope_init(self):
        '''Initialize the RoPE cache with sin and cos values for each position.'''
        theta = torch.pow(self.base, -2 * torch.arange(0, self.dim // 2).float() / self.dim)
        self.register_buffer('theta', theta, persistent=False)
        self.build_rope_cache()

    def build_rope_cache(self):
        '''Build the RoPE cache for the given block size.'''
        seq_idx = torch.arange(self.max_seq_len, dtype=self.theta.dtype, device=self.theta.device)

        idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float()
        hs_half = self.config.n_embed // self.config.n_head // 2
        idx_theta = idx_theta[:, :hs_half]

        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
        self.register_buffer('cache', cache, persistent=False)

    def forward(self, x):
        b, seq_len, nh, hs = x.shape
        print(f"shape of x before reshaping in RoPE: {x.shape}")

        rope_cache = self.cache[:seq_len, :hs // 2].to(x.device)
        x = x.reshape(*x.shape[:-1], -1, 2)
        print(f"shape of x after reshaping in RoPE: {x.shape}")

        rope_cache = rope_cache.unsqueeze(0).unsqueeze(2)
        print(f"Shape of rope_cache in RoPE: {rope_cache.shape}")
        
        rotated = torch.stack([
            x[..., 0] * rope_cache[..., 0] - x[..., 1] * rope_cache[..., 1],
            x[..., 1] * rope_cache[..., 0] + x[..., 0] * rope_cache[..., 1]
        ], dim=-1)

        print(f"shape of rotated before flattening in RoPE: {rotated.shape}")
        print(f"shape of rotated after flattening in RoPE: {rotated.flatten(-2).shape}")
        return rotated.flatten(-2).type_as(x)
            

Kerple Positional Encoding

Kerple Positional Encoding is a type of positional encoding that is based on the kernel trick. To apply this transformation, we need to set the causal mask to be False and apply attention mask to be kerple mask if you are using flash attention. It is defined as follows:

class ClampedKerpleParameterModule(nn.Module):
    def __init__(self, num_heads, scale=1, min_value=1e-2):
        super().__init__()
        self.param = nn.Parameter(torch.rand(num_heads) * scale)
        self.min_value = min_value

    def forward(self, x):
        clamped_param = torch.clamp(self.param, min=self.min_value)
        return clamped_param.view(-1, 1, 1) * x

class KerplePositionalEncoding(nn.Module):
    def __init__(self, num_heads, block_size=1024, scale_r1=1, scale_r2=2, min_value=1e-2):
        super().__init__()
        self.r1 = ClampedKerpleParameterModule(num_heads, scale=scale_r1, min_value=min_value)
        self.r2 = ClampedKerpleParameterModule(num_heads, scale=scale_r2, min_value=min_value)
        self.n_head = num_heads
        self.block_size = block_size

        # Precompute distance matrix
        self.distance_matrix = torch.arange(self.block_size).view(-1, 1) - torch.arange(self.block_size).view(1, -1)
        self.distance_matrix = self.distance_matrix.abs().float().to(device)
        self.distance_matrix = self.distance_matrix.repeat(self.n_head, 1, 1)

    def forward(self, q_shape):
        applied_distance_matrix = self.distance_matrix
        if q_shape != self.block_size:
            pos = torch.arange(q_shape[2]).view(-1, 1) - torch.arange(q_shape[2]).view(1, -1)
            distance_matrix = pos.abs().float().to(device)
            distance_matrix = distance_matrix.repeat(self.n_head, 1, 1)
            applied_distance_matrix = distance_matrix
        
        distance_matrix = self.r1(applied_distance_matrix)
        distance_matrix = -self.r2(1 + torch.log(1 + distance_matrix))

        mask = torch.ones(q_shape[2], q_shape[2]).tril(diagonal=0).repeat(self.n_head, 1, 1)
        distance_matrix = distance_matrix.masked_fill(mask.logical_not().to(device), float('-inf'))
        
        return distance_matrix
            

Alibi Positional Encoding

Alibi Positional Encoding is a type of positional encoding that is based on the attention mechanism. To apply this transformation, similar to kerple positional encoding, we need to set the causal mask to be False and apply attention mask to be alibi mask if you are using flash attention. It is defined as follows:

class CausalSelfAttention(nn.Module):
    def alibi_mask(self, q_shape):
        # Create alibi_mask: q* k_T + m * [a - b]
        # [a - b] is lower triangular matrix with negative values
        pos = - torch.arange(q_shape[2]).view(-1, 1) + torch.arange(q_shape[2]).view(1, -1)
        pos = pos.float().masked_fill(pos > 0, float('-inf'))  # Shape: (block_size, block_size)
        head_m_value = torch.tensor([2.0**(-8/i) for i in torch.linspace(1, 8, self.n_head)])
        return head_m_value.view(-1, 1, 1) * pos
            

FIRE Positional Encoding

FIRE Positional Encoding is a type of positional encoding that is based on the Fourier Transform. To apply this positional encoding, we need to do similar to Kerple and Alibi positional encoding in the attention section which is to set the causal mask to be False and apply attention mask to be fire mask. It is defined as follows:

class FIRE(nn.Module):
    def __init__(self, num_heads=12, mlp_width=32, init_c=0.1, init_L=512., eps=1e-6):
        super(FIRE, self).__init__()

        # Define MLP layers
        self.mlp = nn.Sequential(
            nn.Linear(1, mlp_width),
            nn.ReLU(),
            nn.Linear(mlp_width, num_heads)
        )

        # Initialize parameters
        self.c = nn.Parameter(torch.tensor(init_c))
        self.init_L = nn.Parameter(torch.tensor(init_L), requires_grad=False)
        self.L_multiplier = nn.Parameter(torch.tensor(1.0))
        self.eps = eps

    def forward(self, x: torch.Tensor):
        seq_length = x.size(2)
        positions = torch.arange(seq_length, dtype=torch.float, device=x.device)
        rel_distance = positions[:, None] - positions[None, :]

        # Thresholding and normalization
        threshold = torch.abs(self.L_multiplier * self.init_L)
        pos_normalizer = torch.max(positions, threshold).unsqueeze(1)

        # Log transformations
        rel_distance = torch.log(torch.abs(self.c * rel_distance) + 1)
        pos_normalizer = torch.log(torch.abs(self.c * pos_normalizer) + 1) + self.eps

        # Compute FIRE bias
        normalized_distance = rel_distance / pos_normalizer
        fire_bias = self.mlp(normalized_distance.unsqueeze(-1))
        fire_bias = fire_bias.permute(2, 0, 1)
        
        # Apply causal mask
        mask = torch.ones(seq_length, seq_length).tril(diagonal=0).repeat(fire_bias.shape[0], 1, 1)
        fire_bias = fire_bias.masked_fill(mask.logical_not().to(device), float('-inf')).unsqueeze(0)
        
        return fire_bias
            

References

2019

  1. OAB
    Language models are unsupervised multitask learners
    Radford, Alec and Wu, Jeffrey and Child, Rewon and Luan, David and Amodei, Dario and Sutskever, Ilya and others
    OpenAI blog