FCN-全卷积网络-pytorch搭建

发布时间 2023-08-06 16:44:10作者: 倦鸟已归时

代码摘自:https://github.com/sovit-123/Semantic-Segmentation-using-Fully-Convlutional-Networks

预备知识:

下载预训练权重,抽取出网络层实例:运行如下代码,自动下载到 C:\Users\**\.cache\torch\hub\checkpoints 目录下。

vgg = models.vgg16(pretrained=True)

抽取网络层,vgg.features 是 VGG16 的特征抽取网络部分(卷积网络),vgg 还有 vgg.classifier 表示分类器部分(全连接网络)。

print("----show VGG16's features.children()----")

# feats = vgg.features.children()  # <generator object Module.children at 0x0000021CCC997580>
feats = list(vgg.features.children())
# print(*feats)  # 解包列表,打印列表里的所有元素(*list 只能作为函数参数,无法直接运行)

for i, layer in enumerate(feats):
    print("====={0}======".format(i))
    print(layer)  # 每一个网络层
# print(feats[0:9])  # 获取 0-8 层 共前9层网络
# print(*feats[0:9])  # 解包列表,不再是列表而是9个参数

卷积网络和反卷积网络,两者操作互逆

con = nn.Conv2d(1,16,kernel_size=(3,3),stride=(2,2),padding=(1,1))
dec = nn.ConvTranspose2d(16,1, kernel_size=(3,3), stride=(2,2), padding=(1,1), bias=False)
feat = torch.randn((1, 5, 5))
feat_c = con(feat)
feat_d = dec(feat_c)
print(feat.shape)
print(feat_c.shape)
print(feat_d.shape)

模型搭建全部代码,仅把模型部分摘出作为参考:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import logging
from itertools import chain

# 一个基类,定义了一个模型的“描述信息的功能”,例如logger、print
class BaseModel(nn.Module):
    def __init__(self):
        super(BaseModel, self).__init__()
        self.logger = logging.getLogger(self.__class__.__name__)
    # 子类必须重写的类
    def forward(self):
        raise NotImplementedError
    
    # 打印到log文件中
    def summary(self):
        # 计数 所有参数的个数
        total_params = sum(p.numel() for p in self.parameters())
        print(f"{total_params:,} total parameters.")
        total_trainable_params = sum(
            p.numel() for p in self.parameters() if p.requires_grad)
        print(f"{total_trainable_params:,} training parameters.")
        self.logger.info(f'Nbr of trainable parameters: {total_trainable_params}')
    
    # 返回信息描述
    def __str__(self):
        total_params = sum(p.numel() for p in self.parameters())
        print(f"{total_params:,} total parameters.")
        total_trainable_params = sum(
            p.numel() for p in self.parameters() if p.requires_grad)
        print(f"{total_trainable_params:,} training parameters.")
        return super(BaseModel, self).__str__() + f'\nNbr of trainable parameters: {total_trainable_params}'

上采样权重

# 此处定义的 上采样卷积核权重 是个固定值
# 返回 k 个 k层tensor,每个tensor都是k个矩阵,其中第i个tensor的第i个矩阵为一个高斯核,其他都是0
# 例如 k=3,[[g, 0, 0],[0, g, 0],[0, 0, g]]
def get_upsampling_weight(in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    # 返回两个长度为 kernel_size 的向量,两者点乘得到一个矩阵(类似 meshgrid 的矩阵)
    og = np.ogrid[:kernel_size, :kernel_size]
    filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
    weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64)
    weight[list(range(in_channels)), list(range(out_channels)), :, :] = filt
    return torch.from_numpy(weight).float()

FCN8 模型,该模型的 backbone (特征提取器网络) 采用 VGG16,是pytorch库的预训练权重。

class FCN8(BaseModel):
    def __init__(self, num_classes, pretrained=True, freeze_bn=False, **_):
        super(FCN8, self).__init__()
        vgg = models.vgg16(pretrained)
        features = list(vgg.features.children())
        classifier = list(vgg.classifier.children())
        features[0].padding = (100, 100)
        for layer in features:
            if 'MaxPool' in layer.__class__.__name__:
                # __class__形如 torch.nn.modules.conv.Conv2d
                # __name__ 即为 Conv2d
                # # enbale ceil in max pool, to avoid different sizes when upsampling
                layer.ceil_mode = True
        # extract pool3, pool4 and pool5 from the VGG net
        # 取前17层为第一特征模块
        self.pool3 = nn.Sequential(*features[:17])
        # 取前17-23层为第二特征模块
        self.pool4 = nn.Sequential(*features[17:24])
        # 取24层及之后所有的为第三特征模块
        self.pool5 = nn.Sequential(*features[24:])
        
        # adjust the depth of pool3 and pool4 to num_classes
        self.adj_pool3 = nn.Conv2d(256, num_classes, kernel_size=1)
        self.adj_pool4 = nn.Conv2d(512, num_classes, kernel_size=1)
        
        # replace the FC layer of VGG with conv layers
        conv6 = nn.Conv2d(512, 4096, kernel_size=7)
        conv7 = nn.Conv2d(4096, 4096, kernel_size=1)
        output = nn.Conv2d(4096, num_classes, kernel_size=1)
        
        # copy the weights from VGG's FC pretrained layers
        conv6.weight.data.copy_(classifier[0].weight.data.view(
            conv6.weight.data.size()))
        conv6.bias.data.copy_(classifier[0].bias.data)
        
        conv7.weight.data.copy_(classifier[3].weight.data.view(
            conv7.weight.data.size()))
        conv7.bias.data.copy_(classifier[3].bias.data)
        
        # get the outputs
        self.output = nn.Sequential(conv6, nn.ReLU(inplace=True), nn.Dropout(),
                                    conv7, nn.ReLU(inplace=True), nn.Dropout(), 
                                    output)
        
        # we'll need three upsampling layers, upsampling (x2 +2) the outputs
        # upsampling (x2 +2) addition of pool4 and upsampled output 
        # upsampling (x8 +8) the final value (pool3 + added output and pool4)
        self.up_output = nn.ConvTranspose2d(num_classes, num_classes,
                                            kernel_size=4, stride=2, bias=False)
        self.up_pool4_out = nn.ConvTranspose2d(num_classes, num_classes, 
                                            kernel_size=4, stride=2, bias=False)
        self.up_final = nn.ConvTranspose2d(num_classes, num_classes, 
                                            kernel_size=16, stride=8, bias=False)
        
        # we'll use guassian kernels for the upsampling weights
        self.up_output.weight.data.copy_(
            get_upsampling_weight(num_classes, num_classes, 4))
        self.up_pool4_out.weight.data.copy_(
            get_upsampling_weight(num_classes, num_classes, 4))
        self.up_final.weight.data.copy_(
            get_upsampling_weight(num_classes, num_classes, 16))
        
        # we'll freeze the wights, this is a fixed upsampling and not deconv
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose2d):
                m.weight.requires_grad = False
        if freeze_bn: self.freeze_bn()
    
    def forward(self, x):
        imh_H, img_W = x.size()[2], x.size()[3]
        
        # forward the image
        pool3 = self.pool3(x)
        pool4 = self.pool4(pool3)
        pool5 = self.pool5(pool4)

        # get the outputs and upsmaple them
        output = self.output(pool5)
        up_output = self.up_output(output)

        # adjust pool4 and add the uped-outputs to pool4
        adjstd_pool4 = self.adj_pool4(0.01 * pool4)
        add_out_pool4 = self.up_pool4_out(adjstd_pool4[:, :, 5: (5 + up_output.size()[2]), 
                                            5: (5 + up_output.size()[3])]
                                           + up_output)
        
        # adjust pool3 and add it to the uped last addition
        adjstd_pool3 = self.adj_pool3(0.0001 * pool3)
        final_value = self.up_final(adjstd_pool3[:, :, 9: (9 + add_out_pool4.size()[2]), 9: (9 + add_out_pool4.size()[3])]
                                 + add_out_pool4)

        # remove the corresponding padded regions to the input img size
        final_value = final_value[:, :, 31: (31 + imh_H), 31: (31 + img_W)].contiguous()
        return final_value
    
    def get_backbone_params(self):
        return chain(self.pool3.parameters(), self.pool4.parameters(), self.pool5.parameters(), self.output.parameters())

    def get_decoder_params(self):
        return chain(self.up_output.parameters(), self.adj_pool4.parameters(), self.up_pool4_out.parameters(),
            self.adj_pool3.parameters(), self.up_final.parameters())

    def freeze_bn(self):
        for module in self.modules():
            if isinstance(module, nn.BatchNorm2d): module.eval()

定义一个张量测试一下前向推理

fcn8 = FCN8(9)
x = torch.randn((4, 3, 28, 28))
fcn8(x)