net.apply(weights_init)的理解

发布时间 2023-08-31 22:49:32作者: 绘守辛玥

在DCGAN的学习中,Pytorch官方对于权重初始化使用了下列方法

# custom weights initialization called on ``netG`` and ``netD``
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

在这里对该代码学习后的理解做一些记录。首先是apply(fn),根据官网解释该方法是Module类的方法,作用是将fn递归地应用于每个子模块(由.children()返回),以及其自身。典型用途便是初始化模型的参数。我们这里来写一个简单的神经网络 net 并将其实例化

def weights_init(m):
    print(m)


net = nn.Sequential(nn.Linear(1, 1), nn.Conv2d(1, 1, 1))

net.apply(weights_init)

我们定义了一个weights_init函数,和一个Sequential类,该类有两层,第一层是全链接层,第二层是卷积层。将该类实例化后调用其apply()方法,我们来运行看看

>>> Linear(in_features=1, out_features=1, bias=True)
Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
Sequential(
  (0): Linear(in_features=1, out_features=1, bias=True)
  (1): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
)

可以看到apply()遍历了该类的每一层和其自身,我们这里将打印的参数再改成内建参数m.__class__看看

def weights_init(m):
    print(m.__class__)


net = nn.Sequential(nn.Linear(1, 1), nn.Conv2d(1, 1, 1))

net.apply(weights_init)
>>> <class 'torch.nn.modules.linear.Linear'>
<class 'torch.nn.modules.conv.Conv2d'>
<class 'torch.nn.modules.container.Sequential'>

在这里net对象的所在类被递归出来了,最后在把其换成m.__class__.__name__运行

def weights_init(m):
    print(m.__class__.__name__)


net = nn.Sequential(nn.Linear(1, 1), nn.Conv2d(1, 1, 1))

net.apply(weights_init)
>>> Linear
Conv2d
Sequential

可以看出该方法含义是递归神经网络并返回每层名字,如果该名字找到了字符串'Conv'或者'BatchNorm',则对其权重做归一化