Torchvision: 数据读取

Torchvision 是一个和 PyTorch 配合使用的 Python 包,包含很多图像处理的工具。

数据读取

PyTorch 为我们提供了一种十分方便的数据读取机制,即使用 Dataset 类与 DataLoader 类的组合,来得到数据迭代器。在训练或预测时,数据迭代器能够输出每一批次所需的数据,并且对数据进行相应的预处理与数据增强操作。

DataSet

PyTorch 中的 Dataset 类是一个抽象类,它可以用来表示数据集。我们通过继承 Dataset类来自定义数据集的格式、大小和其它属性,后面就可以供 DataLoader 类直接使用。其实这就表示,无论使用自定义的数据集,还是官方为我们封装好的数据集,其本质都是继承了 Dataset 类。而在继承 Dataset 类时,至少需要重写以下几个方法:

**__init__()**:构造函数,可自定义数据读取方法以及进行数据预处理;

**__len__()**:返回数据集大小;

**__getitem__()**:索引数据集中的某一个数据。

示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
# 构造函数
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
# 返回数据集大小
def __len__(self):
return self.data_tensor.size(0)
# 返回索引的数据与标签
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]

使用:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 生成数据
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # 标签是0或1
# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)
# 查看数据集大小
print('Dataset size:', len(my_dataset))
'''
输出:
Dataset size: 10
'''
# 使用索引调用数据
print('tensor_data[0]: ', my_dataset[0])
'''
输出:
tensor_data[0]: (tensor([ 0.4931, -0.0697, 0.4171]), tensor(0))
'''

DataLoader

多进程迭代加载数据(考虑到内存有限、I/O 速度等问题)

DataLoader 是一个迭代器,最基本的使用方法就是传入一个 Dataset 对象,它会根据参数 batch_size 的值生成一个 batch 的数据,节省内存的同时,它还可以实现多进程、数据打乱等处理。