解决pytorch训练的过程中内存一直增加的问题
2021/9/19 7:06:31
本文主要是介绍解决pytorch训练的过程中内存一直增加的问题,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
代码中存在累加loss,但每步的loss没加item()。
pytorch中,.item()方法 是得到一个元素张量里面的元素值
具体就是 用于将一个零维张量转换成浮点数,比如计算loss,accuracy的值
就比如:
loss = (y_pred - y).pow(2).sum()
print(loss.item())
for epoch in range(100): index=np.arange(train_sample.shape[0]) np.random.shuffle(index) train_set=train_sample[index].tolist() model.train() loss,s=0,0 for s in tqdm(range(0,train_sample.shape[0],batch_size)): if s+batch_size>train_sample.shape[0]: break batch_loss=model(train_set[s:s+batch_size]) optimizer.zero_grad() batch_loss.backward() optimizer.step() # 会导致内存一直增加,需改为loss+=batch_loss.item() loss+=batch_loss s+=batch_size loss/=total_batch print(epoch,loss) if (epoch+1) % 10 ==0: model.eval() model.save_embedding(epoch)
以上代码会导致内存占用越来越大,解决的方法是:loss+=batch_loss.item()。值得注意的是,要复现内存越来越大的问题,模型中需要切换model.train() 和 model.eval(),train_loss以及eval_loss的作用是保存模型的平均误差(这里是累积误差),保存到tensorboard中。
这篇关于解决pytorch训练的过程中内存一直增加的问题的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-11-23增量更新怎么做?-icode9专业技术文章分享
- 2024-11-23压缩包加密方案有哪些?-icode9专业技术文章分享
- 2024-11-23用shell怎么写一个开机时自动同步远程仓库的代码?-icode9专业技术文章分享
- 2024-11-23webman可以同步自己的仓库吗?-icode9专业技术文章分享
- 2024-11-23在 Webman 中怎么判断是否有某命令进程正在运行?-icode9专业技术文章分享
- 2024-11-23如何重置new Swiper?-icode9专业技术文章分享
- 2024-11-23oss直传有什么好处?-icode9专业技术文章分享
- 2024-11-23如何将oss直传封装成一个组件在其他页面调用时都可以使用?-icode9专业技术文章分享
- 2024-11-23怎么使用laravel 11在代码里获取路由列表?-icode9专业技术文章分享
- 2024-11-22怎么实现ansible playbook 备份代码中命名包含时间戳功能?-icode9专业技术文章分享