深度学习开发框架PyTorch(8)-- 模型保存和加载

PyTorch 中的state_dict是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系。(注意,只有那些参数可以训练的层才会被保存到模型的state_dict中,如卷积层、线性层等)
优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import torch.nn as nn

# 定义模型
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

# 初始化模型
model = TheModelClass()

# 初始化优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 打印模型的状态字典
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# 打印优化器的状态字典
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])

Pytorch有多种保存模型的方式,使用哪种进行保存,就要使用对应的加载方式。保存的时候模型的后缀名是无所谓的。

保存和加载模型

1、 保存/加载state_dict

保存

1
torch.save(model.state_dict(), PATH)  # PyTorch中最常见的模型保存使用‘.pt’或者是‘.pth’作为模型文件扩展名

加载

1
2
3
model = TheModelClass()
model.load_state_dict(torch.load(PATH))
model.eval()

2、 保存/加载完整模型

保存

1
torch.save(model, PATH)

加载

1
2
3
model = TheModelClass()
model = torch.load(PATH)
model.eval()

3、保存/加载Checkpoint用于预测/继续训练

保存

1
2
3
4
5
6
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, PATH) # PyTorch中常见的保存checkpoint是使用.tar作为文件扩展名

加载

1
2
3
4
5
6
7
8
9
10
model = TheModelClass()
optimizer = TheOptimizerClass()

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval() # or : model.train()

4、在一个文件中保存多个模型

保存

1
2
3
4
5
6
torch.save({
'modelA_state_dict': modelA.state_dict(),
'modelB_state_dict': modelB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict()
}, PATH)

加载

1
2
3
4
5
6
7
8
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval() # or : modelA.train()
modelB.eval() # or : modelB.train()
0%