一文详解如何实现PyTorch模型编译

目录
  • 准备
  • 加载预训练的 PyTorch 模型​
  • 加载测试图像​
  • 将计算图导入 Relay​
  • Relay 构建​
  • 在 TVM 上执行可移植计算图​
  • 查找分类集名称​

准备

本篇文章译自英文文档 Compile PyTorch Models

作者是 Alex Wong

更多 TVM 中文文档可访问 →TVM 中文站

本文介绍了如何用 Relay 部署 PyTorch 模型。

首先应安装 PyTorch。此外,还应安装 TorchVision,并将其作为模型合集 (model zoo)。

可通过 pip 快速安装:

pip install torch==1.7.0
pip install torchvision==0.8.1

或参考官网:pytorch.org/get-started…

PyTorch 版本应该和 TorchVision 版本兼容。

目前 TVM 支持 PyTorch 1.7 和 1.4,其他版本可能不稳定。

import tvm
from tvm import relay
import numpy as np
from tvm.contrib.download import download_testdata
# 导入 PyTorch
import torch
import torchvision

加载预训练的 PyTorch 模型​

model_name = "resnet18"
model = getattr(torchvision.models, model_name)(pretrained=True)
model = model.eval()
# 通过追踪获取 TorchScripted 模型
input_shape = [1, 3, 224, 224]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
输出结果:

Downloading: "download.pytorch.org/models/resn…" to /workspace/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth

0%| | 0.00/44.7M [00:00<?, ?B/s] 11%|# | 4.87M/44.7M [00:00<00:00, 51.0MB/s] 22%|##1 | 9.73M/44.7M [00:00<00:00, 49.2MB/s] 74%|#######3 | 32.9M/44.7M [00:00<00:00, 136MB/s] 100%|##########| 44.7M/44.7M [00:00<00:00, 129MB/s]

加载测试图像​

经典的猫咪示例:

from PIL import Image
img_url = "https://github.com/dmlc/mxnet.js/blob/main/data/cat.png?raw=true"
img_path = download_testdata(img_url, "cat.png", module="data")
img = Image.open(img_path).resize((224, 224))
# 预处理图像,并将其转换为张量
from torchvision import transforms
my_preprocess = transforms.Compose(
 [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 ]
)
img = my_preprocess(img)
img = np.expand_dims(img, 0)

将计算图导入 Relay​

将 PyTorch 计算图转换为 Relay 计算图。input_name 可以是任意值。

input_name = "input0"
shape_list = [(input_name, img.shape)]
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)

Relay 构建​

用给定的输入规范,将计算图编译为 llvm target。

target = tvm.target.Target("llvm", host="llvm")
dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

输出结果:

/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
 "target_host parameter is going to be deprecated. "

在 TVM 上执行可移植计算图​

将编译好的模型部署到 target 上:

from tvm.contrib import graph_executor
dtype = "float32"
m = graph_executor.GraphModule(lib["default"](dev))
# 设置输入
m.set_input(input_name, tvm.nd.array(img.astype(dtype)))
# 执行
m.run()
# 得到输出
tvm_output = m.get_output(0)

查找分类集名称​

在 1000 个类的分类集中,查找分数最高的第一个:

synset_url = "".join(
 [
 "https://raw.githubusercontent.com/Cadene/",
 "pretrained-models.pytorch/master/data/",
 "imagenet_synsets.txt",
 ]
)
synset_name = "imagenet_synsets.txt"
synset_path = download_testdata(synset_url, synset_name, module="data")
with open(synset_path) as f:
    synsets = f.readlines()
synsets = [x.strip() for x in synsets]
splits = [line.split(" ") for line in synsets]
key_to_classname = {spl[0]: " ".join(spl[1:]) for spl in splits}
class_url = "".join(
 [
 "https://raw.githubusercontent.com/Cadene/",
 "pretrained-models.pytorch/master/data/",
 "imagenet_classes.txt",
 ]
)
class_name = "imagenet_classes.txt"
class_path = download_testdata(class_url, class_name, module="data")
with open(class_path) as f:
    class_id_to_key = f.readlines()
class_id_to_key = [x.strip() for x in class_id_to_key]
# 获得 TVM 的前 1 个结果
top1_tvm = np.argmax(tvm_output.numpy()[0])
tvm_class_key = class_id_to_key[top1_tvm]
# 将输入转换为 PyTorch 变量,并获取 PyTorch 结果进行比较
with torch.no_grad():
    torch_img = torch.from_numpy(img)
    output = model(torch_img)
 # 获得 PyTorch 的前 1 个结果
    top1_torch = np.argmax(output.numpy())
    torch_class_key = class_id_to_key[top1_torch]
print("Relay top-1 id: {}, class name: {}".format(top1_tvm, key_to_classname[tvm_class_key]))
print("Torch top-1 id: {}, class name: {}".format(top1_torch, key_to_classname[torch_class_key]))

输出结果:

Relay top-1 id: 281, class name: tabby, tabby cat
Torch top-1 id: 281, class name: tabby, tabby cat

下载 Python 源代码:from_pytorch.py

下载 Jupyter Notebook:from_pytorch.ipynb

以上就是一文详解如何实现PyTorch 模型编译 的详细内容,更多关于PyTorch 模型编译 的资料请关注我们其它相关文章!

(0)

相关推荐

  • Pytorch模型定义与深度学习自查手册

    目录 定义神经网络 权重初始化 方法1:net.apply(weights_init) 方法2:在网络初始化的时候进行参数初始化 常用的操作 利用nn.Parameter()设计新的层 nn.Flatten nn.Sequential 常用的层 全连接层nn.Linear() torch.nn.Dropout 卷积torch.nn.ConvNd() 池化 最大池化torch.nn.MaxPoolNd() 均值池化torch.nn.AvgPoolNd() 反池化 最大值反池化nn.MaxUnpoo

  • pytorch模型的保存加载与续训练详解

    目录 前面 模型保存与加载 方式1 方式2 方式3 总结 前面 最近,看到不少小伙伴问pytorch如何保存和加载模型,其实这部分pytorch官网介绍的也是很清楚的,感兴趣的点击了解详情

  • AMP Tensor Cores节省内存PyTorch模型详解

    目录 导读 什么是Tensor Cores? 那么,我们如何使用Tensor Cores? 使用PyTorch进行混合精度训练: 基准测试 导读 只需要添加几行代码,就可以得到更快速,更省显存的PyTorch模型. 你知道吗,在1986年Geoffrey Hinton就在Nature论文中给出了反向传播算法? 此外,卷积网络最早是由Yann le cun在1998年提出的,用于数字分类,他使用了一个卷积层.但是直到2012年晚些时候,Alexnet才通过使用多个卷积层来实现最先进的imagene

  • 利用Pytorch实现ResNet网络构建及模型训练

    目录 构建网络 训练模型 构建网络 ResNet由一系列堆叠的残差块组成,其主要作用是通过无限制地增加网络深度,从而使其更加强大.在建立ResNet模型之前,让我们先定义4个层,每个层由多个残差块组成.这些层的目的是降低空间尺寸,同时增加通道数量. 以ResNet50为例,我们可以使用以下代码来定义ResNet网络: class ResNet(nn.Module): def __init__(self, num_classes=1000): super().__init__() self.con

  • 详解利用Pytorch实现ResNet网络

    目录 正文 评估模型 训练 ResNet50 模型 正文 每个 batch 前清空梯度,否则会将不同 batch 的梯度累加在一块,导致模型参数错误. 然后我们将输入和目标张量都移动到所需的设备上,并将模型的梯度设置为零.我们调用model(inputs)来计算模型的输出,并使用损失函数(在此处为交叉熵)来计算输出和目标之间的误差.然后我们通过调用loss.backward()来计算梯度,最后调用optimizer.step()来更新模型的参数. 在训练过程中,我们还计算了准确率和平均损失.我们

  • 一文详解JS中的事件循环机制

    目录 前言 1.JavaScript是单线程的 2.同步和异步 3.事件循环 前言 我们知道JavaScript 是单线程的编程语言,只能同一时间内做一件事,按顺序来处理事件,但是在遇到异步事件的时候,js线程并没有阻塞,还会继续执行,这又是为什么呢?本文来总结一下js 的事件循环机制. 1.JavaScript是单线程的 JavaScript 是一种单线程的编程语言,只有一个调用栈,决定了它在同一时间只能做一件事.在代码执行的时候,通过将不同函数的执行上下文压入执行栈中来保证代码的有序执行.在

  • 一文详解Java中的类加载机制

    目录 一.前言 二.类加载的时机 2.1 类加载过程 2.2 什么时候类初始化 2.3 被动引用不会初始化 三.类加载的过程 3.1 加载 3.2 验证 3.3 准备 3.4 解析 3.5 初始化 四.父类和子类初始化过程中的执行顺序 五.类加载器 5.1 类与类加载器 5.2 双亲委派模型 5.3 破坏双亲委派模型 六.Java模块化系统 一.前言 Java虚拟机把描述类的数据从Class文件加载到内存,并对数据进行校验.转换解析和初始化,最 终形成可以被虚拟机直接使用的Java类型,这个过程

  • 一文详解Golang协程调度器scheduler

    目录 1. 调度器scheduler的作用 2. GMP模型 3. 调度机制 1. 调度器scheduler的作用 我们都知道,在Go语言中,程序运行的最小单元是gorouines. 然而程序的运行最终都是要交给操作系统来执行的,以Java为例,Java中的一个线程对应的就是操作系统中的线程,以此来实现在操作系统中的运行.在Go中,gorouines比线程更轻量级,其与操作系统的线程也不是一一对应的关系,然而,最终我们想要执行程序,还是要借助操作系统的线程来完成,调度器scheduler的工作就

  • 一文详解MySQL Binlog日志与主从复制

    目录 1. Binlog日志的介绍 2. 主从复制 2.1 主从复制的流程 2.2 GTID 2.3 复制模型 2.4 MGR模式 2.5 并行回放 1. Binlog日志的介绍 Binlog是Binary log的缩写,即二进制日志.Binlog主要有三个作用:持久化时将随机IO转化为顺序IO,主从复制以及数据恢复.本文重点主从复制相关的问题. Binlog日志由一个索引文件与很多日志文件组成,每个日志文件由魔数以及事件组成,每个日志文件都会以一个Rotate类型的事件结束. 对于每个事件,都

  • 详解如何使用Pytorch进行多卡训练

    目录 1.DP 2.DDP 2.1Pytorch分布式基础 2.2Pytorch分布式训练DEMO 当一块GPU不够用时,我们就需要使用多卡进行并行训练.其中多卡并行可分为数据并行和模型并行.具体区别如下图所示: 由于模型并行比较少用,这里只对数据并行进行记录.对于pytorch,有两种方式可以进行数据并行:数据并行(DataParallel, DP)和分布式数据并行(DistributedDataParallel, DDP). 在多卡训练的实现上,DP与DDP的思路是相似的: 1.每张卡都复制

  • 一文详解Java etcd的应用场景及编码实战

    目录 一.白话etcd与zookeeper 二.etcd的4个核心机制 三.Leader选举与客户端交互 四.etcd的应用场景 4.1. kubernetes大脑 4.2. 服务注册与发现 4.3. 健康检查与状态变更通知 4.4.分布式锁 4.5.实现消息队列(纯扯淡) 五.etcd安装 六.jetcd的编码实现配置管理 本文首先用大白话给大家介绍一下etcd是什么?这部分内容网上已经有很多了. etcd有哪些应用场景?这些应用场景的核心原理是什么? 最后不能光动嘴不动手.先搭建一个etcd

  • 一文详解CNN 解决 Flowers 图像分类任务

    目录 前言 加载并展示数据 构件处理图像的 pipeline 搭建深度学习分类模型 训练模型并观察结果 加入了抑制过拟合措施并重新进行模型的训练和测试 前言 本文主要任务是使用通过 tf.keras.Sequential 搭建的模型进行各种花朵图像的分类,主要涉及到的内容有三个部分: 使用 tf.keras.Sequential 搭建模型. 使用 tf.keras.utils.image_dataset_from_directory 从磁盘中高效加载数据. 使用了一定的防止过拟合的方法,如丰富训

  • 一文详解Python灰色预测模型实现示例

    目录 前言 一.模型理论 特点 二.模型场景 1.预测种类 2.适用条件 三.建模流程 1.级比校验 3.系数求解 4.残差检验与级比偏差检验 四.Python实例实现 总结 前言 博主参与过大大小小十次数学建模比赛,也获得了不少建模奖项.对于一些小批量样本数据去做预测或者是评估其规律性的话,比较适合的模型一般都是选择灰色预测模型.该模型解释性强而且易于理解,建模手段也比较简单.在一些不确定是否存在相关标量或者是存在位置特征的时候,用灰色预测模型尤为明显,牵扯太多变量时候可以以量曾量减的方式显现

  • 一文详解Dart如何实现多任务并行

    目录 Isolate(隔离区域) async/await Stream Compute Function Isolate(隔离区域) Dart 是一种支持多任务并行的编程语言,它提供了多种机制来实现并发和并行.下面是 Dart 实现多任务并行的几种方式: Dart 中的 Isolate 是一种轻量级的并发机制,类似于线程.每个隔离区域都是独立的内存空间,每个隔离区域都有自己的内存空间和执行线程,因此不同的隔离区域之间可以独立地执行代码,每个隔离区都在自己的核心上运行,不会阻塞其他 Isolate

  • 一文详解typeScript的extends关键字

    目录 前言 extends 的几个语义 extends 与 类型组合/类继承 extends 与类型约束 extends 与条件类型 extends 与 {} extends 与 any extends 与 never extends 与 联合类型 extends 判断类型严格相等 extends 与类型推导 总结 前言 声明: 以下文章所包含的结论都是基于 typeScript@4.9.4 版本所取得的. extends 是 typeScript 中的关键字.在 typeScript 的类型编

随机推荐