PyTorch 中的state_dict是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系。(注意,只有那些参数可以训练的层才会被保存到模型的state_dict中,如卷积层、线性层等) 优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 import torchimport torch.nn as nnclass TheModelClass (nn.Module) : def __init__ (self) : super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3 , 6 , 5 ) self.pool = nn.MaxPool2d(2 , 2 ) self.conv2 = nn.Conv2d(6 , 16 , 5 ) self.fc1 = nn.Linear(16 * 5 * 5 , 120 ) self.fc2 = nn.Linear(120 , 84 ) self.fc3 = nn.Linear(84 , 10 ) def forward (self, x) : x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1 , 16 * 5 * 5 ) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model = TheModelClass() optimizer = torch.optim.SGD(model.parameters(), lr=0.001 , momentum=0.9 ) print("Model's state_dict:" ) for param_tensor in model.state_dict(): print(param_tensor, "\t" , model.state_dict()[param_tensor].size()) print("Optimizer's state_dict:" ) for var_name in optimizer.state_dict(): print(var_name, "\t" , optimizer.state_dict()[var_name])
Pytorch有多种保存模型的方式,使用哪种进行保存,就要使用对应的加载方式。保存的时候模型的后缀名是无所谓的。
保存和加载模型 1、 保存/加载state_dict 保存
1 torch.save(model.state_dict(), PATH)
加载
1 2 3 model = TheModelClass() model.load_state_dict(torch.load(PATH)) model.eval()
2、 保存/加载完整模型 保存
加载
1 2 3 model = TheModelClass() model = torch.load(PATH) model.eval()
3、保存/加载Checkpoint用于预测/继续训练 保存
1 2 3 4 5 6 torch.save({ 'epoch' : epoch, 'model_state_dict' : model.state_dict(), 'optimizer_state_dict' : optimizer.state_dict(), 'loss' : loss }, PATH)
加载
1 2 3 4 5 6 7 8 9 10 model = TheModelClass() optimizer = TheOptimizerClass() checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict' ]) optimizer.load_state_dict(checkpoint['optimizer_state_dict' ]) epoch = checkpoint['epoch' ] loss = checkpoint['loss' ] model.eval()
4、在一个文件中保存多个模型 保存
1 2 3 4 5 6 torch.save({ 'modelA_state_dict' : modelA.state_dict(), 'modelB_state_dict' : modelB.state_dict(), 'optimizerA_state_dict' : optimizerA.state_dict(), 'optimizerB_state_dict' : optimizerB.state_dict() }, PATH)
加载
1 2 3 4 5 6 7 8 checkpoint = torch.load(PATH) modelA.load_state_dict(checkpoint['modelA_state_dict' ]) modelB.load_state_dict(checkpoint['modelB_state_dict' ]) optimizerA.load_state_dict(checkpoint['optimizerA_state_dict' ]) optimizerB.load_state_dict(checkpoint['optimizerB_state_dict' ]) modelA.eval() modelB.eval()