2022CVPR_SNR-Aware Low-light Image Enhancement(SNR)

发布时间 2023-05-04 15:29:35作者: helloWorldhelloWorld

一. motivation

二. contribution

三.Network

 1. 对于低光照的图片首先采用公式2获得SNR Map

 (1)

Ig:是低光图片

:是经过cv.blur进行均值滤波后的图像

 (2) 对Ig和Ig' 取得灰度图进行绝对值相减得到噪声N

 (3)SNR(mask):均值滤波后的图像与噪声相除得到S

2. 先进行浅层特征提取

 3.  对于fea进行深层特征提取,fea进行两个分支,一个分支(短分支)进行卷积块(更容易捕获局部信息)进行残差连接,另外一个分支(长分支)进行SNR引导的transformer结构(更容易捕获全局信息)

(1)短分支结构

 (2)长分支结构

 transformer包括Attention和FeedForward

对于Attention,对输入的特征进行归一化分别赋给q,k,v

q*k的转置是查看patch之间的相似度

对于mask的操作: 将mask分成s’个pacth,之后按照dim=2取得了mask的均值,采用公式,如果mask 的均值<0.5 则 取值为0,便于后面attention操作

      

 对于mask取值为0的地方,attn中填充为很大的负数,在计算softmax这部分像素值不会发生变化

 FeedForward: 首先归一化,之后全连接+relu,再归一化,一个跳连接。最后输出经过transformer结构的fea_unfold

 

再利用mask:进行特征融合

 

 

四. 损失

1.CharbonnierLoss2:

 大佬链接:https://blog.csdn.net/weixin_43135178/article/details/120865709

class CharbonnierLoss2(nn.Module):
    """Charbonnier Loss (L1)"""

    def __init__(self, eps=1e-6):
        super(CharbonnierLoss2, self).__init__()
        self.eps = eps

    def forward(self, x, y):
        diff = x - y
        loss = torch.mean(torch.sqrt(diff * diff + self.eps))
        return loss

2. perceptual loss

 对于nn.L1Loss中reduction的说明:https://blog.csdn.net/qq_39450134/article/details/121745209

import torchvision
class VGG19(torch.nn.Module):
    def __init__(self, requires_grad=False):
        super().__init__()
        vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out


class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss, self).__init__()
        self.vgg = VGG19().cuda()
        # self.criterion = nn.L1Loss()
        self.criterion = nn.L1Loss(reduction='sum')
        self.criterion2 = nn.L1Loss()
        self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]

    def forward(self, x, y):
        x_vgg, y_vgg = self.vgg(x), self.vgg(y)
        loss = 0
        for i in range(len(x_vgg)):
            # print(x_vgg[i].shape, y_vgg[i].shape)
            loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
        return loss