线性回归——pytorch实现
2021/8/16 23:10:08
本文主要是介绍线性回归——pytorch实现,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
1 import torch 2 import matplotlib.pyplot as plt 3 import os 4 os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' 5 6 lr = 0.001 7 see = 20000 8 x = torch.rand([1, 50]) 9 y = 3 * x + 0.8 10 11 w = torch.rand([1, 1], requires_grad=True, dtype=torch.float32) 12 b = torch.rand(1, requires_grad=True, dtype=torch.float32) 13 loss = [] 14 15 for i in range(see): 16 y_pred = torch.matmul(w, x) + b 17 cur_loss = torch.matmul(y - y_pred, (y - y_pred).T)18 loss.append(cur_loss.item()) 19 20 if i != 0: # 将梯度清零,初始时参数的梯度为None所以先计算一次后才有梯度 21 w.grad.data.zero_() 22 b.grad.data.zero_() 23 24 cur_loss.backward() 25 w.data = w.data - lr * w.grad 26 b.data = b.data - lr * b.grad 27 28 if i % 200 == 0: 29 print("w, b, loss", w.item(), b.item(), cur_loss.item()) 30 31 plt.scatter(x.numpy()[0], y.numpy()[0]) 32 y_predict = torch.matmul(w, x) + b 33 plt.plot(x.numpy()[0], y_predict.detach().numpy()[0]) 34 plt.show()
这篇关于线性回归——pytorch实现的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-07-03微信支付提示下单账户与支付账户不一致-icode9专业技术文章分享
- 2024-07-03微信支付提示订单号重复-icode9专业技术文章分享
- 2024-07-02微服务启动nacos注册上去了,但是一直没有收到请求-icode9专业技术文章分享
- 2024-07-02如何检查文件的编码格式-icode9专业技术文章分享
- 2024-07-02sublime 更改编码格式-icode9专业技术文章分享
- 2024-06-30uniAPP 实现全屏左右滚动滚动的效果-icode9专业技术文章分享
- 2024-06-30如何在本地使用授权或插件-icode9专业技术文章分享
- 2024-06-30伪静态规则配置方法汇总-icode9专业技术文章分享
- 2024-06-29易优CMS安装常见问题汇总-icode9专业技术文章分享
- 2024-06-28易优新手必读安装教程-icode9专业技术文章分享