Dataset & DataLoader

发布时间 2023-06-12 18:34:22作者: shendawei
from torch.utils.data import Dataset, DataLoader

1. Dataset

There are 2 different types of datasets:

1.1 map-style datasets (most commonly used)

Represents a map from indices/keys to data samples.

  • For example, such a dataset, when accessed with dataset[idx], could read the idx-th image and its corresponding label.

1.2 iterable-style datasets

An iterable-style dataset represents an iterable over data samples.

  • This type of datasets is particularly suitable for cases where random reads are expensive or even improbable, and where the batch size depends on the fetched data.

  • For example, such a dataset, when called iter(dataset), could return a stream of data reading from a database, a remote server, or even logs generated in real time.

1.3 Demo (map-style datasets)

# A custom Dataset class must implement three functions: __init__, __len__, and __getitem__. 

from torch.utils.data import Dataset

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform  # transform images
        self.target_transform = target_transform  # transform labels

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

2. DataLoader