模型转onnx遇到问题,报错 1. _thnn_fused_lstm_cell , 2._thnn_fused_gru_cell, 3. Exporting the operator numpy_T to ONNX opset version 11 is not supported.

发布时间 2023-09-22 13:22:05作者: 无左无右


RuntimeError: Exporting the operator _thnn_fused_lstm_cell to ONNX opset version 11 is not supported. Please open a bug to request ONNX export support for the missing operator.
RuntimeError: Exporting the operator _thnn_fused_gru_cell to ONNX opset version 11 is not supported. Please open a bug to request ONNX export support for the missing operator.

GRU

import torch
from torch import nn

class GRU_test(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(GRU_test, self).__init__()
        self.rnn_cell = nn.GRUCell(input_size, hidden_size)  # 66   64

    def forward(self, input, hx):
        out = self.rnn_cell(input, hx)
        return out


device = "cuda:0"
# device = "cpu"
net = GRU_test(10, 20).to(device)
input = torch.randn(3, 10).to(device)
hx = torch.randn(3, 20).to(device)
out = net(input, hx)

c = 0

#########################
torch.onnx.export(net, (input, hx),
                  "./only-gru.onnx", keep_initializers_as_inputs=True, opset_version=11, verbose=True,
                  input_names=["input", "hx"],
                  output_names=["out"])

print("convert_onnx succesfully!!")

LSTM

import torch
from torch import nn


class LSTM_TEST(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTM_TEST, self).__init__()
        # self.rnn_cell = nn.GRUCell(input_size, hidden_size)  # 66   64
        self.rnn_cell = nn.LSTMCell(input_size, hidden_size)  # 66   64

    def forward(self, input, hx):
        out = self.rnn_cell(input, hx)
        return out



device = "cuda:0"
# device = "cpu"

net = LSTM_TEST(204, 64).to(device)
input = torch.randn(10, 204).to(device)
hx = (torch.randn(10, 64).to(device), torch.randn(10, 64).to(device))

# hx = torch.randn(3, 20).to(device)
out = net(input, hx)
# print("out shape=", out.shape)

# input = torch.randn(10, 32, 100)
# lstm = nn.LSTMCell(100, 8)
# output = []
# for time_data in input:
#     out, _ = lstm(time_data)
#     output.append(out)
# output = torch.stack(output)
# print(output.shape)


#########################
torch.onnx.export(net, (input, hx),
                  "./only-lstm.onnx", keep_initializers_as_inputs=True, opset_version=11, verbose=True,
                  input_names=["input", "hx"],
                  output_names=["out"])
#   output_names=["hm", "reg", "rot", "wh", "vel", "nmp","dxy", "rnn", "x_grid", "x_img1", "x_img2", "x_seg", "rots_cls", "rots_regs"])
#   output_names=["hm", "reg", "rot", "wh", "vel", "nmp", "rnn", "x_grid", "x_img1", "x_img2", "x_seg", "rots_cls", "rots_regs", "traj_feature"])
#   output_names=["hm0", "hm1" ,"hm2", "reg", "rot", "wh", "vel", "nmp", "rnn", "x_grid", "x_img1", "x_img2", "x_seg", "rots_cls", "rots_regs"])

print("convert_onnx succesfully!!")

解决方案,换到cpu上面运行, device = "cpu"

Exporting the operator numpy_T to ONNX opset version 11 is not supported

import torch

# create a simple module
class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()

    def forward(self, x):
        return x.T

dummy_input = torch.randn(1, 300)
torch.onnx.export(MyModule(), dummy_input, "aten_transpose_issue.onnx", opset_version=11, verbose=True)

报错如下:

RuntimeError: Exporting the operator numpy_T to ONNX opset version 11 is not supported. Please open a bug to request ONNX export support for the missing operator.

Process finished with exit code 1

确实是因为XXX.T .
但是我的代码网络搜索了一圈下来,发现没有用到XXX.T,
最后发现是调用的库函数确实用到了XXX.T, 解决方案如下:

from torch.distributions import Categorical
class MyCategorical(Categorical):
    def __init__(self, logits):
        super().__init__(logits=logits)

    def sample(self, sample_shape=torch.Size()):
        if not isinstance(sample_shape, torch.Size):
            sample_shape = torch.Size(sample_shape)
        probs_2d = self.probs.reshape(-1, self._num_events)
        samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True)
        samples_2d = samples_2d.transpose(0, 1) # instead of .T
        return samples_2d.reshape(self._extended_shape(sample_shape))