代码实现-小样本-RN

发布时间 2023-07-23 17:37:50作者: 太好了还有脑子可以用

此篇为《Learning to Compare Relation Network for Few-Shot Learning》

只实现了基于Omniglot数据集的小样本代码
datas为数据集
models为训练好的模型
venv为配置文件
下面的py文件是具体实现代码

1.结构

image

2.问题:KeyError: '..\datas\omniglot_resized'

报错信息:
  File "LearningToCompare_FSL-master/omniglot/omniglot_train_few_shot.py", line 163, in main
    task = tg.OmniglotTask(metatrain_character_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)
  File "LearningToCompare_FSL-master\omniglot\task_generator.py", line 72, in <listcomp>
    self.train_labels = [labels[self.get_class(x)] for x in self.train_roots]
KeyError: '..\\datas\\omniglot_resized'

由于linux和window路径的转换,需要把把'/'改成'\'即可。
image

def get_class(self, sample):
return os.path.join(*sample.split('\\')[:-1])

3.问题:IndexError: invalid index of a 0-dim tensor.

报错信息:
File "LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 212, in main
    print("episode:",episode+1,"loss",loss.data[0])
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

按要求改成
image

if (episode + 1) % 100 == 0:
print("episode:", episode + 1, "loss", loss.item())

4.问题:RuntimeError: output with shape [1, 28, 28]

报错信息:
File "LearningToCompare_FSL-master\omniglot\task_generator.py", line 107, in __getitem__
    image = self.transform(image)
  File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\transforms.py", line 60, in __call__
    img = t(img)
  File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\transforms.py", line 163, in __call__
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\functional.py", line 208, in normalize
    tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

这个是使用Omniglot数据集时的报错,主要原因在于使用 torch.transforms 中 normalize 用了 3 通道,而实际使用的数据集Omniglot 图片大小是 [1, 28, 28],只需要把
normalize =transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
改成
image

normalize = transforms.Normalize(mean=[0.92206], std=[0.08426])
dataset = Omniglot(task,split=split,transform=transforms.Compose([Rotate(rotation)

5.问题:

6.问题:

7.问题: