Matching Network算法概述

发布时间 2023-10-21 16:05:59作者: HoroSherry

什么是Matching Network

1. 论文地址:Matching Networks for One Shot Learning

2. 简介:基于Metric Learning部分思想,使用外部记忆来增强网络,提高网络的学习能力。

3. 创新点

  • 借鉴了注意力和外部记忆方面的经验来搭建网络
  • 基于meta-learning用task来训练,而不是metric-learning输入固定类别的图片

4. 算法描述

Matching Network有两个输入

  1. 输入任务S为一个N-way K-shot的任务(下图中是一个4way 1shot的任务),其中\(S=\left(x_i, y_i\right)_{i=1}^k\)
  2. 需要预测类别的图片\(\hat{x}\)

Matching Network的输出被定义为:

图片\(\hat{x}\)的预测类别\(\hat{y}\)

那么Matching Network算法就可以被构建为\(P(\hat{y} \mid \hat{x}, S)\)

其中,\(P(.)\)为网络的参数映射,即注意力和外部记忆

(1) 注意力
  1. 简单范式

​ 论文中给了一个简单的注意力范式:\(\hat{y}=\sum_{i=1}^k a\left(\hat{x}, x_i\right) y_i\) 这里用\(a(.)\)做注意力计算,计算\(\hat{x}\)所有给定标签输入\(x\)的关系,然 后将这种关系与\(\hat{y}\)进行对应,从而求解需预测类别输入\(\hat{x}\)的预测类别\(\hat{y}\)

  1. 余弦距离注意力

​ 直观上想,很容易想到注意力\(a(.)\)的定义可以选择一种metric指标(如:余弦距离),在浅层的向量空间求解两张图片的类似度/距 离。

​ 论文中定义了一个余弦距离注意力:

\[a\left(\hat{x}, x_i\right)=e^{c\left(f(\hat{x}), g\left(x_i\right)\right)} / \sum_{j=1}^k e^{c\left(f(\hat{x}), g\left(x_j\right)\right)} \]

​ 其中\(c(.)\)为余弦距离,\(f(\hat{x})\)输入\(\hat{x}\)的浅层向量表示,\(g(x_j)\)输入标签\(x_j\)的浅层向量表示。论文中提到的\(f(.)\)\(g(.)\)是共享参 数的(也就是同一个CNN网络)。

(2) 外部记忆

​ 作者作者认为上述的余弦注意力定义的时候,(输入任务S中)每个已知标签的输入\(x_i\)通过CNN后的embedding,也就是 \(g(\hat{x_i})\)是 独立的,前后没有关系,然后与\(f(\hat{x})\)进行逐个对比,这看起来就有点简单粗暴,没有考虑到输入任务S改变embedding \(\hat{x_i}\) 的方式, 也就是\(f(.)\)应该是受\(g(S)\)影响的。

​ 对此,作者提出了双向LSTM来解决这个问题。

5. 网络设计

算法描述

  1. 将任务S中所有图片\(x_i\)和目标图片\(\hat{x}\)全部通过CNN网络,以获得它们的浅层向量表示,然后将这\(k+1\)个向量进行堆叠
  2. 将以上堆叠的浅层向量全部输入到双向LSTM中,获得\(k+1\)个输出。然后使用余弦距离判断前\(k\)个输出中与最后一个输出之间的相似度
  3. 根据计算出的相似度,按照任务中\(S\)中的标签信息求解目标图片\(\hat{x}\)的类别标签

核心代码

class MatchingNetwork(nn.Module):
    def __init__(self, keep_prob, \
                 batch_size=100, num_channels=1, learning_rate=0.001, fce=False, num_classes_per_set=5, \
                 num_samples_per_class=1, nClasses = 0, image_size = 28):
        super(MatchingNetwork, self).__init__()

        """
        Builds a matching network, the training and evaluation ops as well as data augmentation routines.
        :param keep_prob: A tf placeholder of type tf.float32 denotes the amount of dropout to be used
        :param batch_size: The batch size for the experiment
        :param num_channels: Number of channels of the images
        :param is_training: Flag indicating whether we are training or evaluating
        :param rotate_flag: Flag indicating whether to rotate the images
        :param fce: Flag indicating whether to use full context embeddings (i.e. apply an LSTM on the CNN embeddings)
        :param num_classes_per_set: Integer indicating the number of classes per set
        :param num_samples_per_class: Integer indicating the number of samples per class
        :param nClasses: total number of classes. It changes the output size of the classifier g with a final FC layer.
        :param image_input: size of the input image. It is needed in case we want to create the last FC classification 
        """
        self.batch_size = batch_size
        self.fce = fce
        self.g = Classifier(layer_size = 64, num_channels=num_channels,
                            nClasses= nClasses, image_size = image_size )
        if fce:
            self.lstm = BidirectionalLSTM(layer_sizes=[32], batch_size=self.batch_size, vector_dim = self.g.outSize)
        self.dn = DistanceNetwork()
        self.classify = AttentionalClassify()
        self.keep_prob = keep_prob
        self.num_classes_per_set = num_classes_per_set
        self.num_samples_per_class = num_samples_per_class
        self.learning_rate = learning_rate

    def forward(self, support_set_images, support_set_labels_one_hot, target_image, target_label):
        """
        Builds graph for Matching Networks, produces losses and summary statistics.
        :param support_set_images: A tensor containing the support set images [batch_size, sequence_size, n_channels, 28, 28]
        :param support_set_labels_one_hot: A tensor containing the support set labels [batch_size, sequence_size, n_classes]
        :param target_image: A tensor containing the target image (image to produce label for) [batch_size, n_channels, 28, 28]
        :param target_label: A tensor containing the target label [batch_size, 1]
        :return: 
        """
        # produce embeddings for support set images
        # (batch_size,shot_num,3,img_size,img_size)
        encoded_images = []
        for i in np.arange(support_set_images.size(1)):
            gen_encode = self.g(support_set_images[:,i,:,:,:])
            encoded_images.append(gen_encode)

        # produce embeddings for target images
        for i in np.arange(target_image.size(1)):
            gen_encode = self.g(target_image[:,i,:,:,:])
            encoded_images.append(gen_encode)
            outputs = torch.stack(encoded_images)

            if self.fce:
                outputs, hn, cn = self.lstm(outputs)

            # get similarity between support set embeddings and target
            similarities = self.dn(support_set=outputs[:-1], input_image=outputs[-1])
            similarities = similarities.t()

            # produce predictions for target probabilities
            preds = self.classify(similarities,support_set_y=support_set_labels_one_hot)

            # calculate accuracy and crossentropy loss
            values, indices = preds.max(1)
            if i == 0:
                accuracy = torch.mean((indices.squeeze() == target_label[:,i]).float())
                crossentropy_loss = F.cross_entropy(preds, target_label[:,i].long())
            else:
                accuracy = accuracy + torch.mean((indices.squeeze() == target_label[:, i]).float())
                crossentropy_loss = crossentropy_loss + F.cross_entropy(preds, target_label[:, i].long())

            # delete the last target image encoding of encoded_images
            # make the embedding vector for each new target images to be at the end of the list
            encoded_images.pop()

        return accuracy/target_image.size(1), crossentropy_loss/target_image.size(1)