Zexian Li

一文看懂Pytorch数据加载方法

2020-03-19 · 4 min read
Pytorch

Pytorch的数据加载主要依赖torch.utils.data.Datasettorch.utils.data.DataLoader两个模块,可以完成如下格式的傻瓜式加载。

train_dataset = MyDataset(train_data_path) # 'MyDataset' subclasses 'torch.utils.data.Dataset'
train_loader = torch.utils.data.DataLoader(train_dataset)   

看起来的确很诱人,接下来我会参照源码和实战案例对整个数据加载的流程及细节进行讲解。

torch.utils.data.Dataset

Pytorch中,任何基于索引读取数据(map-style: from keys to data samples)的类均需继承torch.utils.data.Dataset,该类为数据的读取定义了格式。我们可以通过torch.utils.data.Dataset源码得到该类的具体结构:

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError
    def __add__(self, other):
        return ConcatDataset([self, other])

注:任何基于迭代器读取数据(iterable-style)的类需继承torch.utils.data.IterableDataset。本文暂仅介绍较为常用的torch.utils.data.Dataset。

继承后的子类必须重写__getitem__()函数,以此通过给定索引获取对应数据;可以有选择性地重写__len__()函数以返回数据集的大小(若不重写,默认使用torch.utils.data.DataLoader实现该功能)。

以目标检测(Object Detection)任务为例,我们常通过输入参数的index值去索引并返回对应的图像、候选框及类别标签。实战中常搭配torch.utils.data.DataLoader类完成批次(batch)数据的读取。

torch.utils.data.DataLoader

torch.utils.data.DataLoader是实际的数据采样器,以单/多进程迭代的方式在封装的Dataset上获取数据。具体地,迭代器DataLoader使用next()方法以不断获得数据。
需注意,DataLoader只读取tensor,故常需在Dataset中将源数据转化为tensor。常用模块如torchvision.transforms等。

torch.utils.data.DataLoader源码可以得到该类的主要参数如下:

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=None,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None, multiprocessing_context=None)

实战讲解

现以COCO数据集数据读取作为样例,完成最小型的数据加载实例。

# coding: utf-8
# Author: Zexian Li

import torch
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

class COCO_Dataset(Dataset): 

    def __init__(self, txt_path, transform=None):

        # image path is recorded in a txt file that prepared in advance.
        with open(txt_path, "r") as f:
            self.img_paths = f.readlines()
        self.transform = transform
    

    def __getitem__(self, index):   
         
        image_path = self.img_paths[index]
        image = Image.open(image_name)
        if self.transform:
            image = self.transform(image)
        return image
 

    def __len__(self): 

        return len(self.img_paths)


transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])


if __name__ == "__main__":

    train_data = COCO_Dataset(txt_path='./image_path.txt', transform=transform)
    print(type(train_data))     # <class '__main__.COCO_Dataset'>
    print(type(train_data[0]))  # <class 'torch.Tensor'>

    # Data visualization.
    image_vis = transforms.ToPILImage()(train_data[0])
    print(type(image_vis))  # <class 'PIL.Image.Image'>
    image_vis.save('./image_vis.png')
    
    train_loader = DataLoader(dataset=train_data, batch_size=1, shuffle=True, num_workers=1)
    
    loader_iter = iter(train_loader)    
    # len(loader_iter) == image_numbers / batch_size

    # Save images through 'loader_iter'.
    for i in range(len(loader_iter)):
        image_save = data_iter.next()   
        # type(image_save) <class 'torch.Tensor'>
        # image_save.shape: torch.Size([1, 3, 483, 640])
        image_save = transforms.ToPILImage()(image_save.squeeze_(0))
        image_save.save('./save/' + str(i) + '.png')
Bad decisions make good stories.