-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
157 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
class MultiHeadAttn(nn.Module): | ||
def __init__(self, d_in, d_out, ctx_len, dropout, num_heads, qkv_bias=False): | ||
super().__init__() | ||
assert d_out % num_heads == 0, "d_out must be divisible by num_heads" | ||
self.d_out = d_out | ||
self.num_heads = num_heads | ||
self.head_dim = d_out // num_heads | ||
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) | ||
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) | ||
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) | ||
self.out_proj = nn.Linear(d_out, d_out) | ||
self.dropout = nn.Dropout(dropout) | ||
self.register_buffer('mask', torch.triu(torch.ones(ctx_len, ctx_len), diagonal=1)) | ||
|
||
def forward(self, x): | ||
b, num_tokens, d_in = x.shape | ||
keys = self.W_key(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) | ||
queries = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) | ||
values = self.W_value(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) | ||
attn_scores = queries @ keys.transpose(2, 3) | ||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens] | ||
attn_scores.masked_fill_(mask_bool, -torch.inf) | ||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) | ||
attn_weights = self.dropout(attn_weights) | ||
context_vec = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out) | ||
return self.out_proj(context_vec) | ||
|
||
class FeedForward(nn.Module): | ||
def __init__(self, cfg): | ||
super().__init__() | ||
self.layers = nn.Sequential( | ||
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]), | ||
GELU(), | ||
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]), | ||
) | ||
|
||
def forward(self, x): | ||
return self.layers(x) | ||
|
||
class GELU(nn.Module): | ||
def forward(self, x): | ||
return 0.5 * x * (1 + torch.tanh( | ||
torch.sqrt(torch.tensor(2.0 / torch.pi)) * (x + 0.044715 * x**3) | ||
)) | ||
|
||
class LayerNorm(nn.Module): | ||
def __init__(self, emb_dim): | ||
super().__init__() | ||
self.eps = 1e-5 | ||
self.scale = nn.Parameter(torch.ones(emb_dim)) | ||
self.shift = nn.Parameter(torch.zeros(emb_dim)) | ||
|
||
def forward(self, x): | ||
mean = x.mean(dim=-1, keepdim=True) | ||
var = x.var(dim=-1, keepdim=True, unbiased=False) | ||
norm_x = (x - mean) / torch.sqrt(var + self.eps) | ||
return self.scale * norm_x + self.shift | ||
|
||
class TransformerBlock(nn.Module): | ||
def __init__(self, cfg): | ||
super().__init__() | ||
self.att = MultiHeadAttn( | ||
d_in=cfg["emb_dim"], | ||
d_out=cfg["emb_dim"], | ||
ctx_len=cfg["context_length"], | ||
num_heads=cfg["n_heads"], | ||
dropout=cfg["drop_rate"], | ||
qkv_bias=cfg["qkv_bias"]) | ||
self.ff = FeedForward(cfg) | ||
self.norm1 = LayerNorm(cfg["emb_dim"]) | ||
self.norm2 = LayerNorm(cfg["emb_dim"]) | ||
self.drop_resid = nn.Dropout(cfg["drop_rate"]) | ||
|
||
def forward(self, x): | ||
save = x | ||
x = self.norm1(x) | ||
x = self.att(x) | ||
x = self.drop_resid(x) | ||
x += save | ||
save = x | ||
x = self.norm2(x) | ||
x = self.ff(x) | ||
x = self.drop_resid(x) | ||
x += save | ||
return x | ||
|
||
class GPT(nn.Module): | ||
def __init__(self, cfg): | ||
super().__init__() | ||
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) | ||
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) | ||
self.drop_emb = nn.Dropout(cfg["drop_rate"]) | ||
self.trf_blocks = nn.Sequential( | ||
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])] | ||
) | ||
self.final_norm = LayerNorm(cfg["emb_dim"]) | ||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) | ||
|
||
def forward(self, in_idx): | ||
batch_size, seq_len = in_idx.shape | ||
tok_embeds = self.tok_emb(in_idx) | ||
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) | ||
x = tok_embeds + pos_embeds | ||
x = self.drop_emb(x) | ||
x = self.trf_blocks(x) | ||
x = self.final_norm(x) | ||
return self.out_head(x) | ||
|
||
def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None): | ||
for _ in range(max_new_tokens): | ||
idx_cond = idx[:, -context_size:] | ||
with torch.no_grad(): | ||
logits = model(idx_cond) | ||
logits = logits[:, -1, :] | ||
|
||
if top_k is not None: | ||
top_logits, _ = torch.topk(logits, top_k) | ||
min_val = top_logits[:, -1] | ||
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits) | ||
|
||
if temperature > 0.0: | ||
logits = logits / temperature | ||
probs = torch.softmax(logits, dim=-1) | ||
idx_next = torch.multinomial(probs, num_samples=1) | ||
else: | ||
idx_next = torch.argmax(logits, dim=-1, keepdim=True) | ||
|
||
if eos_id is not None and idx_next == eos_id: | ||
break | ||
|
||
idx = torch.cat((idx, idx_next), dim=1) | ||
|
||
return idx | ||
|
||
def generate_and_print_sample(model, tokenizer, device, start_context): | ||
model.eval() | ||
context_size = model.pos_emb.weight.shape[0] | ||
encoded = text_to_token_ids(start_context, tokenizer).to(device) | ||
with torch.no_grad(): | ||
token_ids = generate(model=model, idx=encoded, max_new_tokens=100, context_size=context_size) | ||
decoded_text = token_ids_to_text(token_ids, tokenizer) | ||
model.train() | ||
return decoded_text | ||
|
||
def text_to_token_ids(text, tokenizer): | ||
encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"}) | ||
encoded_tensor = torch.tensor(encoded).unsqueeze(0) | ||
return encoded_tensor | ||
|
||
def token_ids_to_text(token_ids, tokenizer): | ||
flat = token_ids.squeeze(0) | ||
return tokenizer.decode(flat.tolist()) |