tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例

升级到tf 2.0后, 训练的模型想转成1.x版本的.pb模型, 但之前提供的通过ckpt转pb模型的方法都不可用(因为保存的ckpt不再有.meta)文件, 尝试了好久, 终于找到了一个方法可以迂回转到1.x版本的pb模型.

Note: 本方法首先有些要求需要满足:

可以拿的到模型的网络结构定义源码

网络结构里面的所有操作都是通过tf.keras完成的, 不能出现类似tf.nn 的tensorflow自己的操作符

tf2.0下保存的模型是.h5格式的,并且仅保存了weights, 即通过model.save_weights保存的模型.

在tf1.x的环境下, 将tf2.0保存的weights转为pb模型:

如果在tf2.0下保存的模型符合上述的三个定义, 那么这个.h5文件在1.x环境下其实是可以直接用的, 因为都是通过tf.keras高级封装了,2.0版本和1.x版本不存在特别大的区别,我自己的模型是可以直接用的.

import tensorflow as tf
import os
from nets.efficientNet import *
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
# 这个代码网上说需要加上, 如果模型里有dropout , bn层的话, 我测试过加不加结果都一样, 保险起见还是加上吧
tf.keras.backend.set_learning_phase(0)

# 首先是定义你的模型, 这个需要和tf2.0下一毛一样
inputs = tf.keras.Input(shape=(224, 224, 3), name='modelInput')
outputs = yourModel(inputs, training=False)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.load_weights('save_weights.h5')
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
  """
  Freezes the state of a session into a pruned computation graph.

  Creates a new computation graph where variable nodes are replaced by
  constants taking their current value in the session. The new graph will be
  pruned so subgraphs that are not necessary to compute the requested
  outputs are removed.
  @param session The TensorFlow session to be frozen.
  @param keep_var_names A list of variable names that should not be frozen,
             or None to freeze all the variables in the graph.
  @param output_names Names of the relevant graph outputs.
  @param clear_devices Remove the device directives from the graph for better portability.
  @return The frozen graph definition.
  """
  from tensorflow.python.framework.graph_util import convert_variables_to_constants
  graph = session.graph
  with graph.as_default():
    freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
    output_names = output_names or []
    output_names += [v.op.name for v in tf.global_variables()]
    # Graph -> GraphDef ProtoBuf
    input_graph_def = graph.as_graph_def(add_shapes=True)
    if clear_devices:
      for node in input_graph_def.node:
        node.device = ""
    frozen_graph = convert_variables_to_constants(session, input_graph_def,
                           output_names, freeze_var_names)
    return frozen_graph

frozen_graph = freeze_session(tf.keras.backend.get_session(), output_names=[out.op.name for out in model.outputs])
tf.train.write_graph(frozen_graph, "model", "tf_model.pb", as_text=False)

运行成功后, 会在当前目录下生成一个model文件夹, 里面有生成的tf_model.pb文件, 至此, 我们就完成了将tf2.0下训练的模型转到tf1.x下的pb模型, 这样,就可以用这个pb模型做其它推理或者转tvm ncnn等模型转换工作.

这个转换的重点就是通过keras这个中间商来完成, 所以我们定义的模型就必须要满足这个中间商定义的条件

补充知识:tensorflow2.0降级及如何从别的版本升到2.0

代码实践《tensorflow实战GOOGLE深度学习框架》时,由于本机安装的tensorflow为2.0版本与配套书籍代码1.4的API不兼容,只得将tensorflow降级为1.4.0版本使用,降级方法如下

1 pip uninstall tensorflow

2 pip install tensorflow==1.14.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

验证

import tensorflow as tf
print(tf.version)

二 从别的版本升级到2.0

自动卸载与其相关包

pip uninstall tensorflow

安装某版本

pip install --no-cache-dir tensorflow==x.xx (此处填写2.0)

验证

以上这篇tensorflow 2.0模式下训练的模型转成 tf1.x 版本的pb模型实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

时间: 2020-06-22

安装多个版本的TensorFlow的方法步骤

TensorFlow 2.0测试版在今年春季发布,新版本比1.x版本在易用性上有了很大的提升.但是由于2.0发布还没有多久,现在大部分论文的实现代码都是1.x版本的,所以在学习TensorFlow的过程中同时安装1.x和2.0两个版本是很有必要的. 下面是具体操作 首先需要安装Anaconda 然后进入Anaconda prompt(未避免安装失败,最好以管理员身份运行) 安装第一个版本的tensorflow: 现在是默认环境,输入要安装的第一个tensorflow版本:pip install

解决Linux Tensorflow2.0安装问题

conda update conda pip install tf-nightly-gpu-2.0-preview conda install https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/linux-64/cudnn-7.3.1-cuda10.0_0.tar.bz2 conda install https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/linux-64/cud

Windows10下Tensorflow2.0 安装及环境配置教程(图文)

下载安装Anaconda 下载地址如下,根据所需版本下载 安装过程暂略(下次在安装时添加) 下载安装Pycharm 下载安装Pycharm,下载对应使用版本即可 如果你是在校学生,有学校的edu邮箱,可以免费注册Pycharm专业版,注册地址如下,本文不详细说明 下载CUDA10.0 下载地址如下CUDA Toolkit 10.0 Archive 下载之后默认安装即可 下载CUDNN 通过此处选择版本对应的CUDNN,对于本次配置就选择Windows 10对应的版本 下载CUDNN需要注册一个N

tensorflow2.0保存和恢复模型3种方法

方法1:只保存模型的权重和偏置 这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同. tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了. save_weights( filepath, overwrite=True, save_format=None ) Arguments: filepath: String,

TensorFlow2.1.0最新版本安装详细教程

TensorFlow是一款优秀的深度学习框架,支持多种常见的操作系统,例如Windows10,Mac Os等等,同时也支持运行在NVIDIA显卡上的GPU版本以及仅使用CPU进行运算的CPU版本.此篇教程将介绍如何安装最新版TensorFlow框架(2.1.0版本) 安装步骤 1.常用IDE安装 2.CUDA安装 3.cuDNN神经网络加速库安装 4.TensorFlow框架安装 常用IDE安装 用户在Python官网上可以下载到最新版本(Python3.7)的解释器.(Python官网)Pyt

微信小程序(微信应用号)开发工具0.9版安装详细教程

微信小程序全称微信公众平台·小程序,原名微信公众平台·应用号(简称微信应用号) 声明 •微信小程序开发工具类似于一个轻量级的IDE集成开发环境,目前仅开放给了少部分受微信官方邀请的人士(据说仅200个名额)进行内测,因此目前未受到邀请的人士只能使用破解版: •本破解版资源来自于网上,与本人无关,仅供技术开发人员研究之用: •由于尚属内测阶段,因此迭代更新非常快,后续很可能由于升级而导致暂时无法使用. 特别注意 •由于目前发布的0.9版本必须验证才能登录(估计是为了验证是否为内测人士),因此必须先

MySQL v5.7.18 解压版本安装详细教程

下载MySQL https://dev.mysql.com/downloads/mysql/5.1.html#downloads 个人机子是64位的,所以选择下载:Windows (x86, 64-bit), ZIP Archive 版本 解压并安装 1.将下载下载的包解压到指定目录,(本人)解压到:D:\Program Files (x86) 目录下. 因此,MySQL的(安装)包的完整路径为:D:\Program Files (x86)\mysql-5.7.18-winx64 2.解压后,配

mysql 8.0.15 下载安装详细教程 新手必备!

本文记录了mysql 8.0.15 下载安装的具体步骤,供大家参考,具体内容如下 背景:作为一个热爱技术但不懂代码的产品写的教程 1.环境 系统:windows 64位 mysql版本:mysql 8.0.15 2.下载篇 首先是下载数据库安装文件,进入mysql官网下载频道https://www.mysql.com/downloads/,依次点击Community→MySQL Community Server,如下图: 进入下载页面后,选择操作系统,这里我们选择默认的Microsoft Win

Windows下mysql 8.0.12 安装详细教程

本文为大家分享了mysql 8.0.12 安装详细教程,供大家参考,具体内容如下 一.安装 1.从官网上下载MySQL8.0.12版本,下载链接 2.下载后解压到一个文件夹下 我的解压路径:C:\Program Files\MySQL8.0.12 (将压缩后的文件夹放在D盘,通过cmd进不去指定文件夹下,将其放在C盘后就没问题了.) 3.文件配置 首先在解压的路径下通过记事本新建一个my.ini文件, 内容如下: [mysqld] # 设置3306端口 port=3306 # 设置mysql的安

mysql installer community 5.7.16安装详细教程

本文记录了mysql安装详细教程,分享给大家. 一.版本的选择 之前安装的Mysql,现在才来总结,好像有点晚,后台换系统了,现在从新装上Mysql,感觉好多坑,我是来踩坑,大家看到坑就别跳了,这样可以省点安装时间,这个折腾了两天,安装了好多个版本,终于安装好了,最终选择了最新的版本mysql-installer-community-5.7.16.0. 以前是在其他软件网站下载的,但是觉得还是在官方网站下载比较靠谱. 进入到MySql官方网站,进入到下载界面.看到这个,选择"MySQL Inst

[Oracle] Data Guard CPU/PSU补丁安装详细教程

非Data Guard的补丁安装教程可参考<[Oracle] CPU/PSU补丁安装详细教程>,Data Guard需要Primary和Standby同时打上补丁,所以步骤更复杂一些,其主要步骤如下:1.在Primary停止日志传输服务:2.关闭Standby数据库,在Standby的软件上打补丁(注意:不需要为Standby数据库打补丁),启动standby为mount状态,不启用managed recovery:3.关闭Primary,在Primary的软件和数据库本身都打上补丁:4.启动

MySQL5.6.22 绿色版 安装详细教程(图解)

1.数据库下载 从官方网站可以找到两种文件包,一种是exe安装程序,另一种是zip压缩包. 本人喜欢清爽的方式,所以下载的是ZIP压缩包.最新的5.6.22大概350M,下载还需要oracle帐号,自己注册一个好了. 2.数据库安装 解压出下载的文件mysql-5.6.22-win32.zip(有x86和x64两个版本)到任一目录,防止出现未知问题,最好放在非系统盘的非中文目录下,我的位置E:\mysql\mysql-5.6.24-win32.打开文件夹复制一份my-default.ini为my

window下homestead开发环境安装详细教程

一.资源准备 链接:http://pan.baidu.com/s/1mh7qUBe 密码:p4wx 1. virtualbox.box文件放在C盘根目录上. 2. metadata.json文件放在C盘用户目录上.比如我的是 C:\Users\pc 3. Git-2.9.3-64-bit.exe 下载安装,全部默认就好. 4. vagrant_1.8.5.msi 下载安装. 5. VirtualBox-5.1.4-110228-Win.exe 下载安装. 二.简介 1. 什么是vagrant?

MySql 5.6.35 winx64 安装详细教程

说明:因为数据库版本问题出现的项目启动没有错误,但是操作数据库的过程出现错误,为了保持数据库一致,重新检索到了安装mysql5.6的教程,不复杂,需要耐心. 若笔记本原本安装了其他数据库版本,请先将mysql数据库卸载干净,具体请参见网址:http://materliu.github.io/all/web/database/mysql/2014/04/24/uninstall-mysql-totaly.cm.html 为了防止网址不能访问或者不存在的情况,具体步骤如下: 1.首先在windows

windows10系统安装mysql-8.0.13(zip安装) 的教程详解

安装环境说明 •系统版本:windows10 •mysql版本:mysql-8.0.13-winx64.zip •下载地址:http://mirrors.163.com/mysql/Downloads/MySQL-8.0/mysql-8.0.13-winx64.zip 解压安装包 •解压路径:D:\develop\software •解压后mysql根目录:D:\develop\software\mysql-8.0.13-winx64 原则: 安装目录不能有空格.不能有中文 配置文件 •my.i