pytorch 计算网络模型的计算量FLOPs和参数量parameter参数数量

发布时间 2023-10-08 08:09:52作者: emanlee

 
    参数量方法一:pytorch自带方法,计算模型参数总量
    参数量方法二: summary的使用:来自于torchinfo第三方库
    参数量方法三: summary的使用:来自于torchsummary第三方库
    计算量方法一:thop的使用,输出计算量FLOPs和参数量parameter

我们通常要通过计算网络模型的计算量FLOPs和参数量parameter来评估模型的性能,总结了几种常用的计算方式,大家可以尝试一下。
为了能够便于读者理解,我们选取pytorch自带的网络resnet34进行测试,也可自行更改为其他或所提网络。

参数量方法一:pytorch自带方法,计算模型参数总量

from torchvision.models import resnet34

net=resnet34() # 注意:模型内部传参数和不传参数,输出的结果是不一样的
# 计算网络参数
total = sum([param.nelement() for param in net.parameters()])
# 精确地计算:1MB=1024KB=1048576字节
print('Number of parameter: % .4fM' % (total / 1e6))

 输出:
Number of parameter:  21.7977M


参数量方法二: summary的使用:来自于torchinfo第三方库

torchinfo 的 summary 更加友好,是 print 和 torchsummary 的 summary 的结合体

from torchvision.models import resnet34
import torch
from torchinfo import summary  # 注意:当使用from torchsummary import summary时,对应的summary应该为:summary(model, input_size=(3, 512, 512), batch_size=-1)
if __name__ == "__main__":
    model = resnet34()
    tmp_0 = model(torch.rand(1, 3, 224, 224).cuda())  ### torch.rand(1, 3, 224, 224
    print(tmp_0.shape)

    summary(model, (1, 3, 224, 224))# summary的函数内部参数形式与导入的第三方库有关,否则报错


ModuleNotFoundError: No module named 'torchinfo'
pip install torchinfo



参数量方法三: summary的使用:来自于torchsummary第三方库

torchsummary 中的 summary 可以打印每一层的shape, 参数量,

from torchvision.models import resnet34
from torchsummary import summary
model = resnet34()
summary(model, input_size=(3, 512, 512), batch_size=-1)# 同样是summary函数,注意与方法二的区别       input_size=(3, 512, 512)

 

ModuleNotFoundError: No module named 'torchsummary'
pip install torchsummary


计算量方法一:thop的使用,输出计算量FLOPs和参数量parameter

注意区分FLOPs和FLOPS
FLOPs就是表示模型前向传播中计算MAC(乘法加法操作的次数),如果FLOPs的值越大,也从一定程度上说明模型越复杂,模型需要的计算力(算力)更高,因此对硬件的要求也就越高!

from torchvision.models import resnet34
import torch
from thop import profile
if __name__ == "__main__":
    # #call Transception_res

    model = resnet34()
    input = torch.randn(1, 3, 512, 512) ### ?
    Flops, params = profile(model, inputs=(input,)) # macs
    print('Flops: % .4fG'%(Flops / 1000000000))# 计算量
    print('params参数量: % .4fM'% (params / 1000000)) #参数量:等价与上面的summary输出的Total params值

 

ModuleNotFoundError: No module named 'thop'
pip install thop
 

该网络模型中包含该方法的计算:https://github.com/Barrett-python/DuAT/blob/main/DuAT.py

输出结果:输出为网络模型的总参数量(单位M,即百万)与计算量(单位G,即十亿)

Flops:  19.2174G
params参数量:  21.7977M
   

参考链接:
    CNN 模型的参数(parameters)数量和浮点运算数量(FLOPs)是怎么计算的https://blog.csdn.net/weixin_41010198/article/details/108104309
    区分FLOPs和FLOPS:https://blog.csdn.net/IT_flying625/article/details/104898152
    pytorch得到模型的计算量和参数量https://blog.csdn.net/qq_35407318/article/details/109359006
    轻量化网络中常使用的参数量和计算量评估;https://blog.csdn.net/weixin_46274756/article/details/130391999
    Pytorch 中打印网络结构及其参数的方法与实现https://blog.csdn.net/like_jmo/article/details/126903727
    CNN 模型所需的计算力flops是什么?怎么计算?https://zhuanlan.zhihu.com/p/137719986
    FLOPS的计算:https://blog.csdn.net/baidu_35848778/article/details/127571810
————————————————
原文链接:https://blog.csdn.net/weixin_40893448/article/details/130395738