此篇为《Learning to Compare Relation Network for Few-Shot Learning》
只实现了基于Omniglot数据集的小样本代码
datas为数据集
models为训练好的模型
venv为配置文件
下面的py文件是具体实现代码
1.结构
2.问题:KeyError: '..\datas\omniglot_resized'
报错信息:
File "LearningToCompare_FSL-master/omniglot/omniglot_train_few_shot.py", line 163, in main
task = tg.OmniglotTask(metatrain_character_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,BATCH_NUM_PER_CLASS)
File "LearningToCompare_FSL-master\omniglot\task_generator.py", line 72, in <listcomp>
self.train_labels = [labels[self.get_class(x)] for x in self.train_roots]
KeyError: '..\\datas\\omniglot_resized'
由于linux和window路径的转换,需要把把'/'改成'\'即可。
def get_class(self, sample):
return os.path.join(*sample.split('\\')[:-1])
3.问题:IndexError: invalid index of a 0-dim tensor.
报错信息:
File "LearningToCompare_FSL-master/miniimagenet/miniimagenet_train_few_shot.py", line 212, in main
print("episode:",episode+1,"loss",loss.data[0])
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
按要求改成
if (episode + 1) % 100 == 0:
print("episode:", episode + 1, "loss", loss.item())
4.问题:RuntimeError: output with shape [1, 28, 28]
报错信息:
File "LearningToCompare_FSL-master\omniglot\task_generator.py", line 107, in __getitem__
image = self.transform(image)
File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\transforms.py", line 60, in __call__
img = t(img)
File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\transforms.py", line 163, in __call__
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "...\Anaconda3\envs\python36\lib\site-packages\torchvision\transforms\functional.py", line 208, in normalize
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]
这个是使用Omniglot数据集时的报错,主要原因在于使用 torch.transforms 中 normalize 用了 3 通道,而实际使用的数据集Omniglot 图片大小是 [1, 28, 28],只需要把
normalize =transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
改成
normalize = transforms.Normalize(mean=[0.92206], std=[0.08426])
dataset = Omniglot(task,split=split,transform=transforms.Compose([Rotate(rotation)