網(wǎng)站文化建設(shè)軟文新聞發(fā)布網(wǎng)站
以下用形狀來描述矩陣。對(duì)于向量,為了方便理解,也寫成了類似(1,64)這種形狀的表示形式,這個(gè)你理解為64維的向量即可。下面講的矩陣相乘都是默認(rèn)的叉乘。
詞嵌入矩陣形狀:以BERT_BASE為例,我們知道其有12層Encoder,12個(gè)Head。對(duì)于中文版的BERT_BASE來說,詞嵌入矩陣的形狀為(21128,768),其中21128就是詞典的大小,768是詞典中的每個(gè)字對(duì)應(yīng)的維度。
需要注意的是這個(gè)維度其實(shí)可以是其他值,只不過官方恰巧給的是768=64×12(12個(gè)head,每個(gè)head是64維),對(duì)于Transformer的Encoder來說,這個(gè)維度是512,這個(gè)時(shí)候512≠64×6(6個(gè)head,每個(gè)head為64維)。一般來說Encoder層數(shù)越多,該詞向量維度也應(yīng)該越大,畢竟整個(gè)網(wǎng)絡(luò)參數(shù)數(shù)量增大之后,有能力學(xué)習(xí)更多維度的信息。
詞向量維度:然后我們知道,每個(gè)位置x的輸入其實(shí)一開始是一個(gè)序數(shù),通過這個(gè)序數(shù)便可以在上述詞嵌入矩陣中查找到相應(yīng)的詞向量,每個(gè)位置x的詞向量維度為(1,768)。對(duì)于整個(gè)BERT序列來說,其序列長(zhǎng)度為512,所以BERT序列的形狀為(512,768)。
Q、K、V向量的維度:這個(gè)是論文中固定的,維度都是(1,64)。而由詞向量x到Q、K、V向量是分別乘以一個(gè)權(quán)重矩陣(Wq、Wk、Wv)得到的,所以權(quán)重矩陣的形狀為(768,64)。上述都是一個(gè)head的情況,擴(kuò)展到12個(gè)head,那么整個(gè)權(quán)重矩陣的形狀就變成了(768,768)。這樣詞向量x和這個(gè)權(quán)重矩陣相乘后得到維度為(1,768)維度的向量,然后經(jīng)過切分在單個(gè)head上為(1,64)維度的向量。
注意力計(jì)算后的維度:注意力的計(jì)算如下,這里盜個(gè)圖,鏈接為:
https://zhuanlan.zhihu.com/p/48508221
可知Q向量(1,64)和K的轉(zhuǎn)置(64,1)相乘后其實(shí)就變成了一個(gè)數(shù),該數(shù)再和V向量進(jìn)行數(shù)乘得到的z向量,維度和V一樣,為(1,64)。多個(gè)head中的z向量進(jìn)行拼接得到(1,768)維的Z’向量。Z’向量再乘以一個(gè)轉(zhuǎn)換矩陣Wo(768,768)得到最終的Z向量(1,768)。
需要注意的是,上述圖中的Q、K、V向量均有兩個(gè),最終得到兩個(gè)z向量。并且這里公式?jīng)]有考慮掩碼的情況,但是掩碼并不影響矩陣的形狀。
前饋神經(jīng)網(wǎng)絡(luò)(FFNN)的形狀:前饋神經(jīng)網(wǎng)絡(luò)用一句話概括就是對(duì)于多頭注意力的輸出先進(jìn)行線性變化,然后經(jīng)過激活函數(shù)之后再進(jìn)行線性變換。前饋神經(jīng)網(wǎng)絡(luò)的維度為3072,由于單個(gè)時(shí)刻多頭注意力的輸出維度為(1,768),第一個(gè)線性變換的矩陣形狀為(768,3072),第二個(gè)線性變換矩陣的形狀為(3072,768)。