PyTorch的数据工具

在解决深度学习问题的过程中,往往需要花费大量的精力去处理数据,包括图像、文本、语音或其他二进制数据等。数据的处理对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练,也会提高模型效果。

图像数据加载与处理

1、自定义数据集

在PyTorch中,数据加载可通过自定义的数据集对象实现。数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset类,并实现两个Python的魔术方法,__ getitem __ 和 __ len __。
在训练神经网络时是对一个batch的数据进行操作,同时还需要对数据进行shuffle和并行加速等。对此,PyTorch提供了DataLoader类帮助我们实现这些功能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
from PIL import Image
import numpy as np
import torch
from torch.utils import data
from torch.utils.data import DataLoader
from torchvision import transforms

#========================================================
# 数据集对象
#========================================================

transform = transforms.Compose([
transforms.Resize(256), # 缩放图片, 保持长宽比不变, 最短边256
transforms.CenterCrop(224), # 从图片中间切出224x224的图片
transforms.ToTensor(), # 将图片(Image)转化为Tensor, 归一化到[0,1]
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化到[-1,1]
])

class DogCat(data.Dataset):
def __init__(self, root, transform=None):
imgs = os.listdir(root)
self.imgs = [os.path.join(root, img) for img in imgs]
self.transform = transform

def __getitem__(self, index):
'''
返回一条数据或样本
'''
img_path = self.imgs[index]
label = 0 if 'dog' in img_path.split('/')[-1] else 1
data = Image.open(img_path)
if self.transform:
data = self.transform(data)
return data, label

def __len__(self):
'''
返回样本的数量
'''
return len(self.imgs)

dataset = DogCat('DogCat/data/', transform=transform)
# img, label = dataset[0]
# print(img.size(), label)
# for img, label in dataset:
# print(img.size(), label)

#========================================================
# 数据加载对象
#========================================================

from torch.utils.data.sampler import WeightedRandomSampler

# 猫的图片被取出的概率是狗的两倍
# weight = [2 if label == 1 else 1 for data, label in dataset]
# 狗的图片被取出的概率是猫的两倍
weight = [0.5 if label == 1 else 1 for data, label in dataset]
# print(weight)

sampler = WeightedRandomSampler(weight,
num_samples = 5, # 选取样本总数
replacement=True # 是否可重复选取同一样本
)

dataloader = DataLoader(dataset,
batch_size=8, # 批大小
# shuffle=True, # 是否将数据打乱
sampler=sampler, # 样本抽样(与shuffle只会有一个生效)
num_workers=0 # 使用多进程加载的进程数, 0代表不使用多进程
)

for datas, labels in dataloader:
print(labels.tolist())

2、torchvision定义的数据集

torchvision已经预先实现了常用的Dataset,包括MNIST、CIFAR-10、ImageNet、COCO等数据集,可通过torchvision.datasets下相应对象调用相关数据集。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
from PIL import Image
import numpy as np
import torch
from torch.utils import data
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder

#========================================================
# 数据集对象
#========================================================

transform = transforms.Compose([
transforms.Resize(256), # 缩放图片, 保持长宽比不变, 最短边256
transforms.CenterCrop(224), # 从图片中间切出224x224的图片
transforms.ToTensor(), # 将图片(Image)转化为Tensor, 归一化到[0,1]
transforms.Normalize(mean=[0.5], std=[0.5]) # 标准化到[-1,1]
])

dataset = torchvision.datasets.MNIST(
root='data/',
train=True,
download=True,
transform=transform)

# img, label = dataset[0]
# print(img.size(), label)
# for img, label in dataset:
# print(img.size(), label)

#========================================================
# 数据加载对象
#========================================================

from torch.utils.data.sampler import WeightedRandomSampler

# 猫的图片被取出的概率是狗的两倍
# weight = [2 if label == 1 else 1 for data, label in dataset]
# 狗的图片被取出的概率是猫的两倍
weight = [0.5 if label == 1 else 1 for data, label in dataset]
# print(weight)

sampler = WeightedRandomSampler(weight,
num_samples = 5, # 选取样本总数
replacement=True # 是否可重复选取同一样本
)

dataloader = DataLoader(dataset,
batch_size=8, # 批大小
# shuffle=True, # 是否将数据打乱
sampler=sampler, # 样本抽样(与shuffle只会有一个生效)
num_workers=0 # 使用多进程加载的进程数, 0代表不使用多进程
)

for datas, labels in dataloader:
print(labels.tolist())

对于自定义数据集,可以采用ImageFloder来定义。ImageFloder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹为类名。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
from PIL import Image
import numpy as np
import torch
from torch.utils import data
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder

#========================================================
# 数据集对象
#========================================================

transform = transforms.Compose([
transforms.Resize(256), # 缩放图片, 保持长宽比不变, 最短边256
transforms.CenterCrop(224), # 从图片中间切出224x224的图片
transforms.ToTensor(), # 将图片(Image)转化为Tensor, 归一化到[0,1]
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # 标准化到[-1,1]
])

dataset = ImageFolder('DogCat/data/', transform=transform)
# cat文件夹对应label 0, dog对应1
print(dataset.class_to_idx)

# img, label = dataset[0]
# print(img.size(), label)
# for img, label in dataset:
# print(img.size(), label)

#========================================================
# 数据加载对象
#========================================================

from torch.utils.data.sampler import WeightedRandomSampler

# 猫的图片被取出的概率是狗的两倍
# weight = [2 if label == 1 else 1 for data, label in dataset]
# 狗的图片被取出的概率是猫的两倍
weight = [0.5 if label == 1 else 1 for data, label in dataset]
# print(weight)

sampler = WeightedRandomSampler(weight,
num_samples = 5, # 选取样本总数
replacement=True # 是否可重复选取同一样本
)

dataloader = DataLoader(dataset,
batch_size=8, # 批大小
# shuffle=True, # 是否将数据打乱
sampler=sampler, # 样本抽样(与shuffle只会有一个生效)
num_workers=0 # 使用多进程加载的进程数, 0代表不使用多进程
)

for datas, labels in dataloader:
print(labels.tolist())
0%