0%
August 11, 2022

Transformer 1: The Model Definition and Naive Training Dataset for Machine Translation

deep-learning

pytorch

Grahpical Structure of Transformer

Constants in the Config

d_model = 512  # Embedding Size
d_ff = 2048  # FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_layers = 6  # number of Encoder of Decoder Layer
n_heads = 8  # number of heads in Multi-Head Attention

Positional Encoding

A positional encoding

is a mapping that takes a positional index to a vector of word embedding dimension. In terms of -array,

where and .

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # unfreeze to prepare for batches of word embeddings
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x: [seq_len, batch_size, d_model]
        '''
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

PoswiseFeedForwardNet

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self):
        super(PoswiseFeedForwardNet, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(d_model, d_ff, bias=False),
            nn.ReLU(),
            nn.Linear(d_ff, d_model, bias=False)
        )

    def forward(self, inputs):
        '''
        inputs: [batch_size, seq_len, d_model]
        '''
        residual = inputs
        output = self.fc(inputs)
        # [batch_size, seq_len, d_model]
        return nn.LayerNorm(d_model).to(device)(output + residual)

Attentions

Scaled Dot Product Attention
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v]
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        '''
        # scores : [batch_size, n_heads, len_q, len_k]
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)
        # Fills elements of self tensor with value where mask is True.
        scores.masked_fill_(attn_mask, -1e9)

        attn = nn.Softmax(dim=-1)(scores)
        # [batch_size, n_heads, len_q, d_v]:
        context = torch.matmul(attn, V)
        return context, attn

attn_mask is used to mask out value of undesired position such as

  • those from padding or
  • those from "future word" in self-attention module of decoder.
Repeated Use of ScaledDotProductAttention: Multi-head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

    def forward(self, input_Q, input_K, input_V, attn_mask):
        '''
        input_Q: [batch_size, len_q, d_model]
        input_K: [batch_size, len_k, d_model]
        input_V: [batch_size, len_v(=len_k), d_model]
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        residual, batch_size = input_Q, input_Q.size(0)
        # (B, S, D) -proj-> (B, S, D_new) -split-> (B, S, H, W) -trans-> (B, H, S, W)
        # Q: [batch_size, n_heads, len_q, d_k]
        Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)
        # K: [batch_size, n_heads, len_k, d_k]
        K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)
        # V: [batch_size, n_heads, len_v(=len_k), d_v]
        V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)

        # attn_mask : [batch_size, n_heads, seq_len, seq_len]
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)

        # context: [batch_size, n_heads, len_q, d_v], attn: [batch_size, n_heads, len_q, len_k]
        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)
        # context: [batch_size, len_q, n_heads * d_v]
        context = context.transpose(1, 2).reshape(
            batch_size, -1, n_heads * d_v)
        output = self.fc(context)  # [batch_size, len_q, d_model]
        return nn.LayerNorm(d_model).to(device)(output + residual), attn

Mask Creation

get_attn_pad_mask
1def get_attn_pad_mask(seq_q, seq_k, pad=0):
2    '''
3    seq_q: [batch_size, seq_len]
4    seq_k: [batch_size, seq_len]
5    seq_len could be src_len or it could be tgt_len
6    seq_len in seq_q and seq_len in seq_k maybe not equal
7    '''
8    batch_size, len_q = seq_q.size()
9    batch_size, len_k = seq_k.size()
10    # eq(pad) is PAD token
11    # [batch_size, 1, len_k], True is masked
12    pad_attn_mask = seq_k.eq(pad).unsqueeze(1)
13    # [batch_size, len_q, len_k]
14    return pad_attn_mask.expand(batch_size, len_q, len_k)

To apply get_attn_pad_mask, input that we want to mask will be fed into the second argument.

get_attn_subsequence_mask
def get_attn_subsequence_mask(seq):
    '''
    seq: [batch_size, tgt_len]
    '''
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    # Upper triangular matrix
    subsequence_mask = np.triu(np.ones(attn_shape), k=1)
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask  # [batch_size, tgt_len, tgt_len]
How will Attention Mask be Used

In get_attn_pad_mask, seq_q means a sequence of indexes, i.e, , which is to be embedded into and form queries , the same is true for seq_k and , they will be multiplied together to get

for scaled dot-product attention.

On Padding Logic

Note that our padding logic just depends on seq_k and pad (see line 12), seq_q is only used to expand the number of rows to match the dimension of (see line 14) in order to apply the mask.

Our mask will be applied right before applying softmax, i.e., we eventually get scores in rows.

Denote the dimension in attention module, where

We get the scores (the attention) by soft-maxing along the last dimension:

The matrix can be decomposed as for some and , where with being the number of paddings counted from the end of the sentence. Our attention will be the linear span of values (the rows) in :

The last few rows of ( for ) has no contribution to the context.

Note that in each flow of computation we are focusing on one sentence with multiple indexes. The last few 's are value from the padding position that we are not concerned about.

Encoder Layer and Encoder

In transformer an Encoder is a iteration of several EncoderLayer:

EncoderLayer
class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, enc_inputs, enc_self_attn_mask):
        '''
        enc_inputs: [batch_size, src_len, d_model]
        enc_self_attn_mask: [batch_size, src_len, src_len]
        '''
        # enc_outputs: [batch_size, src_len, d_model],
        # attn: [batch_size, n_heads, src_len, src_len]
        enc_outputs, attn = self.enc_self_attn(
            # enc_inputs to same Q,K,V
            enc_inputs,
            enc_inputs,
            enc_inputs,
            enc_self_attn_mask
        )
        # enc_outputs: [batch_size, src_len, d_model]
        enc_outputs = self.pos_ffn(enc_outputs)
        return enc_outputs, attn

Here the enc_self_attn_mask is usually just a mask that masks out the value contributed from padding (prevent the value from padding from joining the calculation).

Encoder
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])

    def forward(self, enc_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        '''
        # [batch_size, src_len, d_model]:
        enc_outputs = self.src_emb(enc_inputs)
        # [batch_size, src_len, d_model]:
        enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1)
        # [batch_size, src_len, src_len]:
        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)
        enc_self_attns = []
        for layer in self.layers:
            # enc_outputs: [batch_size, src_len, d_model],
            # enc_self_attn: [batch_size, n_heads, src_len, src_len]
            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)
            enc_self_attns.append(enc_self_attn)
        # enc_self_attns is only for graph-plotting purpose:
        return enc_outputs, enc_self_attns

Decoder Layer and Decoder

DecoderLayer
class DecoderLayer(nn.Module):
    def __init__(self):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention()
        self.dec_enc_attn = MultiHeadAttention()
        self.pos_ffn = PoswiseFeedForwardNet()

    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        '''
        dec_inputs: [batch_size, tgt_len, d_model]
        enc_outputs: [batch_size, src_len, d_model]
        dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        dec_enc_attn_mask: [batch_size, tgt_len, src_len]
        '''
        # dec_outputs: [batch_size, tgt_len, d_model],
        # dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len]
        dec_outputs, dec_self_attn = self.dec_self_attn(
            dec_inputs,
            dec_inputs,
            dec_inputs,
            dec_self_attn_mask
        )
        # dec_outputs: [batch_size, tgt_len, d_model],
        # dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
        dec_outputs, dec_enc_attn = self.dec_enc_attn(
            dec_outputs,
            enc_outputs,
            enc_outputs,
            dec_enc_attn_mask
        )
        # [batch_size, tgt_len, d_model]
        dec_outputs = self.pos_ffn(dec_outputs)
        return dec_outputs, dec_self_attn, dec_enc_attn
Decoder
1class Decoder(nn.Module):
2    def __init__(self):
3        super(Decoder, self).__init__()
4        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
5        self.pos_emb = PositionalEncoding(d_model)
6        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])
7
8    def forward(self, dec_inputs, enc_inputs, enc_outputs):
9        '''
10        dec_inputs: [batch_size, tgt_len]
11        enc_intpus: [batch_size, src_len]
12        enc_outputs: [batch_size, src_len, d_model]
13        '''
14
15        # [batch_size, tgt_len, d_model]:
16        dec_outputs = self.tgt_emb(dec_inputs)
17        # [batch_size, tgt_len, d_model]:
18        dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1))\
19            .transpose(0, 1)\
20            .to(device)
21
22        # [batch_size, tgt_len, tgt_len]:
23        dec_self_attn_pad_mask = get_attn_pad_mask(
24            dec_inputs,
25            dec_inputs
26        ).to(device)
27
28        # [batch_size, tgt_len, tgt_len]:
29        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).to(device)
30        # [batch_size, tgt_len, tgt_len]:
31        dec_self_attn_mask = torch.gt(
32            (dec_self_attn_pad_mask + dec_self_attn_subsequence_mask),
33            0
34        ).to(device)
35
36        # [batc_size, tgt_len, src_len]:
37        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs)
38
39        dec_self_attns, dec_enc_attns = [], []
40        for layer in self.layers:
41            # dec_outputs: [batch_size, tgt_len, d_model],
42            # dec_self_attn: [batch_size, n_heads, tgt_len, tgt_len],
43            # dec_enc_attn: [batch_size, h_heads, tgt_len, src_len]
44            dec_outputs, dec_self_attn, dec_enc_attn = layer(
45                dec_outputs,
46                enc_outputs,
47                dec_self_attn_mask,
48                dec_enc_attn_mask
49            )
50            dec_self_attns.append(dec_self_attn)
51            dec_enc_attns.append(dec_enc_attn)
52        return dec_outputs, dec_self_attns, dec_enc_attns
Why Feed enc_inputs into Decoder?
  • From the structural graph of transformer it may seem weird to also feed enc_inputs into the Decoder.

  • In fact, enc_inputs is only used in the creation of padding mask (see line 37 of the Decoder code block) which help ignore the last few rows of the context value matrix in the enc_outputs (think of it as another form of embedding from the original sequence of word indexes).

The Transformer

class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()
        self.encoder = Encoder().to(device)
        self.decoder = Decoder().to(device)
        self.projection = nn.Linear(
            d_model, tgt_vocab_size, bias=False
        ).to(device)

    def forward(self, enc_inputs, dec_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        '''
        # tensor to store decoder outputs
        # outputs = torch.zeros(batch_size, tgt_len, tgt_vocab_size).to(self.device)

        # enc_outputs: [batch_size, src_len, d_model]
        # enc_self_attns: [n_layers, batch_size, n_heads, src_len, src_len]
        enc_outputs, enc_self_attns = self.encoder(enc_inputs)
        # dec_outpus: [batch_size, tgt_len, d_model]
        # dec_self_attns: [n_layers, batch_size, n_heads, tgt_len, tgt_len]
        # dec_enc_attn: [n_layers, batch_size, tgt_len, src_len]
        dec_outputs, dec_self_attns, dec_enc_attns = self.decoder(
            dec_inputs,
            enc_inputs,
            enc_outputs
        )
        # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        dec_logits = self.projection(dec_outputs)
        return (
            dec_logits.view(-1, dec_logits.size(-1)),
            enc_self_attns,
            dec_self_attns,
            dec_enc_attns
        )
  • Note that dec_logits will be of shape [batch_size, tgt_max_len, tgt_vocab_size].
  • dec_logits.view(-1, dec_logits.size(-1)) will be of shape
    [batch_size*tgt_max_len, tgt_vocab_size]

Training: Naive Dataset

Dataset and data_loader

Let's get a rough feeling from a naive dataset:

sentences = [
    # enc_input                 dec_input                dec_output
    ['ich mochte ein bier <P>', '<sos> i want a beer .', 'i want a beer . <eos>'],
    ['ich mochte ein cola <P>', '<sos> i want a coke .', 'i want a coke . <eos>']
]

It is not necessary to write <P> in the dataset, we can pad our sequence of word indexes by 0's in our dataset pipeline.

Next we define our data_loader:

src_word_index = {'<P>': 0, 'ich': 1,
                  'mochte': 2, 'ein': 3, 'bier': 4, 'cola': 5}
src_vocab_size = len(src_word_index)

tgt_word_index = {'<P>': 0, 'i': 1, 'want': 2, 'a': 3,
                  'beer': 4, 'coke': 5, '<sos>': 6, '<eos>': 7, '.': 8}
src_index_word = {i: w for i, w in enumerate(src_word_index)}
tgt_index_word = {i: w for i, w in enumerate(tgt_word_index)}
tgt_vocab_size = len(tgt_word_index)

src_len = 5  # enc_input max sequence length
tgt_len = 6  # dec_input(=dec_output) max sequence length


def make_data(sentences):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
        enc_input_, dec_input_, dec_output_ = sentences[i]

        enc_input = [src_word_index[n] for n in enc_input_.split()]
        # ^^^^ [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]

        dec_input = [tgt_word_index[n] for n in dec_input_.split()]
        # ^^^^ [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]

        dec_output = [tgt_word_index[n] for n in dec_output_.split()]
        # ^^^^ [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]

        enc_inputs.append(enc_input)
        dec_inputs.append(dec_input)
        dec_outputs.append(dec_output)

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)


class MyDataSet(Data.Dataset):
    def __init__(self):
        super(MyDataSet, self).__init__()
        enc_inputs, dec_inputs, dec_outputs = make_data(sentences)
        self.enc_inputs = enc_inputs
        self.dec_inputs = dec_inputs
        self.dec_outputs = dec_outputs

    def __len__(self):
        return self.enc_inputs.shape[0]

    def __getitem__(self, idx):
        return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]


data_loader = Data.DataLoader(MyDataSet(), batch_size=2, shuffle=True)
Train to see if the Model can Overfit
def train():
    transformer = Transformer().to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = optim.SGD(transformer.parameters(), lr=1e-3, momentum=0.99)
    epochs = 30

    for epoch in range(epochs):
        for enc_inputs, dec_inputs, dec_outputs in data_loader:
            '''
            enc_inputs: [batch_size, src_len]
            dec_inputs: [batch_size, tgt_len]
            dec_outputs: [batch_size, tgt_len]
            '''
            enc_inputs = enc_inputs.to(device)
            dec_inputs = dec_inputs.to(device)
            dec_outputs = dec_outputs.to(device)

            outputs, enc_self_attns, dec_self_attns, dec_enc_attns = transformer(
                enc_inputs,
                dec_inputs
            )
            loss = criterion(outputs, dec_outputs.view(-1))
            print('Epoch:', '%04d' % (epoch + 1),
                  'loss =', '{:.6f}'.format(loss))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    state_dict = transformer.state_dict()
    torch.save(state_dict, os.path.join("pths", f"model_epoch_{epochs}.pth"))

Upon executing the training script, our model can indeed overfit the naive dataset:

Epoch: 0001 loss = 2.579853
Epoch: 0002 loss = 2.416333
Epoch: 0003 loss = 2.135146
Epoch: 0004 loss = 1.839263
Epoch: 0005 loss = 1.538042
Epoch: 0006 loss = 1.303803
Epoch: 0007 loss = 1.134294
Epoch: 0008 loss = 0.898169
Epoch: 0009 loss = 0.751822
Epoch: 0010 loss = 0.609819
Epoch: 0011 loss = 0.477917
Epoch: 0012 loss = 0.374170
Epoch: 0013 loss = 0.280897
Epoch: 0014 loss = 0.225249
Epoch: 0015 loss = 0.175177
Epoch: 0016 loss = 0.143766
Epoch: 0017 loss = 0.139490
Epoch: 0018 loss = 0.114785
Epoch: 0019 loss = 0.090137
Epoch: 0020 loss = 0.087496
Epoch: 0021 loss = 0.076527
Epoch: 0022 loss = 0.069732
Epoch: 0023 loss = 0.061058
Epoch: 0024 loss = 0.055305
Epoch: 0025 loss = 0.042023
Epoch: 0026 loss = 0.040661
Epoch: 0027 loss = 0.039919
Epoch: 0028 loss = 0.023459
Epoch: 0029 loss = 0.022219
Epoch: 0030 loss = 0.027067
Define Translator

We predict the target result word by word as in the teacher forcing approach in our training process.

class Translator():
    def __init__(self, transformer: Transformer):
        self.transformer = transformer

    def translate(self, enc_input, start_index):
        dec_input = torch.zeros(1, 0).type_as(enc_input)
        terminated = False
        next_tgt_word_index = start_index
        while not terminated:
            dec_input = torch.cat(
                [
                    dec_input.detach(),
                    torch.tensor([[next_tgt_word_index]],dtype=enc_input.dtype).to(device)
                ],
                -1
            )
            dec_output_logits, _, _, _= self.transformer(enc_input, dec_input)
            next_tgt_word_index = torch.argmax(dec_output_logits[-1])

            if next_tgt_word_index == tgt_word_index["."]:
                terminated = True

            print("next_word", tgt_index_word[next_tgt_word_index.item()])

        # remove batch, remove <sos>
        return dec_input.squeeze(0)[1:]
Translation Script

Now we test our function by using the script:

transformer = Transformer().to(device)
model_path = "pths/model_epoch_30.pth"

if model_path is not None:
    transformer.load_state_dict(torch.load(model_path))

translator = Translator(transformer)
enc_inputs, _, _ = next(iter(data_loader))
enc_inputs = enc_inputs.to(device)
#   e.g. enc_inputs = tensor([
#       [1, 2, 3, 4, 0], [1, 2, 3, 5, 0]
#   ], device='cuda:0')
for i in range(len(enc_inputs)):
    enc_input = enc_inputs[i]
    sentence = " ".join([src_index_word[i.item()] for i in enc_input])
    print("source sentence:", sentence)
    predict = translator.translate(
        enc_input.unsqueeze(0), # expand as batch
        start_index=tgt_word_index["<sos>"]
    )
    print(enc_input, '->', [tgt_index_word[n.item()]
                            for n in predict.squeeze()])

yields

source sentence: ich mochte ein bier <P>
next_word i
next_word want
next_word a
next_word beer
next_word .
tensor([1, 2, 3, 4, 0], device='cuda:0') -> ['i', 'want', 'a', 'beer']
source sentence: ich mochte ein cola <P>
next_word i
next_word want
next_word a
next_word coke
next_word .
tensor([1, 2, 3, 5, 0], device='cuda:0') -> ['i', 'want', 'a', 'coke']