主要内容:关于DataLoader学习的一些笔记,需要先了解batch和epoch等输入数据的相关概念,以及python中类的基本知识比如继承和函数复写。
DataLoader简单介绍
DataLoader是Pytorch中用来处理模型输入数据的一个工具类。通过使用DataLoader,我们可以方便地对数据进行相关操作,比如我们可以很方便地设置batch_size,对于每一个epoch是否随机打乱数据,是否使用多线程等等。
咱们先通过下图先来窥探DataLoader的基本处理流程。
1. 首先会将原始数据加载到DataLoader中去,如果需要shuffle的话,会对数据进行随机打乱操作,这样能够输入顺序对于数据的影响。
2. 再使用一个迭代器来按照设置好的batch大小来迭代输出shuffle之后的数据。
Tips: 通过使用迭代器能够有效地降低内存的损耗,会在需要使用的时候才将数据加载到内存中去。
好了,知道了DataLoader的基本使用流程,下面开始正式进入我们的介绍。
使用Dataset来创建自己的数据类
当我们拿到数据之后,首先需要做的就是写一个属于自己的数据类。
我们通过继承torch.utils.data.Dataset这个类来构造。因为Dataset这个类比较简单,我们可以先来看看源码。
class Dataset(object):
"""An abstract class representing a Dataset. All other datasets should subclass it. All subclasses should override ``__len__``, that provides the size of the dataset, and ``__getitem__``, supporting integer indexing in range from 0 to len(self) exclusive. """
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
其中, __getitem__ 和 __len__ 这两个方法在我们每次自定义自己的类的时候是需要去复写的。
下面结合一个例子来进行介绍:
class MyDataset(Dataset):
""" my dataset."""
# Initialize your data, download, etc.
def __init__(self):
# 读取csv文件中的数据
xy = np.loadtxt('data-diabetes.csv', delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
# 除去最后一列为数据位,存在x_data中
self.x_data = torch.from_numpy(xy[:, 0:-1])
# 最后一列为标签为,存在y_data中
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
# 根据索引返回数据和对应的标签
return self.x_data[index], self.y_data[index]
def __len__(self):
# 返回文件数据的数目
return self.len
简单分析如下:
- 继承Dataset来创建自己的数据类。将数据的下载,加载等,写入到这个类的初始化方法__init__中去,这样后面直接通过创建这个类即可获得数据并直接进行使用。
- 通过复写 __getitem__ 方法可以通过索引来访问数据,能够同时返回数据和对应的标签(label)。
- 通过复写 __len__ 方法来获取数据的个数。
2. 使用DataLoader来控制数据的输入输出
结合上一节自己创建的Dataset,DataLoader的使用方式如下:
dataset = Mydataset()
train_loader = DataLoader(dataset=dataset, batch_size=32,
shuffle=True, num_workers=2)
下面来对DataLoader中的常用参数进行介绍:
- dataset(Dataset) - 输入自己先前创建好的自己的数据集
- batch_size(int, optional) - 每一个batch包括的样本数(默认为1)
- shuffle (bool, optional) - 每一个epoch进行的时候是否要进行随机打乱(默认为False)
- num_workers(int, optional) - 使用多少个子进程来同时处理数据加载。(默认为0)
- pin_memory(bool, optional) - 如果为True会将数据放置到GPU上去(默认为false)
- drop_last (bool, optional) - 如果最后一点数据如果无法满足batch的大小,那么通过设置True可以将最后部分的数据进行丢弃,否则最后一个batch将会较小。(默认为False)
这样,我们就可以通过循环来迭代来高效地获取数据啦。
for i, data in enumerate(train_loader, 0):
# get the inputs
inputs, labels = data
# wrap them in Variable
inputs, labels = Variable(inputs), Variable(labels)
# Run your training process
print(step, i, "inputs", inputs.data, "labels", labels.data)
注意:本文归作者所有,未经作者允许,不得转载