pytorch转onnx中关于卷积核的问题

发布时间 2023-12-29 16:47:48作者: ilk012

pytorch导出onnx过程中报如下错误:

RuntimeError: Unsupported: ONNX export of convolution for kernel of unknown shape.

我报错的部分代码如下:

def forward(self, input):
    n, c, h, w = input.size()
    s = self.scale_factor

    # pad input (left, right, top, bottom)
    input = F.pad(input, (0, 1, 0, 1), mode='replicate')

    # calculate output (height)
    kernel_h = self.kernels.repeat(c, 1).view(-1, 1, s, 1)
    output = F.conv2d(input, kernel_h, stride=1, padding=0, groups=c)
    output = output.reshape(
        n, c, s, -1, w + 1).permute(0, 1, 3, 2, 4).reshape(n, c, -1, w + 1)

    # calculate output (width)
    kernel_w = self.kernels.repeat(c, 1).view(-1, 1, 1, s)
    output = F.conv2d(output, kernel_w, stride=1, padding=0, groups=c)
    output = output.reshape(
        n, c, s, h * s, -1).permute(0, 1, 3, 4, 2).reshape(n, c, h * s, -1)

    return output

原因是使用卷积函数(torch.nn.functional.conv2d,而非卷积层torch.nn.conv2d)时不能使kernel_size为可变值,即kernel_size不能受输入变量影响。

上述代码中kernel_h和kernel_w均由self.kernels变换而来,其中repeat()受输入变量input维度c影响,导致报错。

最简单的解决方法为,在开始加一句c=int(c),使变量c对onnx来说为int常数而非torch.Tensor,从而实现kernel_h和kernel_w的size固定,最终代码如下:

def forward(self, input):
    n, c, h, w = input.size()
    s = self.scale_factor
    c = int(c)

    # pad input (left, right, top, bottom)
    input = F.pad(input, (0, 1, 0, 1), mode='replicate')

    # calculate output (height)
    kernel_h = self.kernels.repeat(c, 1).view(-1, 1, s, 1)
    output = F.conv2d(input, kernel_h, stride=1, padding=0, groups=c)
    output = output.reshape(
        n, c, s, -1, w + 1).permute(0, 1, 3, 2, 4).reshape(n, c, -1, w + 1)

    # calculate output (width)
    kernel_w = self.kernels.repeat(c, 1).view(-1, 1, 1, s)
    output = F.conv2d(output, kernel_w, stride=1, padding=0, groups=c)
    output = output.reshape(
        n, c, s, h * s, -1).permute(0, 1, 3, 4, 2).reshape(n, c, h * s, -1)

    return output

但是这样同样会导致导出onnx时报Warning:

TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  c = int(c)

在我的应用中,c始终为定值,因此,不需要担心c值更改后无法跟踪的问题,所以无视之~