模型权重初始化

发布时间 2023-03-23 14:30:12作者: Truman001
def weight_init(m):  # 初始化权重
    # print(m)
    if isinstance(m, torch.nn.Conv3d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
        m.weight.data = torch.randint_like(m.weight.data, low=-128, high=127)
        # m.bias.data.zero_()
        if m.bias!=None:
            m.bias.data = torch.randint_like(m.bias.data, low=-128, high=127)
    elif isinstance(m, torch.nn.Conv2d):
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        # data = np.load("weight.npy")
        # m.weight.data = torch.tensor(data)
        m.weight.data = torch.randint_like(m.weight.data, low=-128, high=127)
        # print("weight",m.weight.data.shape)
        # print(m.weight.data)
        # print(m.weight.data)
        # m=torch.nn.Conv2d(in_channels=m.in_channels, out_channels=m.out_channels, kernel_size=m.kernel_size, bias=True, stride=m.stride, padding=m.padding)
        if m.bias!=None:
            m.bias.data = torch.randint_like(m.bias.data, low=-128, high=127)
    elif isinstance(m, torch.nn.BatchNorm3d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif isinstance(m, torch.nn.Linear):
        m.weight.data=torch.randint_like(m.weight.data, low=-128, high=127)
        if m.bias is not None:
            m.bias.data.zero_()
            
            
 # 将模型权重初始化为int8
 model.apply(weight_init)