3_transforms (pytorch tutorial)

发布时间 2023-04-23 10:11:05作者: 白尔特

Transforms

Data does not always come in its final processed form that is required for training machine learning algorithms. We use transforms to perform some manipulation of the data and make it suitable for training.

数据不总是以被处理好只需要机器学习的算法形式出现。我们使用transforms来操作数据,让它适合被训练。

All TorchVision datasets have two parameters -transform to modify the features and target_transform to modify the labels - that accept callables containing the transformation logic. The torchvision.transforms module offers several commonly-used transforms out of the box.

所有的TorchVision数据集都有2个参数transform调整特征,target_transform调整标签(在transformation逻辑下可调用的)。 torchvision.transforms模块提供几个常用的transforms可直接使用。

The FashionMNIST features are in PIL Image format, and the labels are integers. For training, we need the features as normalized tensors, and the labels as one-hot encoded tensors. To make these transformations, we use ToTensor and Lambda.

FashionMUIST特征是PIL数据格式,标签是整数。为了训练,我们需要特征变成归一化的tensors,标签变成one-hot编码的tensors。为了实现这些transformation,我们使用ToTensorLamba

one-hot:就是把一列类别ABCD变成一行表头isA isB isC isD,然后每一列就是1代表是A,0就是不是A,这样就避免了潜在的数字关系

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

ToTensor()

ToTensor converts a PIL image or NumPy ndarray into a FloatTensor. and scales the image's pixel intensity values in the range [0., 1.]

ToTensor把一个PIL图像/NumPy数组转换成FloatTensor,并把图像每个像素的值归一化到[0., 1.]之间

Lambda Transforms

Lambda transforms apply any user-defined lambda function. Here, we define a function to turn the integer into a one-hot encoded tensor. It first creates a zero tensor of size 10 (the number of labels in our dataset) and calls
scatter_ which assigns a value=1 on the index as given by the label y.

Lambda transforms应用任何用户定义的lambda函数。这里我们定义了一个函数把整数转化成一个one-hot编码的tensor。它首先创建一个都是size为10(我们数据集里的标签个数)0的tensor,然后调用scatter_根据标签y把对应索引的值变成1。

target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

Further Reading