Self-attention小小实践

发布时间 2023-12-24 20:57:58作者: 你好,一多

公式 1 不带权重的自注意力机制

\[Attention(X) = softmax(\frac{X\cdot{X^T}}{\sqrt{dim_X}})\cdot X \]

示例程序:

import numpy as np
emb_dim = 3
qkv_dim = 4
seq_len = 5
X = np.array([
    [1501, 502, 503],
    [2502, 501, 503],
    [503, 501, 502],
    [503, 502, 501],
    [501, 503, 5020]
])
X
array([[1501,  502,  503],  
       [2502,  501,  503],  
       [ 503,  501,  502],  
       [ 503,  502,  501],  
       [ 501,  503, 5020]])
def softmax(mtx):
    return np.exp(mtx) / np.sum(np.exp(mtx), axis=-1, keepdims=True)
scores = X.dot(X.T) / np.sqrt(emb_dim)
scores = scores - np.max(scores, axis=-1, keepdims=True)
scores
array([[  -867179.52697255,         0.        ,  -1732629.31253861,
         -1732629.88988887,   -421723.19472849],
       [ -1445685.65140109,         0.        ,  -2887906.62383678,
         -2887907.77853732,  -1578157.51596611],
       [ -1019043.43237911,   -728635.09227619,  -1309448.88573069,
         -1309449.46308095,         0.        ],
       [ -1016436.11856345,   -726028.3558108 ,  -1306841.57191502,
         -1306840.99456476,         0.        ],
       [-12802651.57528769, -12513400.24512423, -13094514.2607187 ,
        -13097122.15188463,         0.        ]])
aw = softmax(scores)
aw
array([[0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 1.]])
out = aw.dot(X)
out
array([[2502.,  501.,  503.],
       [2502.,  501.,  503.],
       [ 501.,  503., 5020.],
       [ 501.,  503., 5020.],
       [ 501.,  503., 5020.]])
np.sum(aw)
5.0

公式 2 带权重的自注意力机制

\[Attention(X) = softmax(\frac{(X\cdot w^Q)\cdot{({X\cdot w^K})^T}}{\sqrt{dim^K}})\cdot (X\cdot w^V)\\\,\\ = softmax(\frac{Q\cdot K^T}{\sqrt{dim^K}})\cdot V \]

示例程序:

import numpy as np
emb_dim = 3
qkv_dim = 4
seq_len = 5
wq = np.array([
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 1, 2, 3]
])
wk = np.array([
    [9, 8, 7, 6],
    [5, 4, 3, 2],
    [1, 9, 8, 7]
])
wv = np.array([
    [3, 6, 9, 7],
    [1, 8, 3, 6],
    [4, 5, 2, 2]
])
X = np.array([
    [1, 2, 3],
    [2, 2, 4],
    [5, 9, 7],
    [6, 6, 6],
    [8, 1, 4]
])
X
array([[1, 2, 3],
       [2, 2, 4],
       [5, 9, 7],
       [6, 6, 6],
       [8, 1, 4]])
def softmax(mtx):
    return np.exp(mtx) / np.sum(np.exp(mtx), axis=-1, keepdims=True)
Q = X.dot(wq)
Q
array([[ 38,  17,  23,  29],
       [ 48,  20,  28,  36],
       [113,  71,  92, 113],
       [ 90,  54,  72,  90],
       [ 49,  26,  39,  52]])
K = X.dot(wk)
K
array([[ 22,  43,  37,  31],
       [ 32,  60,  52,  44],
       [ 97, 139, 118,  97],
       [ 90, 126, 108,  90],
       [ 81, 104,  91,  78]])
V = X.dot(wv)
V
array([[ 17,  37,  21,  25],
       [ 24,  48,  32,  34],
       [ 52, 137,  86, 103],
       [ 48, 114,  84,  90],
       [ 41,  76,  83,  70]])
scores = Q.dot(K.T) / np.sqrt(qkv_dim)
scores
array([[ 1658.5,  2354. ,  5788. ,  5328. ,  4600.5],
       [ 2034. ,  2888. ,  7116. ,  6552. ,  5662. ],
       [ 6223. ,  8816. , 21323.5, 19611. , 16861.5],
       [ 4878. ,  6912. , 16731. , 15390. , 13239. ],
       [ 2625.5,  3722. ,  9006.5,  8289. ,  7139. ]])
scores = scores - np.max(scores, axis=-1, keepdims=True)
scores
array([[ -4129.5,  -3434. ,      0. ,   -460. ,  -1187.5],
       [ -5082. ,  -4228. ,      0. ,   -564. ,  -1454. ],
       [-15100.5, -12507.5,      0. ,  -1712.5,  -4462. ],
       [-11853. ,  -9819. ,      0. ,  -1341. ,  -3492. ],
       [ -6381. ,  -5284.5,      0. ,   -717.5,  -1867.5]])
aw = softmax(scores)
aw
array([[0.00000000e+000, 0.00000000e+000, 1.00000000e+000,
        1.67702032e-200, 0.00000000e+000],
       [0.00000000e+000, 0.00000000e+000, 1.00000000e+000,
        1.14264732e-245, 0.00000000e+000],
       [0.00000000e+000, 0.00000000e+000, 1.00000000e+000,
        0.00000000e+000, 0.00000000e+000],
       [0.00000000e+000, 0.00000000e+000, 1.00000000e+000,
        0.00000000e+000, 0.00000000e+000],
       [0.00000000e+000, 0.00000000e+000, 1.00000000e+000,
        2.47576395e-312, 0.00000000e+000]])
out = aw.dot(V)
out
array([[ 52., 137.,  86., 103.],
       [ 52., 137.,  86., 103.],
       [ 52., 137.,  86., 103.],
       [ 52., 137.,  86., 103.],
       [ 52., 137.,  86., 103.]])