Tensorflow2.1 完成权重或模型的保存和加载

目录
  • 前言
  • 实现方法
    • 1. 读取数据
    • 2. 搭建深度学习模型
    • 3. 使用回调函数在每个 epoch 后自动保存模型权重
    • 4. 使用回调函数每经过 5 个 epoch 对模型权重保存一次
    • 5. 手动保存模型权重到指定目录
    • 6. 手动保存整个模型结构和权重

前言

本文主要使用 cpu 版本的 tensorflow-2.1 来完成深度学习权重参数/模型的保存和加载操作。

在我们进行项目期间,很多时候都要在模型训练期间、训练结束之后对模型或者模型权重进行保存,然后我们可以从之前停止的地方恢复原模型效果继续进行训练或者直接投入实际使用,另外为了节省存储空间我们还可以自定义保存内容和保存频率。

实现方法

1. 读取数据

(1)本文重点介绍模型或者模型权重的保存和读取的相关操作,使用到的是 MNIST 数据集仅是为了演示效果,我们无需关心模型训练的质量好坏。

(2)这里是常规的读取数据操作,我们为了能较快介绍本文重点内容,只使用了 MNIST 前 1000 条数据,然后对数据进行归一化操作,加快模型训练收敛速度,并且将每张图片的数据从二维压缩成一维。

import os
import tensorflow as tf
from tensorflow import keras
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

2. 搭建深度学习模型

(1)这里主要是搭建一个最简单的深度学习模型。

(2)第一层将图片的长度为 784 的一维向量转换成 256 维向量的全连接操作,并且用到了 relu 激活函数。

(3)第二层紧接着使用了防止过拟合的 Dropout 操作,神经元丢弃率为 50% 。

(4)第三层为输出层,也就是输出每张图片属于对应 10 种类别的分布概率。

(5)优化器我们选择了最常见的 Adam 。

(6)损失函数选择了 SparseCategoricalCrossentropy 。

(7)评估指标选用了 SparseCategoricalAccuracy 。

def create_model():
    model = tf.keras.Sequential([keras.layers.Dense(256, activation='relu', input_shape=(784,)),
                                 keras.layers.Dropout(0.5),
                                 keras.layers.Dense(10) ])
    model.compile(optimizer='adam',
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
    return model

3. 使用回调函数在每个 epoch 后自动保存模型权重

(1)这里介绍一种在模型训练期间保存权重参数的方法,我们定义一个回调函数 callback ,它可以在训练过程中将权重保存在自定义目录中 weights_path ,在训练过程中一共执行 5 次 epoch ,每次 epoch 结束之后就会保存一次模型的权重到指定的目录。

(2)可以看到最后使用测试集进行评估的 loss 为 0.4952 ,分类准确率为 0.8500 。

weights_path = "training_weights/cp.ckpt"
weights_dir = os.path.dirname(weights_path)
callback = tf.keras.callbacks.ModelCheckpoint(filepath=weights_path, save_weights_only=True,  verbose=1)
model = create_model()
model.fit(train_images,
          train_labels,
          epochs=5,
          validation_data=(test_images, test_labels),
          callbacks=[callback])

输出结果为:

val_loss: 0.4952 - val_sparse_categorical_accuracy: 0.8500

(3)我们浏览目标文件夹里,只有三个文件,每个 epoch 后自动都会保存三个文件,在下一次 epoch 之后会自动更新这三个文件的内容。

os.listdir(weights_dir)

结果为:

['checkpoint', 'cp.ckpt.data-00000-of-00001', 'cp.ckpt.index']

(4) 我们通过 create_model 定义了一个新的模型实例,然后让其在没有训练的情况下使用测试数据进行评估,结果可想而知,准确率差的离谱。

NewModel = create_model()
loss, acc = NewModel.evaluate(test_images, test_labels, verbose=2)

结果为:

loss: 2.3694 - sparse_categorical_accuracy: 0.1330

(5) tensorflow 中只要两个模型有相同的模型结构,就可以在它们之间共享权重,所以我们使用 NewModel 读取了之前训练好的模型权重,再使用测试集对其进行评估发现,损失值和准确率和旧模型的结果完全一样,说明权重被相同结构的新模型成功加载并使用。

NewModel.load_weights(checkpoint_path)
loss, acc = NewModel.evaluate(test_images, test_labels, verbose=2)

输出结果:

loss: 0.4952 - sparse_categorical_accuracy: 0.8500

4. 使用回调函数每经过 5 个 epoch 对模型权重保存一次

(1)如果我们想保留多个中间 epoch 的模型训练的权重,或者我们想每隔几个 epoch 保存一次模型训练的权重,这时候我们可以通过设置保存频率 period 来完成,我这里让新建的模型训练 30 个 epoch ,在每经过 10 epoch 后保存一次模型训练好的权重。

(2)使用测试集对此次模型进行评估,损失值为 0.4047 ,准确率为 0.8680 。

weights_path = "training_weights2/cp-{epoch:04d}.ckpt"
weights_dir = os.path.dirname(weights_path)
batch_size = 64
cp_callback = tf.keras.callbacks.ModelCheckpoint( filepath=weights_path,
                                                  verbose=1,
                                                  save_weights_only=True,
                                                  period=10)
model = create_model()
model.save_weights(weights_path.format(epoch=1))
model.fit(train_images,
          train_labels,
          epochs=30,
          batch_size=batch_size,
          callbacks=[cp_callback],
          validation_data=(test_images, test_labels),
          verbose=1)

结果输出为:

val_loss: 0.4047 - val_sparse_categorical_accuracy: 0.8680

(3)这里我们能看到指定目录中的文件组成,这里的 0001 是因为训练时指定了要保存的 epoch 的权重,其他都是每 10 个 epoch 保存的权重参数文件。目录中有一个 checkpoint ,它是一个检查点文本文件,文件保存了一个目录下所有的模型文件列表,首行记录的是最后(最近)一次保存的模型名称。

(4)每个 epoch 保存下来的文件都包含:

  • 一个索引文件,指示哪些权重存储在哪个分片中
  • 一个或多个包含模型权重的分片

浏览文件夹内容

os.listdir(weights_dir)

结果如下:

['checkpoint', 'cp-0001.ckpt.data-00000-of-00001', 'cp-0001.ckpt.index', 'cp-0010.ckpt.data-00000-of-00001', 'cp-0010.ckpt.index', 'cp-0020.ckpt.data-00000-of-00001', 'cp-0020.ckpt.index', 'cp-0030.ckpt.data-00000-of-00001', 'cp-0030.ckpt.index']

(5)我们将最后一次保存的权重读取出来,然后创建一个新的模型去读取刚刚保存的最新的之前训练好的模型权重,然后通过测试集对新模型进行评估,发现损失值准确率和之前完全一样,说明权重被成功读取并使用。

latest = tf.train.latest_checkpoint(weights_dir)
newModel = create_model()
newModel.load_weights(latest)
loss, acc = newModel.evaluate(test_images, test_labels, verbose=2)

结果如下:

loss: 0.4047 - sparse_categorical_accuracy: 0.8680

5. 手动保存模型权重到指定目录

(1)有时候我们还想手动将模型训练好的权重保存到指定的目录下,我们可以使用 save_weights 函数,通过我们新建了一个同样的新模型,然后使用 load_weights 函数去读取权重并使用测试集对其进行评估,发现损失值和准确率仍然和之前的两种结果完全一样。

model.save_weights('./training_weights3/my_cp')
newModel = create_model()
newModel.load_weights('./training_weights3/my_cp')
loss, acc = newModel.evaluate(test_images, test_labels, verbose=2)

结果如下:

loss: 0.4047 - sparse_categorical_accuracy: 0.8680

6. 手动保存整个模型结构和权重

(1)有时候我们还需要保存整个模型的结构和权重,这时候我们直接使用 save 函数即可将这些内容保存到指定目录,使用该方法要保证目录是存在的否则会报错,所以这里我们要创建文件夹。我们能看到损失值为 0.4821,准确率为 0.8460 。

model = create_model()
model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels), verbose=1)
!mkdir my_model
modelPath = './my_model'
model.save(modelPath)

输出结果:

val_loss: 0.4821 - val_sparse_categorical_accuracy: 0.8460

(2)然后我们通过函数 load_model 即可生成出一个新的完全一样结构和权重的模型,我们使用测试集对其进行评估,发现准确率和损失值和之前完全一样,说明模型结构和权重被完全读取恢复。

new_model = tf.keras.models.load_model(modelPath)
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)

输出结果:

loss: 0.4821 - sparse_categorical_accuracy: 0.8460

以上就是Tensorflow2.1 完成权重或模型的保存和加载的详细内容,更多关于Tensorflow完成权重模型保存加载的资料请关注我们其它相关文章!

时间: 2022-11-16

python深度学习tensorflow训练好的模型进行图像分类

目录 正文 随机找一张图片 读取图片进行分类识别 最后输出 正文 谷歌在大型图像数据库ImageNet上训练好了一个Inception-v3模型,这个模型我们可以直接用来进来图像分类. 下载链接: https://pan.baidu.com/s/1XGfwYer5pIEDkpM3nM6o2A 提取码: hu66 下载完解压后,得到几个文件: 其中 classify_image_graph_def.pb 文件就是训练好的Inception-v3模型. imagenet_synset_to_huma

Tensorflow2.4使用Tuner选择模型最佳超参详解

目录 前言 实现过程 1. 获取 MNIST 数据并进行处理 2. 搭建超模型 3. 实例化调节器并进行模型超调 4. 训练模型获得最佳 epoch 5. 使用最有超参数集进行模型训练和评估 前言 本文使用 cpu 版本的 tensorflow 2.4 ,选用 Keras Tuner 工具以 Fashion 数据集的分类任务为例,完成最优超参数的快速选择任务. 当我们搭建完成深度学习模型结构之后,我们在训练模型的过程中,有很大一部分工作主要是通过验证集评估指标,来不断调节模型的超参数,这是比较耗

Tensorflow高性能数据优化增强工具Pipeline使用详解

目录 安装方法 功能 高级用户部分 用例1,为训练创建数据Pipeline 用例2,为验证创建数据Pipeline 初学者部分 Keras 兼容性 配置 增强: GridMask MixUp RandomErase CutMix Mosaic CutMix , CutOut, MixUp Mosaic Grid Mask 安装方法 给大家介绍一个非常好用的TensorFlow数据pipeline工具. 高性能的Tensorflow Data Pipeline,使用SOTA的增强和底层优化. pi

python深度学习tensorflow1.0参数初始化initializer

目录 正文 所有初始化方法定义 1.tf.constant_initializer() 2.tf.truncated_normal_initializer() 3.tf.random_normal_initializer() 4.random_uniform_initializer = RandomUniform() 5.tf.uniform_unit_scaling_initializer() 6.tf.variance_scaling_initializer() 7.tf.orthogona

python深度学习tensorflow1.0参数和特征提取

目录 tf.trainable_variables()提取训练参数 具体实例 tf.trainable_variables()提取训练参数 在tf中,参与训练的参数可用 tf.trainable_variables()提取出来,如: #取出所有参与训练的参数 params=tf.trainable_variables() print("Trainable variables:------------------------") #循环列出参数 for idx, v in enumera

python深度学习tensorflow卷积层示例教程

目录 一.旧版本(1.0以下)的卷积函数:tf.nn.conv2d 二.1.0版本中的卷积函数:tf.layers.conv2d 一.旧版本(1.0以下)的卷积函数:tf.nn.conv2d 在tf1.0中,对卷积层重新进行了封装,比原来版本的卷积层有了很大的简化. conv2d( input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=None ) 该函数定义在tensorflow/pytho

Tensorflow 2.4加载处理图片的三种方式详解

目录 前言 数据准备 使用内置函数读取并处理磁盘数据 自定义方式读取和处理磁盘数据 从网络上下载数据 前言 本文通过使用 cpu 版本的 tensorflow 2.4 ,介绍三种方式进行加载和预处理图片数据. 这里我们要确保 tensorflow 在 2.4 版本以上 ,python 在 3.8 版本以上,因为版本太低有些内置函数无法使用,然后要提前安装好 pillow 和 tensorflow_datasets ,方便进行后续的数据加载和处理工作. 由于本文不对模型进行质量保证,只介绍数据的加

基于vue中css预加载使用sass的配置方式详解

1.安装sass的依赖包 npm install --save-dev sass-loader //sass-loader依赖于node-sass npm install --save-dev node-sass 2.在build文件夹下的webpack.base.conf.js的rules里面添加配置,如下红色部分 { test: /\.sass$/, loaders: ['style', 'css', 'sass'] } <span style="color:#454545;"

Python图片存储和访问的三种方式详解

目录 前言 数据准备 一个可以玩的数据集 图像存储的设置 LMDB HDF5 单一图像的存储 存储到 磁盘 存储到 LMDB 存储 HDF5 存储方式对比 多个图像的存储 多图像调整代码 准备数据集对比 单一图像的读取 从 磁盘 读取 从 LMDB 读取 从 HDF5 读取 读取方式对比 多个图像的读取 多图像调整代码 准备数据集对比 读写操作综合比较 数据对比 并行操作 前言 ImageNet 是一个著名的公共图像数据库,用于训练对象分类.检测和分割等任务的模型,它包含超过 1400 万张图像

Pandas保存csv数据的三种方式详解

目录 方法一 方法二 方法三 补充 方法一 import os import pandas as pd path = 'data/train/' img_label_list=[] testList = os.listdir(path) for file in testList: label='aa' img_label_list.append([file, label]) df1 = pd.DataFrame(data=img_label_list, columns=['id', 'label

Python绘制散点密度图的三种方式详解

目录 方式一 方式二 方式三 方式一 import matplotlib.pyplot as plt import numpy as np from scipy.stats import gaussian_kde from mpl_toolkits.axes_grid1 import make_axes_locatable from matplotlib import rcParams config = {"font.family":'Times New Roman',"fo

Python写入MySQL数据库的三种方式详解

目录 场景一:数据不需要频繁的写入mysql 场景二:数据是增量的,需要自动化并频繁写入mysql 方式一 方式二 总结 大家好,Python 读取数据自动写入 MySQL 数据库,这个需求在工作中是非常普遍的,主要涉及到 python 操作数据库,读写更新等,数据库可能是 mongodb. es,他们的处理思路都是相似的,只需要将操作数据库的语法更换即可. 本篇文章会给大家分享数据如何写入到 mysql,分为两个场景,三种方式. 场景一:数据不需要频繁的写入mysql 使用 navicat 工

Java实现AOP代理的三种方式详解

目录 1.JDK实现 2.CGLIB实现 3.boot注解实现[注意只对bean有效] 业务场景:首先你有了一个非常好的前辈无时无刻的在“教育”你.有这么一天,它叫你将它写好的一个方法进行改进测试,这时出现了功能迭代的情况.然后前辈好好“教育”你的说,不行改我的代码!改就腿打折!悲催的你有两条路可走,拿出你10年跆拳道的功夫去火拼一波然后拍拍屁股潇洒走人,要么就是悲催的开始百度...这时你会发现,我擦怎么把AOP代理这种事给忘了?[其实在我们工作中很少去手写它,但是它又是很常见的在使用(控制台日

Android Flutter实现搜索的三种方式详解

目录 示例 1 :使用搜索表单创建全屏模式 编码 示例 2:AppBar 内的搜索字段(最常见于娱乐应用程序) 编码 示例 3:搜索字段和 SliverAppBar 编码 结论 示例 1 :使用搜索表单创建全屏模式 我们要构建的小应用程序有一个应用程序栏,右侧有一个搜索按钮.按下此按钮时,将出现一个全屏模式对话框.它不会突然跳出来,而是带有淡入淡出动画和幻灯片动画(从上到下).在圆形搜索字段旁边,有一个取消按钮,可用于关闭模式.在搜索字段下方,我们会显示一些搜索历史记录(您可以添加其他内容,如建

命令行运行Python脚本时传入参数的三种方式详解

如果在运行python脚本时需要传入一些参数,例如gpus与batch_size,可以使用如下三种方式. python script.py 0,1,2 10 python script.py -gpus=0,1,2 --batch-size=10 python script.py -gpus=0,1,2 --batch_size=10 这三种格式对应不同的参数解析方式,分别为sys.argv, argparse, tf.app.run, 前两者是python自带的功能,最后一个是tensorfl

CodeIgniter中使用cookie的三种方式详解

cookie在php程序设计中应用十分广泛,本文所述CodeIgniter中使用cookie主要有以下三种方式,读者可以根据自身项目需求酌情采用. 第一种方式:采用php原生态的方法设置的cookie的值 setcookie("user_id",$user_info['user_id'],86500); setcookie("username",$user_info['username'],86500); setcookie("password"

基于java解析JSON的三种方式详解

本文实例分析了基于java解析JSON的三种方式.分享给大家供大家参考,具体如下: 一.什么是JSON? JSON是一种取代XML的数据结构,和xml相比,它更小巧但描述能力却不差,由于它的小巧所以网络传输数据将减少更多流量从而加快速度. JSON就是一串字符串 只不过元素会使用特定的符号标注. {} 双括号表示对象 [] 中括号表示数组 "" 双引号内是属性或值 : 冒号表示后者是前者的值(这个值可以是字符串.数字.也可以是另一个数组或对象) 所以 {"name"