Mathematics behind GPT3 - Masked Multihead Self Attention

Hey everyone! Not sure if this is the right place to post, but recently in my free time, I was reviewing Transformers and the maths / guts behind it. I re-skimmed Attention is All You Need [1706.03762] Attention Is All You Need, Uni of Waterloo’s CS480/680 Lecture 19: Attention and Transformer Networks - YouTube and Ali Ghodsi, Lect 13 (Fall 2020): Deep learning, Transformer, BERT, GPT - YouTube and Attention Is All You Need - YouTube

I know GPTx is just the Decoder with Masked Multihead self attention predicting learnt word embeddings X with a softmax final layer predicting the next token.

I minused the batch normalization and residual connections for simplicity.

Normal attention’s equations where W are learnt parameters. Likewise X is also the learnt embeddings. Sigma is the Softmax function.

However with masking, since GPTx attends ONLY to stuff from the left, we use a masking matrix M. M is 0s on the lower triangular and diagonal, and -inf on the upper triangular. This is cause exp(-inf) = 0.

Using matrix diagrams, we see that the shapes of the multiplications above become:

After considering multihead attention, we concatenate every attention output and do a large linear kernel:

Likewise the position encoding is just sines and cosines each dimension.

I’m working on backpropagation derivatives still!

Is my formulations / summary correct? Did I do something wrong in trying to understand transformers?

3 Likes

Cool!! I’ve currently still trying to derive derivatives for backpropagation. I’m extremely extremely struggling with the softmax derivative :frowning: I’m confused whether it’s a 2D matrix or a 3D output.