SFNet_FFTBlock验证模块的有效性

发布时间 2023-10-31 09:34:21作者: helloWorldhelloWorld

五、序号5,使用identityConv进行残差连接,最后对增强后的幅值、增强后的相位、空域进行Concat

class YYBlock(nn.Module):
    def __init__(self, in_channel=3, out_channel=20, relu_slope=0.2):
        super(YYBlock, self).__init__()

        self.spatialConv = nn.Sequential(*[
            nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1, bias=True),
            nn.LeakyReLU(relu_slope, inplace=False),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, padding=1, bias=True),
            nn.LeakyReLU(relu_slope, inplace=False)
        ])

        self.identity = nn.Conv2d(in_channel, out_channel, 1, 1, 0)

        self.fftConv2 = nn.Sequential(*[
            nn.Conv2d(out_channel, out_channel, 1, 1, 0),
            nn.LeakyReLU(relu_slope, inplace=False),
            nn.Conv2d(out_channel, out_channel, 1, 1, 0)
        ])

        self.fusion = nn.Conv2d(out_channel * 3, out_channel, 1, 1, 0)

        # self.conv_01 = nn.Conv2d(in_size, out_size, 3, 1, 1)

    def forward(self, x1):
        spatial_out = self.spatialConv(x1)
        identity = self.identity(x1)
        out = spatial_out + identity

        x_fft = torch.fft.rfft2(out, norm='backward')
        x_amp = torch.abs(x_fft)
        x_phase = torch.angle(x_fft)

        enhanced_phase = self.fftConv2(x_phase)
        enhanced_amp = self.fftConv2(x_amp)
        x_fft_out1 = torch.fft.irfft2(x_amp * torch.exp(1j * enhanced_phase), norm='backward')
        x_fft_out2 = torch.fft.irfft2(enhanced_amp * torch.exp(1j * x_phase), norm='backward')

        out = self.fusion(torch.cat([out, x_fft_out1, x_fft_out2], dim=1))

        return out
YYBlock