import torch
import torch.nn as nn
My implementation of the Transformer model proposed in the original Transformer paper [1] in PyTorch.
Image source: Vaswani et al. [1]
The components of the Transformer architecture
Attention
Scaled dot-product attention
It’s computed as a weighted sum of the values,
\[ \mathrm{Attention}(Q, K, V) = softmax(\frac{QK^\top}{\sqrt{d_k}})V \]
def scaled_dot_product_attention(q, k, v, mask, dropout):
# q: (batch_size, n_heads, q_length, dim_key)
# k: (batch_size, n_heads, k_length, dim_key)
# v: (batch_size, n_heads, k_length, dim_value)
# mask:
# encoder self-attn: (batch_size, 1, 1, k_length)
# decoder self-attn: (batch_size, 1, q_length, k_length)
# decoder cross-attn: (batch_size, 1, 1, k_length)
= torch.matmul(
scores -2, -1)
q, k.transpose(/ torch.math.sqrt(k.size(-1)) # (batch_size, n_heads, q_length, k_length)
)
if mask is not None:
= scores.masked_fill(mask==0, value=-1e10)
scores
= scores.softmax(dim=-1) # (batch_size, n_heads, q_length, k_length)
attention_weights
if dropout is not None:
= dropout(attention_weights)
attentions = torch.matmul(attentions, v) # (batch_size, n_heads, q_length, dim_value)
attentions
return attentions, attention_weights
Multi-head attention
Multi-head attention computes attention functions on multiple projections of the input queries, keys and values. This allows the model to capture different features of the input data from different subspaces and at different locations. The output of each attention function is then concatenated and projected to produce the final values.
\[ \mathrm{MultiHead}(Q,K,V)=\mathrm{Concat}(head_1, \dots, head_h)W^O \] \[ where\ \mathrm{head_i}=\mathrm{Attention}(QW_i^Q,KW_i^K,VW_i^V) \] \[ W_i^Q \in \mathbb{R}^{d_{model} \times d_k}, W_i^K \in \mathbb{R}^{d_{model} \times d_k}, W_i^V \in \mathbb{R}^{d_{model} \times d_v}, W_i^O \in \mathbb{R}^{hd_v \times d_{model}} \]
class MultiheadAttention(nn.Module):
def __init__(self, n_heads, dim_model, dim_value, dropout_p):
super().__init__()
self.dim_model = dim_model
self.n_heads = n_heads
self.dim_value = dim_value
# dim_model should be divisible by n_heads
# the values from the paper are: dim_value=512, n_heads=8
assert dim_model % n_heads == 0, 'Invalid values: dim_model should be divisible by n_heads.'
self.dim_key = dim_model // n_heads
# q, k, v projections
self.wq = nn.Linear(dim_model, dim_model)
self.wk = nn.Linear(dim_model, dim_model)
self.wv = nn.Linear(dim_model, n_heads*dim_value)
self.wo = nn.Linear(n_heads*dim_value, dim_model)
self.dropout = nn.Dropout(dropout_p)
def forward(self, q, k, v, mask):
= self.wq(q) # (batch_size, q_length, dim_model)
query
# key and value will have the same sequence length
= self.wk(k) # (batch_size, k_length, dim_model)
key = self.wv(v) # (batch_size, k_length, dim_model)
value
= query.view(
query 0), query.size(1), self.n_heads, self.dim_key
query.size(1, 2) # (batch_size, n_heads, q_length, dim_key)
).transpose(= key.view(
key 0), key.size(1), self.n_heads, self.dim_key
key.size(1, 2) # (batch_size, n_heads, k_length, dim_key)
).transpose(= value.view(
value 0), value.size(1), self.n_heads, self.dim_value
value.size(1, 2) # (batch_size, n_heads, k_length, dim_value)
).transpose(
= scaled_dot_product_attention(
attentions, attention_weights self.dropout
query, key, value, mask,
)
= attentions.transpose(1, 2).contiguous() # (batch_size, q_length, n_heads, dim_value)
attentions = attentions.view(
attentions 0), attentions.size(1), -1
attentions.size(# (batch_size, q_length, n_heads*dim_value)
) = self.wo(attentions)
attentions
# attentions: (batch_size, q_length, dim_model)
# attention_weights: (batch_size, n_heads, q_length, k_length)
return attentions, attention_weights
Position-wise Feed-Forward Networks
Contains 2 linear layers with a ReLU activation in between, \[ \mathrm{FFN}(x)=\max(0, xW_1 + b_1) W_2 + b_2 \]
class FFN(nn.Module):
def __init__(self, dim_model, dim_ffn, dropout_p):
super().__init__()
self.linear1 = nn.Linear(dim_model, dim_ffn)
self.linear2 = nn.Linear(dim_ffn, dim_model)
self.dropout = nn.Dropout(dropout_p)
def forward(self, x):
= self.linear1(x)
x = torch.relu(x)
x = self.dropout(x)
x = self.linear2(x)
x return x
Embedding & Positional Encoding
Embeddings are learned. Embedding weights are multiplied by \(\sqrt{d_{model}}\) .
class Embeddings(nn.Module):
def __init__(self, dim_model, vocab_size):
super().__init__()
self.dim_model = dim_model
self.embedding = nn.Embedding(vocab_size, dim_model)
def forward(self, x):
# x: (batch_size, seq_length)
= self.embedding(x) # (batch_size, seq_length, dim_model)
x = x * torch.math.sqrt(self.dim_model)
x return x
Positional encoding is used to inject information about positions of the tokens in the sequence. It’s computed using a combination of sine and cosine functions, \[ \begin{aligned} PE_{(pos, 2i)} &=sin(pos/10000^{2i/d_{model}}) \\ PE_{(pos, 2i+1)} &=cos(pos/10000^{2i/d_{model}}) \end{aligned} \] where \(pos\) is the position, and \(i\) is the dimension.
Since it’s more efficient to compute power in log space, the term of denominator can be calculated as, \[ \begin{aligned} u &= e^{ln(u)}, \mathrm{where}\ u= 1/10000^{2i/d_{model}} \\ ln(u) &= 2i(-ln(10000)/d_{model}) \\ u &= e^{2i(-ln(10000)/d_{model})} \end{aligned} \]
class PositionalEncoding(nn.Module):
def __init__(self, max_seq_length, dim_model, dropout_p):
super().__init__()
self.dropout = nn.Dropout(dropout_p)
= torch.zeros(max_seq_length, dim_model) # (max_seq_length, dim_model)
pe = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1) # (max_seq_length, 1)
positions = torch.exp(
u 0, dim_model, 2).float() * (-torch.math.log(10000.0) / dim_model)
torch.arange(# (max_seq_length,)
) 0::2] = torch.sin(positions * u)
pe[:, 1::2] = torch.cos(positions * u)
pe[:, = pe.unsqueeze(0) # (1, max_seq_length, dim_model)
pe
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch_size, seq_length, dim_model)
= x + (self.pe[:, :x.size(1), :]).requires_grad_(False)
x return self.dropout(x) # (batch_size, seq_length, dim_model)
Add & Norm
This is done by applying a residual connection around sublayers (multihead attention layer, FFN layer), and then a layer normalization, i.e., \[\mathrm{LayerNorm}(x+\mathrm{Sublayer}(x))\]
Layer normalization is computed as, \[
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
\] where \(\gamma\) and \(\beta\) are learnable parameters. The PyTorch implementation is nn.LayerNorm
.
class AddNorm(nn.Module):
def __init__(self, normalized_shape, dropout_p):
super().__init__()
self.layernorm = nn.LayerNorm(normalized_shape=normalized_shape)
self.dropout = nn.Dropout(dropout_p)
def forward(self, x, sublayer_output):
return self.layernorm(x + self.dropout(sublayer_output))
Encoder
Encoder creates a sequence of context vectors corresponding to input tokens.
class EncoderBlock(nn.Module):
def __init__(self, self_attn, ffn, dropout_p):
super().__init__()
self.self_attn = self_attn
self.ffn = ffn
self.add_norm = nn.ModuleList(
for _ in range(2)]
[AddNorm(self_attn.dim_model, dropout_p)
)
def forward(self, x, src_mask):
# x: (batch_size, src_length, dim_model)
# src_mask: (batch_size, 1, 1, src_length)
= self.self_attn(x, x, x, src_mask)
sublayer_output, _ = self.add_norm[0](x, sublayer_output) # (batch_size, src_length, dim_model)
x = self.add_norm[1](x, self.ffn(x)) # (batch_size, src_length, dim_model)
x return x
class Encoder(nn.Module):
def __init__(self, N, n_heads, dim_model, dim_value, dim_ffn, dropout_p):
super().__init__()
self.blocks = nn.ModuleList()
for _ in range(N):
= EncoderBlock(
encoder_block
MultiheadAttention(n_heads, dim_model, dim_value, dropout_p),
FFN(dim_model, dim_ffn, dropout_p),
dropout_p
)self.blocks.append(encoder_block)
def forward(self, x, src_mask):
# x: (batch_size, src_length, dim_model)
# src_mask: (batch_size, 1, 1, src_length)
for block in self.blocks:
= block(x, src_mask)
x return x # (batch_size, src_length, dim_model)
Decoder
Decoder takes the context vectors from the encoder to produce output tokens. It has a masked self-attention layer and also a cross-attention layer over the encoder output.
class DecoderBlock(nn.Module):
def __init__(self, self_attn, cross_attn, ffn, dropout_p):
super().__init__()
self.self_attn = self_attn
self.cross_attn = cross_attn
self.ffn = ffn
self.add_norm = nn.ModuleList(
for _ in range(3)]
[AddNorm(self_attn.dim_model, dropout_p)
)
def forward(self, x, tgt_mask, encoder_memory, src_mask):
# x: (batch_size, tgt_length, dim_model)
# tgt_mask: (batch_size, 1, tgt_length, tgt_length)
# encoder_memory: (batch_size, src_length, dim_model)
# src_mask: (batch_size, 1, 1, src_length)
= self.self_attn(x, x, x, tgt_mask)
sublayer_output, _ = self.add_norm[0](x, sublayer_output) # (batch_size, tgt_length, dim_model)
x
= self.cross_attn(x, encoder_memory, encoder_memory, src_mask)
sublayer_output, attention_weights = self.add_norm[1](x, sublayer_output) # (batch_size, tgt_length, dim_model)
x
= self.add_norm[2](x, self.ffn(x))
x
# attention_weights: (batch_size, n_heads, tgt_length, src_length)
# x: (batch_size, tgt_length, dim_model)
return x, attention_weights
class Decoder(nn.Module):
def __init__(self, N, n_heads, dim_model, dim_value, dim_ffn, dropout_p):
super().__init__()
self.blocks = nn.ModuleList()
for _ in range(N):
= DecoderBlock(
decoder_block
MultiheadAttention(n_heads, dim_model, dim_value, dropout_p),
MultiheadAttention(n_heads, dim_model, dim_value, dropout_p),
FFN(dim_model, dim_ffn, dropout_p),
dropout_p
)self.blocks.append(decoder_block)
def forward(self, x, tgt_mask, encoder_memory, src_mask):
for block in self.blocks:
= block(x, tgt_mask, encoder_memory, src_mask)
x, attention_weights
return x, attention_weights
Model head
A learned linear projection.
class ModelHead(nn.Module):
def __init__(self, dim_model, vocab_size):
super().__init__()
self.linear = nn.Linear(dim_model, vocab_size)
def forward(self, x):
return self.linear(x)
The Transformer model
Put everything together.
class Transformer(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, model_head, src_pad_token_id, tgt_pad_token_id):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.src_embed = src_embed
self.tgt_embed = tgt_embed
self.model_head = model_head
self.src_pad_token_id = src_pad_token_id
self.tgt_pad_token_id = tgt_pad_token_id
def get_src_mask(self, src_input):
# src_input: (batch_size, src_length)
= (
src_mask != self.src_pad_token_id
src_input 1).unsqueeze(1).int()
).unsqueeze(
return src_mask # src_mask: (batch_size, 1, 1, src_length)
def get_tgt_mask(self, tgt_input):
# tgt_input: (batch_size, tgt_length)
= (
tgt_padding_mask != self.tgt_pad_token_id
tgt_input 1).unsqueeze(1).int() # tgt_padding_mask: (batch_size, 1, 1, tgt_length)
).unsqueeze(
= torch.tril(
tgt_decoding_mask 1), tgt_input.size(1))
torch.ones(tgt_input.size(0).unsqueeze(0).int().to(tgt_input.device) # tgt_decoding_mask: (1, 1, tgt_length, tgt_length)
).unsqueeze(
= tgt_padding_mask & tgt_decoding_mask
tgt_mask
return tgt_mask # tgt_mask: (batch_size, 1, tgt_length, tgt_length)
def forward(self, src_input, tgt_input):
# src_input: (batch_size, src_length)
# tgt_input: (batch_size, tgt_length)
= self.get_src_mask(src_input)
src_mask = self.encoder(
encoder_output self.src_embed(src_input), # (batch_size, src_length, dim_model)
# (batch_size, 1, 1, src_length)
src_mask # (batch_size, src_length, dim_model)
)
= self.get_tgt_mask(tgt_input)
tgt_mask = self.decoder(
decoder_output, attention_weights self.tgt_embed(tgt_input), # (batch_size, tgt_length, dim_model)
# (batch_size, 1, tgt_length, tgt_length)
tgt_mask, # (batch_size, src_length, dim_model)
encoder_output, # (batch_size, 1, 1, src_length)
src_mask # (batch_size, tgt_length, dim_model)
)
= self.model_head(decoder_output) # (batch_size, tgt_length, tgt_vocab_size)
output
return output, attention_weights
Training
1. load & prepare data for training
1.1. load data
Use the Multi30k dataset [2], which contains translations from English to German. It has 3 splits: train, validation and test.
from datasets import load_dataset
= load_dataset('bentrevett/multi30k', split=['train', 'validation', 'test'])
train_ds, val_ds, test_ds print(train_ds, val_ds, test_ds)
0] train_ds[
Dataset({
features: ['en', 'de'],
num_rows: 29000
}) Dataset({
features: ['en', 'de'],
num_rows: 1014
}) Dataset({
features: ['en', 'de'],
num_rows: 1000
})
{'en': 'Two young, White males are outside near many bushes.',
'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.'}
1.2. init tokenizers & build vocabularies
Spacy
is a good choice for multilingual tokenization.
import spacy
import torchtext.transforms as T
from torchtext.vocab import build_vocab_from_iterator
# languages
= 'en'
EN = 'de'
DE # special tokens
= '<pad>'
PAD = '<unk>'
UNK = '<sos>'
SOS = '<eos>'
EOS
= {EN: spacy.load('en_core_web_sm').tokenizer,
tokenizers 'de_core_news_sm').tokenizer} DE: spacy.load(
def tokenize_sentence(sentence, tokenizers, lang):
return [token.text for token in tokenizers[lang](sentence)]
def get_tokens(ds, tokenizers, lang):
for sample in ds:
yield tokenize_sentence(sample[lang], tokenizers, lang)
Tokenize the first sample.
print(next(get_tokens(train_ds, tokenizers, EN)))
print(next(get_tokens(train_ds, tokenizers, DE)))
['Two', 'young', ',', 'White', 'males', 'are', 'outside', 'near', 'many', 'bushes', '.']
['Zwei', 'junge', 'weiße', 'Männer', 'sind', 'im', 'Freien', 'in', 'der', 'Nähe', 'vieler', 'Büsche', '.']
Build vocabularies for both English and German.
= build_vocab_from_iterator(
en_vocab
get_tokens(train_ds, tokenizers, EN),=2,
min_freq=[PAD, UNK, SOS, EOS])
specials
en_vocab.set_default_index(en_vocab[UNK])
= build_vocab_from_iterator(
de_vocab
get_tokens(train_ds, tokenizers, DE),=2,
min_freq=[PAD, UNK, SOS, EOS])
specials
de_vocab.set_default_index(de_vocab[UNK])
= {EN: en_vocab, DE: de_vocab}
vocabs print(f'EN vocab size: {len(en_vocab)}\nDE vocab size: {len(de_vocab)}')
EN vocab size: 6191
DE vocab size: 8014
1.3. tokenize and transform data
def tokenize_data(vocabs, tokenizers):
= dict()
transforms for lang, vocab in vocabs.items():
= T.Sequential(
transforms[lang]
T.VocabTransform(vocab),=True),
T.AddToken(vocab[SOS], begin=False)
T.AddToken(vocab[EOS], begin
)
def process(sample):
= dict()
result for lang, sentence in sample.items():
= transforms[lang](
result[lang] for token in tokenizers[lang](sentence)]
[token.text
)return result
return process
= train_ds.map(tokenize_data(vocabs, tokenizers))
tokenized_train_ds = val_ds.map(tokenize_data(vocabs, tokenizers))
tokenized_val_ds = test_ds.map(tokenize_data(vocabs, tokenizers)) tokenized_test_ds
print(tokenized_train_ds, tokenized_val_ds, tokenized_test_ds)
Dataset({
features: ['en', 'de'],
num_rows: 29000
}) Dataset({
features: ['en', 'de'],
num_rows: 1014
}) Dataset({
features: ['en', 'de'],
num_rows: 1000
})
The first training sample.
0] tokenized_train_ds[
{'en': [2, 19, 25, 15, 1169, 808, 17, 57, 84, 336, 1339, 5, 3],
'de': [2, 21, 85, 257, 31, 87, 22, 94, 7, 16, 112, 7910, 3209, 4, 3]}
Decode token ids to words.
def decode_tokens(ids, vocab):
= vocab.get_itos()
words_list = [words_list[id] for id in ids]
s return ' '.join(s)
0][EN], vocabs[EN]) decode_tokens(tokenized_train_ds[
'<sos> Two young , White males are outside near many bushes . <eos>'
0][DE], vocabs[DE]) decode_tokens(tokenized_train_ds[
'<sos> Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche . <eos>'
1.4. init dataloaders
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
def get_collate_fn(vocabs):
"""
Dynamically padding to the max length of the batch.
"""
def collate_fn(batch):
= list(batch[0].keys())
langs = langs[0]
src_lang = langs[1]
tgt_lang
= []
src_batch = []
tgt_batch = []
label_batch for row in batch:
src_batch.append(torch.tensor(row[src_lang]))-1])
tgt_batch.append(torch.tensor(row[tgt_lang])[:1:])
label_batch.append(torch.tensor(row[tgt_lang])[
= pad_sequence(src_batch, batch_first=True, padding_value=vocabs[src_lang][PAD])
src_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=vocabs[tgt_lang][PAD])
tgt_batch = pad_sequence(label_batch, batch_first=True, padding_value=vocabs[tgt_lang][PAD])
label_batch
return src_batch, tgt_batch, label_batch
return collate_fn
= DataLoader(tokenized_train_ds, batch_size=64, shuffle=True, collate_fn=get_collate_fn(vocabs))
train_dataloader = DataLoader(tokenized_val_ds, batch_size=64, shuffle=True, collate_fn=get_collate_fn(vocabs)) val_dataloader
Take a quick look at a sample
= next(iter(train_dataloader))
batch
print(f'Source: {decode_tokens(batch[0][0], vocabs[EN])}\n')
print(f'Target: {decode_tokens(batch[1][0], vocabs[DE])}\n')
print(f'Label: {decode_tokens(batch[2][0], vocabs[DE])}')
Source: <sos> A woman looks down from a high point above a calm blue ocean . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
Target: <sos> Eine Frau blickt von einem hohen Aussichtspunkt über den ruhigen blauen Ozean . <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
Label: Eine Frau blickt von einem hohen Aussichtspunkt über den ruhigen blauen Ozean . <eos> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
2. train
from tqdm import tqdm
def build_model(src_vocab_size,
tgt_vocab_size,
max_seq_length,
src_pad_token_id,
tgt_pad_token_id,=512,
dim_model=64,
dim_value=2048,
dim_ffn=8,
n_heads=6,
N=0.1):
dropout_p= nn.Sequential(
src_embed
Embeddings(dim_model, src_vocab_size),
PositionalEncoding(max_seq_length, dim_model, dropout_p)
)= nn.Sequential(
tgt_embed
Embeddings(dim_model, tgt_vocab_size),
PositionalEncoding(max_seq_length, dim_model, dropout_p)
)= Encoder(N, n_heads, dim_model, dim_value, dim_ffn, dropout_p)
encoder = Decoder(N, n_heads, dim_model, dim_value, dim_ffn, dropout_p)
decoder = ModelHead(dim_model, tgt_vocab_size)
model_head
= Transformer(encoder, decoder, src_embed, tgt_embed, model_head, src_pad_token_id, tgt_pad_token_id)
model
# initialize parameters with Xavier uniform.
for p in model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
return model
def run_one_epoch(model, dataloader, optimizer, loss_fn, device):
if device.type=='cuda':
torch.cuda.empty_cache()
model.train()
= 0
total_loss for batch in dataloader:
= batch[0].to(device)
src_input = batch[1].to(device)
tgt_input = batch[2].to(device)
labels
optimizer.zero_grad()= model(src_input, tgt_input)
output, _ = loss_fn(
loss -1, output.size(-1)),
output.view(-1)
labels.view(
)
loss.backward()
optimizer.step()
+= loss.item()
total_loss 'training loss': f'{loss.item():8.4f}'})
dataloader.set_postfix({
return total_loss / len(dataloader)
def evaluation(model, dataloader, loss_fn, device):
eval()
model.
= 0
total_loss with torch.no_grad():
for batch in dataloader:
= batch[0].to(device)
src_input = batch[1].to(device)
tgt_input = batch[2].to(device)
labels
= model(src_input, tgt_input)
output, _ = loss_fn(
loss -1, output.size(-1)),
output.view(-1)
labels.view(
)
+= loss.item()
total_loss
return total_loss / len(dataloader)
# find out the max sequence length that is required by the model's positional encoding
def get_max_length(datasets):
= 0
max_len for ds in datasets:
for sample in ds:
for _, ids in sample.items():
= max(max_len, len(ids))
max_len return max_len
= get_max_length([tokenized_train_ds, tokenized_val_ds, tokenized_test_ds])
max_seq_length max_seq_length
46
= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
= build_model(
model len(vocabs[EN]),
len(vocabs[DE]),
max_seq_length,
vocabs[EN][PAD],
vocabs[DE][PAD]
).to(device)
= torch.optim.Adam(model.parameters(), lr=1e-4, eps=1e-9)
optimizer = nn.CrossEntropyLoss(
loss_fn =vocabs[DE][PAD],
ignore_index=0.1
label_smoothing
).to(device)
def total_params(model):
return sum([param.numel() for param in model.parameters() if param.requires_grad])
print(f'Total parameters: {total_params(model):,}')
Total parameters: 55,522,638
= 6
epochs
= []
training_losses = []
val_losses for epoch in range(epochs):
= tqdm(train_dataloader, desc=f'epoch {epoch:02d}')
train_iterator = run_one_epoch(model, train_iterator, optimizer, loss_fn, device)
avg_loss
training_losses.append(avg_loss)# get validation loss
= evaluation(model, val_dataloader, loss_fn, device)
avg_loss val_losses.append(avg_loss)
epoch 00: 100%|██████████| 454/454 [00:26<00:00, 17.33it/s, training loss=5.2297]
epoch 01: 100%|██████████| 454/454 [00:26<00:00, 17.14it/s, training loss=4.4458]
epoch 02: 100%|██████████| 454/454 [00:26<00:00, 17.10it/s, training loss=4.2431]
epoch 03: 100%|██████████| 454/454 [00:26<00:00, 17.07it/s, training loss=3.8277]
epoch 04: 100%|██████████| 454/454 [00:26<00:00, 17.08it/s, training loss=3.3935]
epoch 05: 100%|██████████| 454/454 [00:26<00:00, 17.03it/s, training loss=3.3311]
import matplotlib.pyplot as plt
='training loss')
plt.plot(training_losses, label='val loss')
plt.plot(val_losses, label
plt.legend(); plt.show()
Inference
During inference, tokens are generated one at a time.
def translate(sample, max_gen_length=50):
= tokenize_data(vocabs, tokenizers)
process_fn = process_fn(sample)
sample_ids
eval()
model.
# compute context vectors
= sample_ids[EN]
src_input = torch.LongTensor(src_input).unsqueeze(0).to(device)
src_input = model.get_src_mask(src_input)
src_mask with torch.no_grad():
= model.encoder(model.src_embed(src_input), src_mask)
encoder_output
# generate output tokens
= []
translation = [vocabs[DE][SOS]]
tgt_input = torch.LongTensor(tgt_input).unsqueeze(0).to(device)
tgt_input for _ in range(max_gen_length):
= model.get_tgt_mask(tgt_input)
tgt_mask with torch.no_grad():
= model.decoder(
decoder_output, attention_weights
model.tgt_embed(tgt_input),
tgt_mask,
encoder_output,
src_mask
)= model.model_head(decoder_output)
output
= output.argmax(-1)[:, -1] # the last token is the predicted
pred if pred.item() == vocabs[DE][EOS]:
# reached <eos> token
break
translation.append(pred.item())= torch.cat((tgt_input, pred.unsqueeze(0)), dim=-1)
tgt_input return translation, attention_weights
= vocabs[DE].get_itos() de_words_list
= test_ds[0]
sample
= translate(sample)
translation, attention_weights = ' '.join([de_words_list[id] for id in translation])
translation
print(f'label: {sample[DE]}\npreds: {translation}')
label: Ein Mann mit einem orangefarbenen Hut, der etwas anstarrt.
preds: Ein Mann mit einem Hut schaut auf einen Tisch .
Visualize attentions
= ['<sos>']+tokenize_sentence(sample[EN], tokenizers, EN)+['<eos>']
en_words = translation.split(' ')+['<eos>']
de_words = attention_weights.detach().cpu().numpy().squeeze(0)
values
= plt.figure(figsize=(10,22))
fig = list(range(len(en_words)))
xticks = list(range(len(de_words)))
yticks for i in range(8):
= fig.add_subplot(4, 2, i+1)
ax ='gray')
ax.matshow(values[i], cmap
ax.set_xticks(xticks)=50)
ax.set_xticklabels(en_words, rotation
ax.set_yticks(yticks)
ax.set_yticklabels(de_words)
plt.show()