读写张量
x = torch.randn(3,3)
# 这里的x可以是张量字典或张量列表等形式
# 保存张量
torch.save(x,"path_to_file.pth")
# 加载张量
x2 = torch.load("path_to_file.pth")
读写模型参数
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.hidden = nn.Linear(20, 256)
self.output = nn.Linear(256, 10)
def forward(self, x):
return self.output(F.relu(self.hidden(x)))
net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
# 保存模型参数
torch.save(net.state_dict(), 'path_to_module.params')
# 加载模型参数
mlp = MLP()
mlp.load_state_dict(torch.load('path_to_module.params'))
读写模型
# 保存模型
torch.save(mlp,'path_to_file.pth')
# 加载模型
# 注意!使用的模型类必须先定义
mlp = torch.load('path_to_file.pth')