读取和存储
import torch
from torch import nn
print(torch.__version__)
0.4.1
读写Tensor
x = torch.ones(3)
torch.save(x, 'x.pt')
x2 = torch.load('x.pt')
x2
tensor([1., 1., 1.])
y = torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list
[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]
torch.save({'x': x, 'y': y}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy
{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}
读写模型
state_dict
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.hidden = nn.Linear(3, 2)
self.act = nn.ReLU()
self.output = nn.Linear(2, 1)
def forward(self, x):
a = self.act(self.hidden(x))
return self.output(a)
net = MLP()
net.state_dict()
OrderedDict([('hidden.weight', tensor([[ 0.1836, -0.1812, -0.1681], [ 0.0406, 0.3061, 0.4599]])), ('hidden.bias', tensor([-0.3384, 0.1910])), ('output.weight', tensor([[0.0380, 0.4919]])), ('output.bias', tensor([0.1451]))])
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()
{'param_groups': [{'dampening': 0, 'lr': 0.001, 'momentum': 0.9, 'nesterov': False, 'params': [4624483024, 4624484608, 4624484680, 4624484752], 'weight_decay': 0}], 'state': {}}
保存和加载模型
X = torch.randn(2, 3)
Y = net(X)
PATH = "./net.pt"
torch.save(net.state_dict(), PATH)
net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y
tensor([[1], [1]], dtype=torch.uint8)
Comments NOTHING