【pytorch】土堆pytorch教程学习(六)神经网络的基本骨架——nn.module的使用

发布时间 2023-05-09 13:02:29作者: hzyuan

torch.nn 是 pytorch 的一个神经网络库(nn 是 neural network 的简称)。

Containers

torch.nn 构建神经网络的模型容器(Containers,骨架)有以下六个:

  • Module
  • Sequential
  • ModuleList
  • ModuleDict
  • ParameterList
  • ParameterDict

本博文将介绍神经网络的基本骨架——nn.module的使用。

Module

所有神经网络模块的基类。自定义的模型也应该继承该类。

自定义模型继承该类要重写 __init__()forward()

  • __init__() 里构建子模块,将子模块作为当前模块类的常规属性。一般将网络中具有可学习参数的层放在__init__中。
  • forward() 前向传播函数,定义每次调用时执行的计算,应该被所有子类重写。
# 官方案例
import torch.nn as nn
import torch.nn.functional as F

# 自定义模型
class Model(nn.Module):
    def __init__(self):
        super().__init__() # 在对子类进行赋值之前,必须对父类进行__init__调用。
        # 构建子模块
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

	# 前向传播函数
    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))
        
# 模型调用
x = torch.randn(3, 1, 10, 20)
model = Model()
y = model(x)

为什么 forward() 方法能在model(x)时自动调用?
在 python 中当一个类定义了 __call__方法,则这个类实例就成为了可调用对象。而nn.Module 中的 __call__ 方法中调用了 forward() 方法,因此继承了 nn.Module 的子类对象就可以通过 model(x) 来调用 forward() 函数。


只要在 nn.Module 的子类中定义了 forward 函数,backward 函数就会被自动实现(利用Autograd)。