TransWikia.com

Why does an attention layer in a transformer learn context?

Data Science Asked by Nick Koprowicz on June 10, 2021

I understand the transformer architecture (from "Attention is All You Need"), as well as how the attention is computed in the multi-headed attention layers.

What I’m confused on is why the output of an attention layer is a context vector. That is to say: what is it about the way that a transformer is trained causes the attention layers to learn context? What I would expect to see in the paper is a justification along the lines of "when you train a transformer using attention on sequence-to-sequence tasks, the attention layers learn context and here’s why…". I believe it because I’ve seen the heatmaps that show that attention between related words, but I want to understand why that is necessarily the result of training a transformer.

Why couldn’t it be the case that the attention layers learn some other features that happen to also be beneficial in sequence to sequence tasks? How do we know that they learn context, other than that’s what we observe?

Again, I get the math and I know there are several posts about it. What I want to know is what about the math or the training process implies that the attention layers learn context.

One Answer

To provide a simplistic and less mathematical reasons. You can assume like this:

In a simple feed-forward neural network (a black-box of course), you shall learn the set of weights, learning a function to map inputs to outputs.

But, in the transformers based architecture, you have Attentions. Here, the weights are structured into Query, Key and Value (Q,K,V). These 3 set of weights drivers of attention and are responsible to learn the context. How precisely it works still remains a black-box like feed forward networks. But yeah, it works something like this, every token's embedding is transformed its query, key and value vectors using their respective weight matrices. For a given token, its query vector is multiplied with all the other tokens' key vector to obtain a value vector. These values determines the importance of every token with respect to the query token. Thus, with back-propagation, you try to optimize these Q, K, V weights, and thus learn it to better map the relationship between tokens.

Answered by Ashwin Geet D'Sa on June 10, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP