k0b0's record.

Computer Engineering, Arts and Books

Pytorchでの学習済みモデルの保存と読み込み

モデルの保存

params = net.state_dict() #netはモデル名
torch.save(params, "ファイル名.prm", pickle_protocol=4)

モデルの読み込み

params = torch.load("ファイル名.prm", map_location="cpu")
net.load_state_dict(params)