Pytorch中的DataLoader的相关记录

star2017 1年前 ⋅ 471 阅读

主要内容:关于DataLoader学习的一些笔记,需要先了解batch和epoch等输入数据的相关概念,以及python中类的基本知识比如继承和函数复写。


DataLoader简单介绍

DataLoader是Pytorch中用来处理模型输入数据的一个工具类。通过使用DataLoader,我们可以方便地对数据进行相关操作,比如我们可以很方便地设置batch_size,对于每一个epoch是否随机打乱数据,是否使用多线程等等。


咱们先通过下图先来窥探DataLoader的基本处理流程。

1. 首先会将原始数据加载到DataLoader中去,如果需要shuffle的话,会对数据进行随机打乱操作,这样能够输入顺序对于数据的影响。

2. 再使用一个迭代器来按照设置好的batch大小来迭代输出shuffle之后的数据。


Tips: 通过使用迭代器能够有效地降低内存的损耗,会在需要使用的时候才将数据加载到内存中去。

好了,知道了DataLoader的基本使用流程,下面开始正式进入我们的介绍。


 使用Dataset来创建自己的数据类

当我们拿到数据之后,首先需要做的就是写一个属于自己的数据类。

我们通过继承torch.utils.data.Dataset这个类来构造。因为Dataset这个类比较简单,我们可以先来看看源码。

class Dataset(object):
    """An abstract class representing a Dataset.    All other datasets should subclass it. All subclasses should override    ``__len__``, that provides the size of the dataset, and ``__getitem__``,    supporting integer indexing in range from 0 to len(self) exclusive.    """

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other]) 


其中, __getitem__ 和 __len__ 这两个方法在我们每次自定义自己的类的时候是需要去复写的。

下面结合一个例子来进行介绍:

class MyDataset(Dataset): 
    """ my dataset."""
    
    # Initialize your data, download, etc.
    def __init__(self):
        # 读取csv文件中的数据
        xy = np.loadtxt('data-diabetes.csv', delimiter=',', dtype=np.float32) 
        self.len = xy.shape[0]
        # 除去最后一列为数据位,存在x_data中
        self.x_data = torch.from_numpy(xy[:, 0:-1])
        # 最后一列为标签为,存在y_data中
        self.y_data = torch.from_numpy(xy[:, [-1]])
        
    def __getitem__(self, index):
        # 根据索引返回数据和对应的标签
        return self.x_data[index], self.y_data[index]
        
    def __len__(self): 
        # 返回文件数据的数目
        return self.len


简单分析如下:

  1. 继承Dataset来创建自己的数据类。将数据的下载,加载等,写入到这个类的初始化方法__init__中去,这样后面直接通过创建这个类即可获得数据并直接进行使用。
  2. 通过复写 __getitem__ 方法可以通过索引来访问数据,能够同时返回数据和对应的标签(label)。
  3. 通过复写 __len__ 方法来获取数据的个数。


2. 使用DataLoader来控制数据的输入输出

结合上一节自己创建的Dataset,DataLoader的使用方式如下:

dataset = Mydataset()
train_loader = DataLoader(dataset=dataset, batch_size=32, 
shuffle=True, num_workers=2)


下面来对DataLoader中的常用参数进行介绍:

  • dataset(Dataset) - 输入自己先前创建好的自己的数据集
  • batch_size(int, optional) - 每一个batch包括的样本数(默认为1)
  • shuffle (bool, optional) - 每一个epoch进行的时候是否要进行随机打乱(默认为False)
  • num_workers(int, optional) - 使用多少个子进程来同时处理数据加载。(默认为0)
  • pin_memory(bool, optional) - 如果为True会将数据放置到GPU上去(默认为false)
  • drop_last (bool, optional) - 如果最后一点数据如果无法满足batch的大小,那么通过设置True可以将最后部分的数据进行丢弃,否则最后一个batch将会较小。(默认为False)


这样,我们就可以通过循环来迭代来高效地获取数据啦。

for i, data in enumerate(train_loader, 0):
    # get the inputs
    inputs, labels = data

    # wrap them in Variable
    inputs, labels = Variable(inputs), Variable(labels)

    # Run your training process
    print(step, i, "inputs", inputs.data, "labels", labels.data)


更多内容请访问:IT源点

相关文章推荐

全部评论: 0

    我有话说: