ONNXRuntime学习笔记(三)
2022/5/1 6:16:28
本文主要是介绍ONNXRuntime学习笔记(三),对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
接上一篇完成的pytorch模型训练结果,模型结构为ResNet18+fc,参数量约为11M,最终测试集Acc达到94.83%。接下来有分两个部分:导出onnx和使用onnxruntime推理。
一、pytorch导出onnx
直接放函数吧,这部分我是直接放在test.py里面的,直接从dataloader中拿到一个batch的数据走一遍推理即可。
def export_onnx(net, testloader, output_file): net.eval() with torch.no_grad(): for data in testloader: images, labels = data torch.onnx.export(net, (images), output_file, training=False, do_constant_folding=True, input_names=["img"], output_names=["output"], dynamic_axes={"img": {0: "b"},"output": {0: "b"}} ) print("onnx export done!") break
上面函数中几个比较重要的参数:do_constant_folding是常量折叠,建议打开;输入张量通过一个tuple传入,并且最好指定每个输入和输出的名称,此外,为保证使用onnxruntime推理的时候batchsize可变,dynamic_axes的第一维需要像上述一样设置为动态的。如果是全卷积做分割的网络,类似的输入h和w也应该是动态的。
单独运行test.py计算测试集效果和平均相应时间,结果为:
Test Acc is: 94.83% Average response time cost: 0.10121344916428192
二、使用onnxruntime推理
这里我们使用gpu版本的onnxruntime库进行推理,其python包可直接pip install onnxruntime-gpu
安装。onnxruntime推理代码和测试集推理代码很类似,如下:
import numpy as np import onnxruntime as ort import argparse, os from lib import CIFARDataset def onnxruntime_test(session, testloader): print("Start Testing!") input_name = session.get_inputs()[0].name correct = 0 total = 0 # 计数归零(初始化) for data in testloader: images, labels = data images, labels = images.numpy(), labels.numpy() outputs = session.run(None, {input_name:images}) predicted = np.argmax(outputs[0], axis=1) # 取得分最高的那个类 total += labels.shape[0] # 累加样本总数 correct += (predicted == labels).sum() # 累加预测正确的样本个数 acc = correct / total print('ONNXRuntime Test Acc is: %.2f%%' % (100*acc)) if __name__ == '__main__': # 命令行参数解析 parser = argparse.ArgumentParser("CNN backbone on cifar10") parser.add_argument('--onnx', default='./output/test_resnet18_10_autoaug/densenet_best.onnx') args = parser.parse_args() NUM_CLASS =10 BATCH_SIZE = 128 # 批处理尺寸(batch_size) # 数据集迭代器 data_path="./data" dataset = CIFARDataset(dataset_path=data_path, batchsize=BATCH_SIZE) _, testloader = dataset.get_cifar10_dataloader() # 构建session sess = ort.InferenceSession(args.onnx, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) #onnxruntime推理 import time start = time.time() onnxruntime_test(sess, testloader) end = time.time() print("Average response time cost: ", (end-start)/len(testloader))
使用onnxruntime加载导出的onnx模型,计算测试集效果和平均响应时间,结果为:
ONNXRuntime Test Acc is: 94.83% Average response time cost: 0.07324151147769976
三、小结
分析上面的pytorch和onnxruntime的测试结果可知,最终测试集效果是一致的,Acc均为94.83%,但onnxruntime的效率更高,耗时是pytorch的75%,但比最初目标设定的50ms高,需要进一步优化,两个方向:模型量化或并行化推理。下一篇再分析。
这篇关于ONNXRuntime学习笔记(三)的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-11-23Springboot应用的多环境打包入门
- 2024-11-23Springboot应用的生产发布入门教程
- 2024-11-23Python编程入门指南
- 2024-11-23Java创业入门:从零开始的编程之旅
- 2024-11-23Java创业入门:新手必读的Java编程与创业指南
- 2024-11-23Java对接阿里云智能语音服务入门详解
- 2024-11-23Java对接阿里云智能语音服务入门教程
- 2024-11-23JAVA对接阿里云智能语音服务入门教程
- 2024-11-23Java副业入门:初学者的简单教程
- 2024-11-23JAVA副业入门:初学者的实战指南