Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ROPE positional encodings #450

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,28 @@ def __init__(self, config):
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def apply_rotary_position_embeddings(self, sinusoidal_pos, q, k):
# Split the sinusoidal_pos into sin and cos parts
sin, cos = sinusoidal_pos.chunk(2, dim=-1)
# Apply the rotary embeddings to the query and key
q_rot = torch.stack((-q[..., 1::2], q[..., ::2]), dim=-1)
k_rot = torch.stack((-k[..., 1::2], k[..., ::2]), dim=-1)
q_rot = torch.reshape(q_rot, q.shape[:-1] + (q.shape[-1]//2, 2)) * torch.stack((cos, sin), dim=-1)
k_rot = torch.reshape(k_rot, k.shape[:-1] + (k.shape[-1]//2, 2)) * torch.stack((cos, sin), dim=-1)
q_rot = torch.reshape(q_rot, q.shape)
k_rot = torch.reshape(k_rot, k.shape)
return q_rot, k_rot

def get_sinusoidal_embeddings(self, n_positions, dim):
"""Generate sinusoidal positional embeddings."""
position = torch.arange(n_positions, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
sinusoidal_emb = torch.zeros((n_positions, dim))
sinusoidal_emb[:, 0::2] = torch.sin(position * div_term)
sinusoidal_emb[:, 1::2] = torch.cos(position * div_term)
return sinusoidal_emb.to(self.device)

def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
Expand All @@ -58,13 +80,17 @@ def forward(self, x):
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

# apply rotary position embeddings
sinusoidal_pos = self.get_sinusoidal_embeddings(T, self.n_embd // self.n_head)
q_rot, k_rot = self.apply_rotary_position_embeddings(sinusoidal_pos, q, k)

# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
y = torch.nn.functional.scaled_dot_product_attention(q_rot, k_rot, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = (q_rot @ k_rot.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
wandb_project = 'owt'
wandb_run_name = 'gpt2' # 'run' + str(time.time())
# data
dataset = 'openwebtext'
dataset = 'shakespeare_char'
gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes
batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size
block_size = 1024
Expand Down