nn.Sequential 、 nn.ModuleList 、 nn.ModuleDict区别

发布时间 2023-08-23 16:58:55作者: rose_halo

1、nn.Sequential 、 nn.ModuleList 、 nn.ModuleDict 类都继承自 Module 类。


 

2、各自用法

net = nn.Sequential(nn.Linear(128, 256), nn.ReLU())

net = nn.ModuleList([nn.Linear(128, 256), nn.ReLU()])

net = nn.ModuleDict({'linear': nn.Linear(128, 256), 'act': nn.ReLU()})

3、区别

  • ModuleList 仅仅是一个储存各种模块的列表,这些模块之间没有联系也没有顺序(所以不用保证相邻层的输入输出维度匹配),而且没有实现 forward 功能需要自己实现

  • 和 ModuleList 一样, ModuleDict 实例仅仅是存放了一些模块的字典,并没有定义 forward 函数需要自己定义

  • 而 Sequential 内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部 forward 功能已经实现,所以,直接如下写模型,是可以直接调用的,不再需要写forward,sequential 内部已经有 forward

4、转换

  • 将 nn.ModuleList 转换成 nn.Sequential


    module_list = nn.ModuleList([nn.Linear(128, 256), nn.ReLU()])
    net = nn.Sequential(*module_list)
  • nn.ModuleDict转换为nn.Sequential

    module_dict = nn.ModuleDict({'linear': nn.Linear(128, 256), 'act': nn.ReLU()})
    net = nn.Sequential(*module_dict.values())