手写Conformer网络结构

发布时间 2023-12-10 00:41:40作者: 365/24/60
import torch
from torch import nn
x = torch.randint(0, 10, size=(5, 280,80))
length = torch.tensor([10,9,9,9,9])
x.size(),x.shape,x[0].shape,length

# def mad_mask(length):
batch = length.size(0)
max_len = length.max().item()
# seq_t = torch.arange(0, max_len).unsqueeze(0).expand(batch, max_len)
seq_t = torch.arange(0, max_len).unsqueeze(0).repeat(5,1)     # 维度必须对应起来
length, seq_t

seq_mask = seq_t >= length.unsqueeze(1)
mask = ~seq_mask
mask

odim = 256
idim = 80
# def embedding
conv1 = torch.nn.Conv2d(1, 256, 3, 2)
relu1 = torch.nn.ReLU()
conv2 = torch.nn.Conv2d(256, 256, 3, 2)

# x = x.unsqueeze(1).long()
x = x.unsqueeze(1).to(dtype=torch.float32)

x1 = conv1(x)
x1 = relu1(x1)
x2 = conv2(x1)
x2 = relu1(x2)

linear1 = torch.nn.Linear(odim * (((idim-1)//2 -1)//2), odim)
b, c, t, f = x2.size()
x = linear1(x2.transpose(1,2).contiguous().view(b, t, c*f))

conv1, conv2, x1.shape,x2.shape , linear1, x2.shape , x.shape


# encoder
layernorm1 = nn.LayerNorm((256), eps=1e-12)
layernorm2 = nn.LayerNorm((256), eps=1e-12)
dropout1 = nn.Dropout(0.1)
## ffn1 
drop1 = nn.Dropout(0.1)
linear1 = nn.Linear(256, 2048)
linear2 = nn.Linear(2048, 256)
activation1 = nn.ReLU()

residual = x
x1 = layernorm1(x)
x2 = residual + 0.5* linear2( drop1( activation1( linear1(x))))
x = x2

# muti_head_attn
x3 = layernorm2(x)

def generate_qkv(query, key, value):
    linear_q = nn.Linear(256, 256)
    linear_k = nn.Linear(256, 256)
    linear_v = nn.Linear(256, 256)
    
    n_head = 4
    n_batch = query.size(0)
    n_feat = query.size(-1)
    d_k = n_feat//n_head
    
    q = linear_q(query).view(n_batch, -1 , n_head, d_k)
    k = linear_k(key).view(n_batch, -1 , n_head, d_k)
    v = linear_v(value).view(n_batch, -1 , n_head, d_k)

    
    return q,k,v   

查看参数



for name, p in conv1.named_parameters():
    print(name, p.shape, p.numel())
for name, p in conv2.named_parameters():
    print(name, p.shape, p.numel())