tensorflow模型继续训练 fineturn实例

解决tensoflow如何在已训练模型上继续训练fineturn的问题。

训练代码

任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf

# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])

# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')

# 操作
result = tf.matmul(x, W) + b

# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))

# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)

with tf.Session() as sess:
  # 初始化变量
  sess.run(tf.global_variables_initializer())
  saver = tf.train.Saver(max_to_keep=3)

  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[100.0]]

  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)

    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print ''
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

训练完成之后生成模型文件:

训练输出:

step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08

final result of x×W+b = [[99.99978]](目标值是100.0)

模型保存的W值 : 29.999931
模型保存的b : 9.999982

保存在模型中的W值是 29.999931,b是 9.999982。

以下代码从保存的模型中恢复出训练状态,继续训练

任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。

# -*- coding: utf-8 -*-)
import tensorflow as tf

# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])

with tf.Session() as sess:

  # 初始化变量
  sess.run(tf.global_variables_initializer())

  # saver = tf.train.Saver(max_to_keep=3)
  saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
  saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据

  # 从保存模型中恢复变量
  graph = tf.get_default_graph()
  W = graph.get_tensor_by_name("w:0")
  b = graph.get_tensor_by_name("b:0")

  print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
  print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))

  # 操作
  result = tf.matmul(x, W) + b
  # 损失函数
  lost = tf.reduce_sum(tf.pow((result - y), 2))
  # 优化
  train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)

  # 这里x、y给固定的值
  x_s = [[3.0]]
  y_s = [[200.0]]

  step = 0
  while (True):
    step += 1
    feed = {x: x_s, y: y_s}
    # 通过sess.run执行优化
    sess.run(train_step, feed_dict=feed)
    if step % 1000 == 0:
      print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
      if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
        print ''
        # print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
        print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
        print("模型保存的W值 : %f" % sess.run(W))
        print("模型保存的b : %f" % sess.run(b))
        break
  saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

训练输出:

从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07

final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958

从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。

总结

从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:

saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型

在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:

saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型

注:特殊情况下(如本例)需要从恢复的模型中加载出数据:

# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")

以上这篇tensorflow模型继续训练 fineturn实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

时间: 2020-01-20

浅谈tensorflow中张量的提取值和赋值

tf.gather和gather_nd从params中收集数值,tf.scatter_nd 和 tf.scatter_nd_update用updates更新某一张量.严格上说,tf.gather_nd和tf.scatter_nd_update互为逆操作. 已知数值的位置,从张量中提取数值:tf.gather, tf.gather_nd tf.gather indices每个元素(标量)是params某个axis的索引,tf.gather_nd 中indices最后一个阶对应于索引值. tf.ga

TensorFlow 模型载入方法汇总(小结)

一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 参数名称 功能说明 默认值 var_list Saver中存储变量集合 全局变量集合 reshape 加载时是否恢复变量形状 True sharded 是否将变量轮循放在所有设备上 True max_to_keep 保留最近检查点个数 5 restore_sequentially 是否按顺序恢复变量,模型较大时顺序恢复内存消耗小 True var_list是字典

Python Tensor FLow简单使用方法实例详解

本文实例讲述了Python Tensor FLow简单使用方法.分享给大家供大家参考,具体如下: 1.基础概念 Tensor表示张量,是一种多维数组的数据结构.Flow代表流,是指张量之间通过计算而转换的过程.TensorFLow通过一个计算图的形式表示编程过程,数据在每个节点之间流动,经过节点加工之后流向下一个节点. 计算图是一个有向图,其组成如下:节点:代表一个操作.边:代表节点之间的数据传递和控制依赖,其中实线代表两个节点之间的数据传递关系,虚线代表两个节点之间存在控制相关. 张量是所有数

Python生成短uuid的方法实例详解

python的uuid都是32位的,比较长,处理起来效率比较低, 本算法利用62个可打印字符,通过随机生成32位UUID,由于UUID都为十六进制,所以将UUID分成8组,每4个为一组,然后通过模62操作,结果作为索引取出字符, 最后生成的Uuid,只有8位,代码如下: uuid4,可以换成uuid1 from uuid import uuid4 uuidChars = ("a", "b", "c", "d", "e

Python编程之属性和方法实例详解

本文实例讲述了Python编程中属性和方法使用技巧.分享给大家供大家参考.具体分析如下: 一.属性 在python中,属性分为公有属性和私有属性,公有属性可以在类的外部调用,私有属性不能在类的外部调用.公有属性可以是任意变量,私有属性是以双下划线开头的变量. 下面我们定义一个People类,它有一个公有属性name,和一个私有属性__age. class People(): def __init(self): self.name='张珊' self.__age=24 我们创建一个People类的

Python基础学习之函数方法实例详解

本文实例讲述了Python基础学习之函数方法.分享给大家供大家参考,具体如下: 前言 与其他编程语言一样,函数(或者方法)是组织好的,可重复使用的,用来实现单一,或相关联功能的代码段. python的函数具有非常高的灵活性,可以在单个函数里面封装和定义另一个函数,使编程逻辑更具模块化. 一.Python的函数方法定义 函数方法定义的简单规则: 1. 函数代码块以 def 关键词开头,后接函数标识符名称和圆括号(). 2. 任何传入参数和自变量必须放在圆括号中间.圆括号之间可以用于定义参数. 3.

python实现DES加密解密方法实例详解

本文实例讲述了python实现DES加密解密方法.分享给大家供大家参考.具体分析如下: 实现功能:加密中文等字符串 密钥与明文可以不等长 这里只贴代码,加密过程可以自己百度,此处python代码没有优化 1. desstruct.py DES加密中要使用的结构体 ip= (58, 50, 42, 34, 26, 18, 10, 2, 60, 52, 44, 36, 28, 20, 12, 4, 62, 54, 46, 38, 30, 22, 14, 6, 64, 56, 48, 40, 32,

Python设计模式之简单工厂模式实例详解

本文实例讲述了Python设计模式之简单工厂模式.分享给大家供大家参考,具体如下: 简单工厂模式(Simple Factory Pattern):是通过专门定义一个类来负责创建其他类的实例,被创建的实例通常都具有共同的父类. 下面使用简单工厂模式实现一个简单的四则运算 #!/usr/bin/env python # -*- coding:utf-8 -*- __author__ = 'Andy' ''' 大话设计模式 用任意一种面向对象语言实现一个计算器控制台程序.要求输入两个数和运算符号,得到

Python比较配置文件的方法实例详解

工作中最常见的配置文件有四种:普通key=value的配置文件.Json格式的配置文件.HTML格式的配置文件以及YMAML配置文件. 这其中以第一种居多,后三种在成熟的开源产品中较为常见,本文只针对第一种配置文件. 一般来说Linux shell下提供了diff命令来比较普通文本类的配置文件,Python的difflib也提供了str和HTML的比较接口,但是实际项目中这些工具其实并不好用,主要是因为我们的配置文件并不是标准化统一化的. 为了解决此类问题,最好针对特定的项目写特定的配置文件比较

Python数据类型之列表和元组的方法实例详解

引言 我们前面的文章介绍了数字和字符串,比如我计算今天一天的开销花了多少钱我可以用数字来表示,如果是整形用 int ,如果是小数用 float ,如果你想记录某件东西花了多少钱,应该使用 str 字符串型,如果你想记录表示所有开销的物品名称,你应该用什么表示呢? 可能有人会想到我可以用一个较长的字符串表示,把所有开销物品名称写进去,但是问题来了,如果你发现你记录错误了,想删除掉某件物品的名称,那你是不是要在这个长字符串中去查找到,然后删除,这样虽然可行,那是不是比较麻烦呢. 这种情况下,你是不是

Android 实现夜间模式的快速简单方法实例详解

ChangeMode 项目地址:ChangeMode Implementation of night mode for Android. 用最简单的方式实现夜间模式,支持ListView.RecyclerView. Preview Usage xml android:background="?attr/zzbackground" app:backgroundAttr="zzbackground"//如果当前页面要立即刷新,这里传入属性名称 比如 R.attr.zzb

python字符串string的内置方法实例详解

下面给大家分享python 字符串string的内置方法,具体内容详情如下所示: #__author: "Pizer Wang" #__date: 2018/1/28 a = "Let's go" print(a) print("-------------------") a = 'Let\'s go' print(a) print("-------------------") print("hello"