PyTorch 的 BatchNorm 层

发布时间 2024-01-07 23:14:33作者: 倒地

BatchNorm 层

为了实现输入特征标准化,batch norm 层会维护一个全局均值 running_mean 和全局方差 running_var。网络 train() 时进行统计,eval() 时使用统计值。

除此之外,可选 weight 权重和 bias 权重,这两个权重是会持续参与到网络学习的。这个变换叫做仿射变换,即 线性变换 + 平移。

所以,BatchNorm 层等价于 “标准化层 + 线性层 + bias 层”。