获取模型的参数量和计算复杂度

发布时间 2023-07-20 17:03:18作者: 孜孜不倦fly
import torch
import net.bilstm
import net.transformer
from ptflops import get_model_complexity_info
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# 统计Transformer模型的参数量和计算复杂度
model_transformer = net.transformer.AudioTransformer(80, 512, 6, 6)     #填写的是模型的参数
model_transformer.to(device)
flops_transformer, params_transformer = get_model_complexity_info(model_transformer, (2, 40, 256), as_strings=True, print_per_layer_stat=False)  #填写的是输入 
                                                                                                                                                 #网络x张量形状
print('Transformer模型参数量:' + params_transformer)
print('Transformer模型计算复杂度:' + flops_transformer)


# 统计BiLSTM模型的参数量和计算复杂度
model_bilstm = net.bilstm.BiLSTM(80, 512, 2, 6)
model_bilstm.to(device)
flops_bilstm, params_bilstm = get_model_complexity_info(model_bilstm, (2, 40, 256), as_strings=True, print_per_layer_stat=False)
print('BiLSTM模型参数量:' + params_bilstm)
print('BiLSTM模型计算复杂度:' + flops_bilstm)