pytorch torch.nn.BatchNorm1d

发布时间 2023-10-07 16:34:42作者: emanlee

pytorch torch.nn.BatchNorm1d

==================================

nn.BatchNorm1d本身不是给定输入矩阵,输出归一化结果的函数,而是定义了一个方法,再用这个方法去做归一化。
下面是一个例子。

BN = nn.BatchNorm1d(100)
input = torch.randn(20, 100)
output = m(input)

我们首先定义了一个归一化的函数BN,需要归一化的维度为100,其他参数为默认。然后随机初始化一个20×100的矩阵input,再用BN对这个矩阵归一化。
函数的input可以是二维或者三维。当input的维度为(N, C)时,BN将对C维归一化;当input的维度为(N, C, L) 时,归一化的维度同样为C维。
链接:https://blog.csdn.net/qsmx666/article/details/109527726

==================================

nn.BatchNorm1d 是 PyTorch 中的一个用于一维数据(例如序列或时间序列)的批标准化(Batch Normalization)层。

批标准化是一种常用的神经网络正则化技术,旨在加速训练过程并提高模型的收敛性和稳定性。它通过对每个输入小批次的特征进行归一化处理来规范化输入数据的分布。

在一维数据上使用 nn.BatchNorm1d 层时,它会对每个特征维度上的数据进行标准化处理。具体而言,它会计算每个特征维度的均值和方差,并将输入数据进行中心化和缩放,以使其分布接近均值为0、方差为1的标准正态分布。

使用 nn.BatchNorm1d 层可以有效地解决神经网络训练过程中出现的内部协变量偏移问题,加速训练收敛,并提高模型的泛化能力。
 链接:https://blog.csdn.net/AdamCY888/article/details/131270585

 

import torch
import torch.nn as nn

# 定义一个简单的神经网络模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.bn = nn.BatchNorm1d(20)
        self.fc2 = nn.Linear(20, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.bn(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

# 创建模型实例
model = MyModel()

# 随机生成输入数据
input_tensor = torch.randn(32, 10)

# 前向传播
output_tensor = model(input_tensor)

 

 

 

==================================

https://zhuanlan.zhihu.com/p/622644675

在pytorch的官方文档中,对torch.nn.BatchNorm1d的叙述是这样的:

torch.nn.BatchNorm1d(num_features,eps=1e-5,momentum=0.1,affine=True,track_running_stats=True,device=None,dtype=None)

具体参数的使用这里就不啰嗦了,紧接着

Applies Batch Normalization over a 2D or 3D input as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .

这句话是说:在2D和3D输入数据上应用批量正则化,在论文《Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》有具体描述。

并给出了具体计算公式:

 

 

这个公式怎么计算的?

首先,在torch.nn.BatchNorm1d()中:

参数num_features (int) – 表述为number of features or channels C of the input(输入的数据的通道数或者是特征数)

输入数据Input形状为: (N,C) or(N,C,L), where N is the batch size, C is the number of features or channels, and L is the sequence length,是对输入数据形状的描述,N就为批量数,无论三维,还是二维,C就是num_features数值,只是在输入数据为二维的时候,它是最后一维的量,三维的时候,它就是通道值。

所以我们来分两种情况:

第一种,在2D输入数据上:

import torch
import torch.nn as nn
x = torch.tensor([[0, 1, 2],
                  [3, 4, 5],
                  [6, 7, 8]], dtype=torch.float)
print(x)
print(x.shape)   #  x的形状为(3,3)
m = nn.BatchNorm1d(3)   #  num_features的值必须为形状的最后一数3
y = m(x)
print(y)
# 输出的结果是
tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
torch.Size([3, 3])
tensor([[-1.2247, -1.2247, -1.2247],
        [ 0.0000,  0.0000,  0.0000],
        [ 1.2247,  1.2247,  1.2247]], grad_fn=<NativeBatchNormBackward0>)

 

 

着重想要说明的就是:批量,是在第一维度上批量,第一维度是3,即3行,特征数是3个,即每一行的列数,所以它此时计算的均值就是在每一个列向量上计算正则化,均值

是列向量中每个元素的平均,比如第一个特征的均值就是 ,所以方差为 ,所以第一列元素的正则化就是 ,其他列数值同样计算。这里

的默认值分别是1和0。

 

第二种,在3D输入数据上:

x = torch.tensor([[[0, 1, 2],
                   [3, 4, 5]],
                  [[6, 7, 8],
                   [9, 10, 11]]], dtype=torch.float)
print(x)
print(x.shape)   #  x的形状为(2,2,3)
m = nn.BatchNorm1d(2)   #  num_features的值必须第二维度的数值,即通道数2
y = m(x)
print(y)
# 输出的结果是
tensor([[[ 0.,  1.,  2.],
         [ 3.,  4.,  5.]],

        [[ 6.,  7.,  8.],
         [ 9., 10., 11.]]])
torch.Size([2, 2, 3])
tensor([[[-1.2865, -0.9649, -0.6433],
         [-1.2865, -0.9649, -0.6433]],

        [[ 0.6433,  0.9649,  1.2865],
         [ 0.6433,  0.9649,  1.2865]]], grad_fn=<NativeBatchNormBackward0>)
 

着重想要说明的就是:批量,也是在第一维度上批量,第一维度是2,即2个二维矩阵之间批量,通道数(也就是特征数)是2。

现在,可以把3D数据想象为虚假的二维(2D)数据,只是在此时最后一维它的每一个单独元素是行向量,而在纯粹的真正二维2D数据中,最后一维它的每一个单独元素是标量,对照2D数据的计算过程,此时计算具体步骤是:该数据按照两个通道计算,

第一个通道均值是由[0,1,2]和[6,7,8]这两个行元素相加得到,这两个行向量横跨于第一批和第二批数据,就像2D数据中两个标量相加;

第二个通道均值是由[3,4,5]和[9,10,11]这两个行元素相加得到。

既然是计算数值均值,还需把他们的元素之间全部相加,所以第一个通道的均值是:

,方差为 ,所以[0,1,2]元素的正则化就是

,其他数值同样计算可以得到。

大家可以尝试计算一下,大致过程就是如此。

 

==================================