训练模型的保存包括两种:
1、保存整个模型框架以及模型参数(存储文件过大,不推荐)
torch.save(model,path)
2、仅仅保存模型的参数文件(推荐)
torch.save(model.state_dict(),path)
"state_dict"表示state dictionary,即字典类型的参数,模型本身的参数。
例如
torch.save(model.state_dict(),'{}/moilenetV2_{}_{}.pth'.format('./models',epoch,acc))
模型的断点继续训练
Resume = True
# Resume = False
if Resume:
path_checkpoint = 'your/new/model/path.pth'
checkpoint = torch.load(path_checkpoint, map_location = torch.device('cuda'))
model.load_state_dict(checkpoint)