显著图:Saliency Map 详解

发布时间 2023-12-31 10:42:58作者: 独上兰舟1

https://zhuanlan.zhihu.com/p/644181243

 

泻药。最近在研究一些基于saliency map的归因方法,在这里分享一下对saliency map的开山鼻祖VGG2014:Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps()的解读,这篇文章的引用量已经达到了6k+。

相比于其他的解释方法,显著图(saliency map)是弱监督的,只需要一个分类模型就能完成显著图的生成。这篇文章由于是开山鼻祖,所以有非常严谨的推导,非常简洁,建议仔细阅读。

模型对类的判断标准的可视化

用 ��(⋅) 表示图像矩阵 � 上的每一个像素对类 � 的分数函数, � 表示图像,我们希望将这个图像迭代成模型对一个类的判断标准,则这个任务的目标函数 �(�) 可以用下面的式子表示。 |�|2 表示图像或像素的L2正则(欧几里得正则),防止图像 � 的像素值过大而产生过拟合。

(1)�(�)=arg⁡max�(��(�)−�||�||22)

这里的分数函数 ��(⋅) 是CNN后面的MLP层的输出,而不是最后的概率 �� ,如果是分类概率 �� 还要过一层sigmoid:

(2)��=�(��)=exp⁡��∑�exp⁡��

这么做的原因是,我们目前仅考虑类别 � 的显著特征。这个任务中,如果引入其他类的知识,就相当于用了后验知识,会造成显著图生成的效果变差。

注意,在训练模型的时候,我们改变模型参数来优化。在模型训练完后我们想要获得一个可视化图像,是固定模型参数,转而改变图像 � 来优化,因此优化参数其实是图像 � 。所以 �(�) 式中,我们希望通过调整图像 � 来最大化 ��(�) 。分数函数 ��(�) 越大,说明图像 � 被分类成类别 � 的概率越大,这样图像 � 就显示了分类器认为的区别类别 � 和其他类别的特征。

优化之初,图像 � 被初始化为零图像(所有像素都是0)或均值图像(训练集图像取平均)。优化的过程就是传统的反向传播,只需要计算目标函数 �(�) 相对于每个像素的梯度,就可以用这个梯度不断更新图像 � 来获得最终的显著图。

部分类别的显著图如下:

模型对某个图像属于某个类的判断标准的可视化(显著图)

仅仅弄清楚模型对某个类的总体判断标准还不够,我们往往希望能够解释模型为什么对于某个特定的图像做出了一个分类,这就是显著图(saliency map)的目标。因此我们定义一个任务,希望可视化模型对某个图像属于某个类的判断标准。正式地,这个任务输入图像 �0 ,类别 � ,以及模型对该类的分数 ��(�0) ,我们希望输出一个显著图 � ,也就是 �0 中每一个像素对于分数 ��(�0) 的影响力的排序。

这个任务与上面的任务的一个容易混淆的地方是,上面的任务的 � 表示用于迭代的图像;而这个任务出现的 � 和 �0 都是我们想要解释的图像,但 �0 是 � 的特例(输入图像)。在这个任务中,我们希望生成一个 � 来作为 �0 的显著图。

不失一般性,我们先从简单的设定开始。对于类 � ,我们假设存在一个线性打分模型:(3)��(�)=����+��这个图像 � 是被flatten掉的一维向量, ��,�� 是模型的weight(一维向量)和bias(标量)。这种情况下,显然 � 中元素的大小会决定 � 向量对分类为类 � 的重要性。 � 中某个元素越大, � 中对应的像素就对分数 ��(�) 贡献越大。

然而,CNN模型的打分函数 ��(�) 显然是非线性的,不满足上面的式子。但是我们想到了泰勒展开,可以通过计算 ��(�) 这个高阶函数的一阶泰勒展开来近似地获得 ��(�) :(4)��(�)≈���+�这里的 � 是打分函数 �� 在图像 �0 处对图像 � 的梯度,是一个常数矩阵:

(5)�=∂��(�)∂�|�0

这里有的同学会感到奇怪, ��(�) 是用 � 定义的,而 � 的定义里面又出现了 �� ,这是怎么回事?这是因为这里的 ��(�) 是个完整的CNN而不是近似。为什么这里不需要近似?其实非常简单,因为 � 是导数在 �0 处的近似,所以 (5) 式就相当于输入 �0 ,经过CNN获得 ��(�0) 的值,然后立刻停止前向计算进入反向传播,传播回输入层后取出对输入图像 � 的梯度,就是 � 了,因此 � 可以直接获得的。注意,在训练模型的时候,这个 �� 对于 � 的梯度都是被扔掉的,框架只会保留对CNN中参数的梯度用于参数的更新;因此在具体实现时,需要用requires_grad_()来保留这个梯度。

对这个梯度 � 还有一种解释:它反映了 �0 中哪些像素在变化最小的时候,对分数 �� 的影响最大。这种像素往往代表着重要物体在图片中的位置。

注意这里我们不把 ��(�) 代回目标函数式 (1) ,因为式 (1) 认为 � 是要训练的参数,而我们可以直接通过梯度直接得到 � ,不需要任何训练。假设输入图像 �0 的形状为 �×� ,设类显著图为 �∈��×� ,可以直接通过梯度 � 获得显著图 � :

  1. 用反向传播求得梯度 � 。
  2. 将 � 重排成 �×� 的形状。
  3. 对单色位(灰度)图像而言, ���=|�ℎ(�,�)| 。这个 ℎ 函数将两个一维index转换为二维index。
  4. 对三色位(彩色)图像而言, ���=max�|�ℎ(�,�,�ℎ)| 。这里的 �ℎ 是channel的意思。这个式子就是在取所有channel里面magnitude最大的那个channel的值作为 � 的值。

这种方法最大的优势是,不需要任何额外的标注,可以直接通过在分类任务上训练后的模型得到显著图,甚至不需要再进行任何训练。这样的速度非常快,只需要一次反向传播。实验结果如下:

代码实现

这里给出saliency map的基本代码(源自towards science)。下面的例子展示了一个分类模型model对某个样本image在预测中最自信的类output_idx(其分数为output_max)上的显著图saliency

# Reshape the image (because the model use 
# 4-dimensional tensor (batch_size, channel, width, height))
image = image.reshape(1, 3, 224, 224)

# Set the requires_grad_ to the image for retrieving gradients
image.requires_grad_()

# Retrieve output from the image
output = model(image)

# Catch the output
output_idx = output.argmax()
output_max = output[0, output_idx]

# Do backpropagation to get the derivative of the output based on the image
output_max.backward()

# Retireve the saliency map and also pick the maximum value from channels on each pixel.
# In this case, we look at dim=1. Recall the shape (batch_size, channel, width, height)
saliency, _ = torch.max(image.grad.data.abs(), dim=1) 
saliency = saliency.reshape(224, 224)

# Reshape the image
image = image.reshape(-1, 224, 224)

弱监督目标定位

因为只需要在分类任务上训练就可以获得显著图,而显著图又能够反映物体的位置,所以可以使用显著图来做识别和分割。因为不需要这些定位的ground truth,所以方法是弱监督的。一些挑选过的例子如下: