读写张量

x = torch.randn(3,3)
# 这里的x可以是张量字典或张量列表等形式
# 保存张量
torch.save(x,"path_to_file.pth")
# 加载张量
x2 = torch.load("path_to_file.pth")

读写模型参数

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)
 
    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))
 
net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
 
# 保存模型参数
torch.save(net.state_dict(), 'path_to_module.params')
# 加载模型参数
mlp = MLP()
mlp.load_state_dict(torch.load('path_to_module.params'))

读写模型

# 保存模型
torch.save(mlp,'path_to_file.pth')
# 加载模型
# 注意!使用的模型类必须先定义
mlp = torch.load('path_to_file.pth')