利用 FCN 使得 ResNet 允许任意大小图片输入

发布时间 2023-12-27 16:21:57作者: 倒地

阅读这个网站写的一些备忘。

通过少量修改 ResNet18 网络结构的形式,对全卷积网络方案一窥究竟。

允许网络输入任意大小的图像

一般的卷积网络,会因为全连接层 nn.Linear 的存在,而仅允许固定大小的图像输入。

全卷积网络 FCN 使用 1×1 的卷积核,回避了全连接层的缺陷。

不摒弃全连接层的解决方法

ResNet 的 torchvision 实现中,在最后的全连接层之前有一个 nn.AdaptiveAvgPool2d((1, 1))

class ResNet:
 
    # ...
    self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
    self.bn1 = norm_layer(self.inplanes)
    self.relu = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.layer1 = self._make_layer(block, 64, layers[0])
    self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate = replace_stride_with_dilation[0])
    self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate = replace_stride_with_dilation[1])
    self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate = replace_stride_with_dilation[2])
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Linear(512 * block.expansion, num_classes)

nn.AdaptiveAvgPool2d((1, 1)) 会让每个通道上只有一个 (1, 1) 的像素点,保证全连接层接收的参数数量固定。输入图像因此不局限于固定长宽比和大小。

例如,输入 tensor 大小为 8×5×16 或 2×2×16,经过此层会都变为 1×1×16

使用 FCN 替换全连接层

对上一章的 ResNet 代码进行修改。

具体来说:

  • nn.AdaptiveAvgPool2d((1, 1))修改为 nn.AvgPool2d((7, 7))
  • 将全连接层换为 torch.nn.Conv2d( in_channels = self.fc.in_features, out_channels = num_classes, kernel_size = 1),即卷积核为 1 的、输入输出维度不变的 二维卷积层
class FullyConvolutionalResnet18(models.ResNet):
    def __init__(self, num_classes=1000, pretrained=False, **kwargs):
        # Start with standard resnet18 defined here 
        super().__init__(block = models.resnet.BasicBlock, layers = [2, 2, 2, 2], num_classes = num_classes, **kwargs)
        if pretrained:
            state_dict = load_state_dict_from_url( models.resnet.model_urls["resnet18"], progress=True)
            self.load_state_dict(state_dict)
 
        # Replace AdaptiveAvgPool2d with standard AvgPool2d 
        self.avgpool = nn.AvgPool2d((7, 7))
 
        # Convert the original fc layer to a convolutional layer.  
        self.last_conv = torch.nn.Conv2d( in_channels = self.fc.in_features, out_channels = num_classes, kernel_size = 1)
        self.last_conv.weight.data.copy_( self.fc.weight.data.view ( *self.fc.weight.data.shape, 1, 1))
        self.last_conv.bias.data.copy_ (self.fc.bias.data)

最后三行,是将原来全连接层的权重,转移到了二维卷积层。
全连接层的 weight 权重 size 为 $\text{out_features} \times \text{in_features}$;二维卷积层的 weight 权重 size 为 $1 \times 1 \times \text{out_features} \times \text{in_features}$。可见,仅需一点变换,就能迁移权重。

整体代码

import torch
import torch.nn as nn
from torchvision import models
from PIL import Image
import cv2
import numpy as np
from matplotlib import pyplot as plt
from torchvision import transforms
from einops import rearrange, reduce, repeat
class FNC_Resnet18(models.ResNet):
    def __init__(self):
        # 创建 resnet18 的网络结构
        super(FNC_Resnet18, self).__init__(
            block=models.resnet.BasicBlock, layers=[2, 2, 2, 2], num_classes=1000)

        # 需要提前下载权重 https://download.pytorch.org/models/resnet18-f37072fd.pth
        state_dict = torch.load('resnet18-f37072fd.pth')
        self.load_state_dict(state_dict)

        self.avgpool = nn.AvgPool2d((6, 6))

        self.last_conv = torch.nn.Conv2d(
            in_channels=self.fc.in_features, out_channels=self.fc.out_features, kernel_size=1)
        self.last_conv.weight.data.copy_(
            self.fc.weight.data.view(*self.fc.weight.data.shape, 1, 1))
        self.last_conv.bias.data.copy_(self.fc.bias.data)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = self.last_conv(x)

        return x
model = FNC_Resnet18()

# 需要下载 https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
with open('imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])
input_image = Image.open('R.jpg')  # 随意选择一张图片
input_tensor = transform(input_image)
input_batch = input_tensor.unsqueeze(0)

model.eval()
with torch.no_grad():
    output = model(input_batch)
    print(f"output.shape: {output.shape}")
    output = rearrange(output, 'b c h w -> b (h w) c')

    summed_output = output.sum(dim=1)

    # 找到前五个最大值的值和索引
    for batch in summed_output:
        top5_values, top5_indices = torch.topk(batch, 5)
        # 打印 label
        for value, index in zip(top5_values, top5_indices):
            print(f"{labels[index]}: {value}")