电子发烧友App

硬声App

0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示
电子发烧友网>电子资料下载>电子资料>PyTorch教程6.6.之文件输入输出

PyTorch教程6.6.之文件输入输出

2023-06-05 | pdf | 0.11 MB | 次下载 | 免费

资料介绍

到目前为止,我们讨论了如何处理数据以及如何构建、训练和测试深度学习模型。然而,在某些时候,我们希望对学习的模型感到满意,我们希望保存结果以供以后在各种情况下使用(甚至可能在部署中进行预测)。此外,在运行较长的训练过程时,最佳做法是定期保存中间结果(检查点),以确保如果我们被服务器的电源线绊倒,我们不会损失几天的计算量。因此,是时候学习如何加载和存储单个权重向量和整个模型了。本节解决这两个问题。

import torch
from torch import nn
from torch.nn import functional as F
from mxnet import np, npx
from mxnet.gluon import nn

npx.set_np()
import flax
import jax
from flax import linen as nn
from flax.training import checkpoints
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import numpy as np
import tensorflow as tf

6.6.1. 加载和保存张量

对于单个张量,我们可以直接调用loadsave 函数分别进行读写。这两个函数都需要我们提供一个名称,并且save需要将要保存的变量作为输入。

x = torch.arange(4)
torch.save(x, 'x-file')
x = np.arange(4)
npx.save('x-file', x)
x = jnp.arange(4)
jnp.save('x-file.npy', x)
x = tf.range(4)
np.save('x-file.npy', x)

我们现在可以将存储文件中的数据读回内存。

x2 = torch.load('x-file')
x2
tensor([0, 1, 2, 3])
x2 = npx.load('x-file')
x2
[array([0., 1., 2., 3.])]
x2 = jnp.load('x-file.npy', allow_pickle=True)
x2
Array([0, 1, 2, 3], dtype=int32)
x2 = np.load('x-file.npy', allow_pickle=True)
x2
array([0, 1, 2, 3], dtype=int32)

我们可以存储张量列表并将它们读回内存。

y = torch.zeros(4)
torch.save([x, y],'x-files')
x2, y2 = torch.load('x-files')
(x2, y2)
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
y = np.zeros(4)
npx.save('x-files', [x, y])
x2, y2 = npx.load('x-files')
(x2, y2)
(array([0., 1., 2., 3.]), array([0., 0., 0., 0.]))
y = jnp.zeros(4)
jnp.save('xy-files.npy', [x, y])
x2, y2 = jnp.load('xy-files.npy', allow_pickle=True)
(x2, y2)
(Array([0., 1., 2., 3.], dtype=float32),
 Array([0., 0., 0., 0.], dtype=float32))
y = tf.zeros(4)
np.save('xy-files.npy', [x, y])
x2, y2 = np.load('xy-files.npy', allow_pickle=True)
(x2, y2)
(array([0., 1., 2., 3.]), array([0., 0., 0., 0.]))

我们甚至可以编写和读取从字符串映射到张量的字典。当我们想要读取或写入模型中的所有权重时,这很方便。

mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')
mydict2
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
mydict = {'x': x, 'y': y}
npx.save('mydict', mydict)
mydict2 = npx.load('mydict')
mydict2
{'x': array([0., 1., 2., 3.]), 'y': array([0., 0., 0., 0.])}
mydict = {'x': x, 'y': y}
jnp.save('mydict.npy', mydict)
mydict2 = jnp.load('mydict.npy', allow_pickle=True)
mydict2
array({'x': Array([0, 1, 2, 3], dtype=int32), 'y': Array([0., 0., 0., 0.], dtype=float32)},
   dtype=object)
mydict = {'x': x, 'y': y}
np.save('mydict.npy', mydict)
mydict2 = np.load('mydict.npy', allow_pickle=True)
mydict2
array({'x': <tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>, 'y': <tf.Tensor: shape=(4,), dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>},
   dtype=object)

6.6.2. 加载和保存模型参数

保存单个权重向量(或其他张量)很有用,但如果我们想保存(并稍后加载)整个模型,它会变得非常乏味。毕竟,我们可能散布着数百个参数组。出于这个原因,深度学习框架提供了内置功能来加载和保存整个网络需要注意的一个重要细节是,这会保存模型参数而不是整个模型。例如,如果我们有一个 3 层的 MLP,我们需要单独指定架构。这样做的原因是模型本身可以包含任意代码,因此它们不能自然地序列化。因此,为了恢复模型,我们需要用代码生成架构,然后从磁盘加载参数。让我们从我们熟悉的 MLP 开始。

class MLP(nn.Module):
  def __init__(self):
    super().__init__()
    self.hidden = nn.LazyLinear(256)
    self.output = nn.LazyLinear(10)

  def forward(self, x):
    return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)

下载该资料的人也在下载 下载该资料的人还在阅读
更多 >

评论

查看更多

下载排行

本周

  1. 1DD3118电路图纸资料
  2. 0.08 MB   |  1次下载  |  免费
  3. 2AD库封装库安装教程
  4. 0.49 MB   |  1次下载  |  免费
  5. 3PC6206 300mA低功耗低压差线性稳压器中文资料
  6. 1.12 MB   |  1次下载  |  免费
  7. 4网络安全从业者入门指南
  8. 2.91 MB   |  1次下载  |  免费
  9. 5DS-CS3A P00-CN-V3
  10. 618.05 KB  |  1次下载  |  免费
  11. 6海川SM5701规格书
  12. 1.48 MB  |  次下载  |  免费
  13. 7H20PR5电磁炉IGBT功率管规格书
  14. 1.68 MB   |  次下载  |  1 积分
  15. 8IP防护等级说明
  16. 0.08 MB   |  次下载  |  免费

本月

  1. 1贴片三极管上的印字与真实名称的对照表详细说明
  2. 0.50 MB   |  103次下载  |  1 积分
  3. 2涂鸦各WiFi模块原理图加PCB封装
  4. 11.75 MB   |  89次下载  |  1 积分
  5. 3锦锐科技CA51F2 SDK开发包
  6. 24.06 MB   |  43次下载  |  1 积分
  7. 4锦锐CA51F005 SDK开发包
  8. 19.47 MB   |  19次下载  |  1 积分
  9. 5PCB的EMC设计指南
  10. 2.47 MB   |  16次下载  |  1 积分
  11. 6HC05蓝牙原理图加PCB
  12. 15.76 MB   |  13次下载  |  1 积分
  13. 7802.11_Wireless_Networks
  14. 4.17 MB   |  12次下载  |  免费
  15. 8苹果iphone 11电路原理图
  16. 4.98 MB   |  6次下载  |  2 积分

总榜

  1. 1matlab软件下载入口
  2. 未知  |  935127次下载  |  10 积分
  3. 2开源硬件-PMP21529.1-4 开关降压/升压双向直流/直流转换器 PCB layout 设计
  4. 1.48MB  |  420064次下载  |  10 积分
  5. 3Altium DXP2002下载入口
  6. 未知  |  233089次下载  |  10 积分
  7. 4电路仿真软件multisim 10.0免费下载
  8. 340992  |  191390次下载  |  10 积分
  9. 5十天学会AVR单片机与C语言视频教程 下载
  10. 158M  |  183342次下载  |  10 积分
  11. 6labview8.5下载
  12. 未知  |  81588次下载  |  10 积分
  13. 7Keil工具MDK-Arm免费下载
  14. 0.02 MB  |  73815次下载  |  10 积分
  15. 8LabVIEW 8.6下载
  16. 未知  |  65989次下载  |  10 积分