torchvision中的数据集使用

发布时间 2023-08-11 12:13:12作者: ydky

torchvision中的数据集使用

1.torchvision介绍

torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型,一般包括左侧几个模块。

pytorch官网-Docs-torchvision(左侧修改为0.90版本就可以直接看到datasets)

torchvision.datasets:包含常用的数据集API文档,设置一些参数即可下载和使用这些数据集。

COCO数据集:常用于目标检测、语义分割

MNIST数据集:手写文字数据集(一般为入门数据集)

CIFAR数据集:常用于物体识别

torchvision.io:输入输出模块。

torchvision.models:包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等。

torchvision.ops:提供一些少见的特殊的操作。

torchvision.transforms:常用的图片变换,例如类型转换、裁剪等。
torchvision.utils:其他的一些有用的方法。

2.举例说明

本次以CIFAR10为例进行数据集的使用(观察参数设置):

数据集的使用代码

import torchvision

# 将数据集下载到本地的文件夹中用作训练集和测试集
train_set = torchvision.datasets.CIFAR10(root="./dataset2",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset2",train=False,download=True)

print(test_set[0])
print(test_set.classes)

img, target = test_set[0]
print(img)
print(target)
print(train_set.classes[target])
img.show()

dataset和transforms的结合使用:

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

# 添加transforms参数可以对数据集进行转换操作
train_set = torchvision.datasets.CIFAR10(root="./dataset2",train=True,transform=dataset_transforms, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset2",train=False,transform=dataset_transforms, download=True)

# print(test_set[0])

writer = SummaryWriter("logs")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)

writer.close()