千家信息网

怎么解决pytorch训练神经网络爆内存

发表于:2025-11-09 作者:千家信息网编辑
千家信息网最后更新 2025年11月09日,小编给大家分享一下怎么解决pytorch训练神经网络爆内存,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!在建立人工神经网络
千家信息网最后更新 2025年11月09日怎么解决pytorch训练神经网络爆内存

小编给大家分享一下怎么解决pytorch训练神经网络爆内存,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!

在建立人工神经网络时整体的步骤主要有以下四步:

1、载入原始数据

2、构建具体神经网络

3、进行数据的训练

4、数据测试和验证

pytorch神经网络的数据载入,以MINIST书写字体的原始数据为例:

import torchimport matplotlib.pyplot as  pltdef plot_curve(data):    fig=plt.figure()    plt.plot(range(len(data)),data,color="blue")    plt.legend(["value"],loc="upper right")    plt.xlabel("step")    plt.ylabel("value")    plt.show() def plot_image(img,label,name):    fig=plt.figure()    for i in range(6):        plt.subplot(2,3,i+1)        plt.tight_layout()        plt.imshow(img[i][0]*0.3081+0.1307,cmap="gray",interpolation="none")        plt.title("{}:{}".format(name, label[i].item()))        plt.xticks([])        plt.yticks([])    plt.show()def one_hot(label,depth=10):    out=torch.zeros(label.size(0),depth)    idx=torch.LongTensor(label).view(-1,1)    out.scatter_(dim=1,index=idx,value=1)    return out batch_size=512import torchfrom torch import nn                         #完成神经网络的构建包from torch.nn import functional as F         #包含常用的函数包from torch import optim                      #优化工具包import torchvision                           #视觉工具包import  matplotlib.pyplot as pltfrom utils import plot_curve,plot_image,one_hot#step1 load dataset   加载数据包train_loader=torch.utils.data.DataLoader(    torchvision.datasets.MNIST("minist_data",train=True,download=True,transform=torchvision.transforms.Compose(        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))         ])),    batch_size=batch_size,shuffle=True)test_loader=torch.utils.data.DataLoader(    torchvision.datasets.MNIST("minist_data",train=True,download=False,transform=torchvision.transforms.Compose(        [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,),(0.3081,))         ])),    batch_size=batch_size,shuffle=False)x,y=next(iter(train_loader))print(x.shape,y.shape)plot_image(x,y,"image")print(x)print(y)

以构建一个简单的回归问题的神经网络为例,

其具体的实现代码如下所示:

import torchimport torch.nn.functional as F  # 激励函数都在这 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)y = x.pow(2) + 0.2 * torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1) class Net(torch.nn.Module):  # 继承 torch 的 Module(固定)    def __init__(self, n_feature, n_hidden, n_output):  # 定义层的信息,n_feature多少个输入, n_hidden每层神经元, n_output多少个输出        super(Net, self).__init__()  # 继承 __init__ 功能(固定)        # 定义每层用什么样的形式        self.hidden = torch.nn.Linear(n_feature, n_hidden)  # 定义隐藏层,线性输出        self.predict = torch.nn.Linear(n_hidden, n_output)  # 定义输出层线性输出     def forward(self, x):  # x是输入信息就是data,同时也是 Module 中的 forward 功能,定义神经网络前向传递的过程,把__init__中的层信息一个一个的组合起来        # 正向传播输入值, 神经网络分析出输出值        x = F.relu(self.hidden(x))  # 定义激励函数(隐藏层的线性值)        x = self.predict(x)  # 输出层,输出值        return x  net = Net(n_feature=1, n_hidden=10, n_output=1) print(net)  # net 的结构"""Net (  (hidden): Linear (1 -> 10)  (predict): Linear (10 -> 1))"""# optimizer 是训练的工具optimizer = torch.optim.SGD(net.parameters(), lr=0.2)  # 传入 net 的所有参数, 学习率loss_func = torch.nn.MSELoss()  # 预测值和真实值的误差计算公式 (均方差) for t in range(100):  # 训练的步数100步    prediction = net(x)  # 喂给 net 训练数据 x, 每迭代一步,输出预测值     loss = loss_func(prediction, y)  # 计算两者的误差     # 优化步骤:    optimizer.zero_grad()  # 清空上一步的残余更新参数值    loss.backward()  # 误差反向传播, 计算参数更新值    optimizer.step()  # 将参数更新值施加到 net 的 parameters 上 import matplotlib.pyplot as plt plt.ion()  # 实时画图something about plotting for t in range(200):    prediction = net(x)  # input x and predict based on x     loss = loss_func(prediction, y)  # must be (1. nn output, 2. target)     optimizer.zero_grad()  # clear gradients for next train    loss.backward()  # backpropagation, compute gradients    optimizer.step()  # apply gradients     if t % 5 == 0:  # 每五步绘一次图        # plot and show learning process        plt.cla()        plt.scatter(x.data.numpy(), y.data.numpy())        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)        plt.text(0.5, 0, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})        plt.pause(0.1) plt.ioff()plt.show()

以上是"怎么解决pytorch训练神经网络爆内存"这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注行业资讯频道!

神经 神经网络 网络 输出 数据 训练 参数 信息 函数 工具 篇文章 线性 误差 更新 输入 内存 原始 内容 功能 工具包 数据库的安全要保护哪些东西 数据库安全各自的含义是什么 生产安全数据库录入 数据库的安全性及管理 数据库安全策略包含哪些 海淀数据库安全审计系统 建立农村房屋安全信息数据库 易用的数据库客户端支持安全管理 连接数据库失败ssl安全错误 数据库的锁怎样保障安全 go打印数据库查询结果乱码 数据库不同表如何查询数据 手机在桌面显示未连接到服务器 服务器能上youtube 网络安全有奖问答方案 联想650服务器管理端口 惠普笔记本网络安全密匙 泰山杯网络安全大赛团体得分 安卓软件开发哪里学 上号器服务器 橙光文字教育软件开发 东莞pc软件开发费用是多少 中华人民共和国维护网络安全法 查看sql数据库连接池 济南应用软件开发怎么收费 数据库演化技术 我的世界服务器指令进服提示 恒易达直销软件开发公司 池州金融软件开发公司 怎么给自己的服务器上高防 局域网服务器需要网线吗 中同网络技术有限公司怎么样 汽车嵌入式软件开发简历 天上网络技术有限公司 云南ipfs云服务器云主机 怎么把吃鸡安装到自己服务器里 软件开发只是前端后端 网络安全之我的故事 数据库热块问题 买阿里云服务器可以打游戏吗
0