通道注意力和空间注意力(CBAM)

发布时间 2024-01-03 13:48:17作者: 倒地

用实例说明通道注意力机制与空间注意力机制的内容。包含代码示例。

参考的博文:

pytorch中加入注意力机制(CBAM),以ResNet为例。解析到底要不要用ImageNet预训练?如何加预训练参数?

(六十一)通俗易懂理解——通道注意力机制和空间注意力机制(CBAM)

CBAM

Convolutional Block Attention Module (CBAM) 表示卷积模块的注意力机制模块,是一种结合了空间和通道的注意力机制模块[1]

通道注意力机制

例如对 batch x 7 x 7 x 512 的张量使用通道注意力机制,会获得 batch x 512 的张量作为这 512 个通道的权重。

即,通道注意力机制是使网络获得对通道的权重。

具体来说,若输入张量为 $x$,则运算为:

$$
\text{softmax}(\text{Linear}(\text{AvgPool}(x))+\text{Linear}(\text{MaxPool}(x)))
$$

实现代码:

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

空间注意力机制

例如对 batch x 7 x 7 x 512 的张量使用通道注意力机制,会获得 batch x 7 x 7 的张量作为这 7 x 7 空间的权重。

实现代码:

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

ResNet 添加注意力机制

以 Resnet 为例。

例如输入 Resnet 网络的图像大小为 224 x 224,网络最后会生成一个维度在 batch x 7 x 7 x 512 的张量,然后经过 nn.AdaptiveAvgPool2d((1, 1))nn.Linear 获得维度在 batch x 1000 的张量。

在送入 nn.AdaptiveAvgPool2d((1, 1))nn.Linear 层之前,进行通道注意力和空间注意力。

class ResNet(nn.Module):

    def __init__():
        ···
        self.ca = ChannelAttention(self.inplanes)
        self.sa = SpatialAttention()

    def forward(self, x):
        ···

        x = self.layer4(x)

        x = self.ca(x) * x
        x = self.sa(x) * x

        ···

        return x

  1. Woo, Sanghyun, et al. "Cbam: Convolutional block attention module." Proceedings of the European conference on computer vision (ECCV). 2018. https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf ↩︎