ONNXRuntime学习笔记(三)

2022/5/1 6:16:28

本文主要是介绍ONNXRuntime学习笔记(三),对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

接上一篇完成的pytorch模型训练结果,模型结构为ResNet18+fc,参数量约为11M,最终测试集Acc达到94.83%。接下来有分两个部分:导出onnx和使用onnxruntime推理。

一、pytorch导出onnx

直接放函数吧,这部分我是直接放在test.py里面的,直接从dataloader中拿到一个batch的数据走一遍推理即可。

def export_onnx(net, testloader, output_file):
    net.eval()
    with torch.no_grad():
        for data in testloader: 
            images, labels = data

            torch.onnx.export(net, 
                            (images), 
                            output_file,
                            training=False,
                            do_constant_folding=True,
                            input_names=["img"], 
                            output_names=["output"],
                  dynamic_axes={"img": {0: "b"},"output": {0: "b"}}
                  )
            print("onnx export done!")
            break

上面函数中几个比较重要的参数:do_constant_folding是常量折叠,建议打开;输入张量通过一个tuple传入,并且最好指定每个输入和输出的名称,此外,为保证使用onnxruntime推理的时候batchsize可变,dynamic_axes的第一维需要像上述一样设置为动态的。如果是全卷积做分割的网络,类似的输入h和w也应该是动态的。

单独运行test.py计算测试集效果和平均相应时间,结果为:

Test Acc is: 94.83%
Average response time cost:  0.10121344916428192

二、使用onnxruntime推理

这里我们使用gpu版本的onnxruntime库进行推理,其python包可直接pip install onnxruntime-gpu安装。onnxruntime推理代码和测试集推理代码很类似,如下:

import numpy as np
import onnxruntime as ort
import argparse, os
from lib import CIFARDataset

def onnxruntime_test(session, testloader):
    print("Start Testing!")
    input_name = session.get_inputs()[0].name
    correct = 0
    total = 0   # 计数归零(初始化)
    for data in testloader:
        images, labels = data
        images, labels = images.numpy(), labels.numpy()
        outputs = session.run(None, {input_name:images})
        predicted = np.argmax(outputs[0], axis=1)  # 取得分最高的那个类
        total += labels.shape[0]                        # 累加样本总数
        correct += (predicted == labels).sum()        # 累加预测正确的样本个数
    acc = correct / total
    print('ONNXRuntime Test Acc is: %.2f%%' % (100*acc))
            
if __name__ == '__main__':
    # 命令行参数解析
    parser = argparse.ArgumentParser("CNN backbone on cifar10")
    parser.add_argument('--onnx', default='./output/test_resnet18_10_autoaug/densenet_best.onnx')
    args = parser.parse_args()

    NUM_CLASS =10
    BATCH_SIZE = 128  # 批处理尺寸(batch_size)

    # 数据集迭代器
    data_path="./data"
    dataset = CIFARDataset(dataset_path=data_path, batchsize=BATCH_SIZE)
    _, testloader = dataset.get_cifar10_dataloader()

    # 构建session
    sess = ort.InferenceSession(args.onnx, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])

    #onnxruntime推理
    import time
    start = time.time()
    onnxruntime_test(sess, testloader)
    end = time.time()
    print("Average response time cost: ", (end-start)/len(testloader))

使用onnxruntime加载导出的onnx模型,计算测试集效果和平均响应时间,结果为:

ONNXRuntime Test Acc is: 94.83%
Average response time cost:  0.07324151147769976

三、小结

分析上面的pytorch和onnxruntime的测试结果可知,最终测试集效果是一致的,Acc均为94.83%,但onnxruntime的效率更高,耗时是pytorch的75%,但比最初目标设定的50ms高,需要进一步优化,两个方向:模型量化或并行化推理。下一篇再分析。



这篇关于ONNXRuntime学习笔记(三)的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程