tensorflow学习笔记之tfrecord文件的生成与读取

训练模型时,我们并不是直接将图像送入模型,而是先将图像转换为tfrecord文件,再将tfrecord文件送入模型。为进一步理解tfrecord文件,本例先将6幅图像及其标签转换为tfrecord文件,然后读取tfrecord文件,重现6幅图像及其标签。
1、生成tfrecord文件

import os
import numpy as np
import tensorflow as tf
from PIL import Image

filenames = [
'images/cat/1.jpg',
'images/cat/2.jpg',
'images/dog/1.jpg',
'images/dog/2.jpg',
'images/pig/1.jpg',
'images/pig/2.jpg',]

labels = {'cat':0, 'dog':1, 'pig':2}

def int64_feature(values):
	if not isinstance(values, (tuple, list)):
		values = [values]
	return tf.train.Feature(int64_list=tf.train.Int64List(value=values))

def bytes_feature(values):
	return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))

with tf.Session() as sess:
	output_filename = os.path.join('images/train.tfrecords')
	with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
		for filename in filenames:
			#读取图像
			image_data = Image.open(filename)
			#图像灰度化
			image_data = np.array(image_data.convert('L'))
			#将图像转化为bytes
			image_data = image_data.tobytes()
			#读取label
			label = labels[filename.split('/')[-2]]
			#生成protocol数据类型
			example = tf.train.Example(features=tf.train.Features(feature={'image': bytes_feature(image_data),
																			'label': int64_feature(label)}))
			tfrecord_writer.write(example.SerializeToString())

2、读取tfrecord文件

import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image

# 根据文件名生成一个队列
filename_queue = tf.train.string_input_producer(['images/train.tfrecords'])
reader = tf.TFRecordReader()
# 返回文件名和文件
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
									features={'image': tf.FixedLenFeature([], tf.string),
												'label': tf.FixedLenFeature([], tf.int64)})
# 获取图像数据
image = tf.decode_raw(features['image'], tf.uint8)
# 恢复图像原始尺寸[高,宽]
image = tf.reshape(image, [60, 160])
# 获取label
label = tf.cast(features['label'], tf.int32)

with tf.Session() as sess:
	# 创建一个协调器,管理线程
	coord = tf.train.Coordinator()
	# 启动QueueRunner, 此时文件名队列已经进队
	threads = tf.train.start_queue_runners(sess=sess, coord=coord)

	for i in range(6):
		image_b, label_b = sess.run([image, label])
		img = Image.fromarray(image_b, 'L')
		plt.imshow(img)
		plt.axis('off')
		plt.show()
		print(label_b)

	# 通知其他线程关闭
	coord.request_stop()
	# 其他所有线程关闭之后,这一函数才能返回
	coord.join(threads)

到此这篇关于tensorflow学习笔记之tfrecord文件的生成与读取的文章就介绍到这了,更多相关tfrecord文件的生成与读取内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

时间: 2021-03-30

tensorflow入门:TFRecordDataset变长数据的batch读取详解

在上一篇文章tensorflow入门:tfrecord 和tf.data.TFRecordDataset的使用里,讲到了使用如何使用tf.data.TFRecordDatase来对tfrecord文件进行batch读取,即使用dataset的batch方法进行:但如果每条数据的长度不一样(常见于语音.视频.NLP等领域),则不能直接用batch方法获取数据,这时则有两个解决办法: 1.在把数据写入tfrecord时,先把数据pad到统一的长度再写入tfrecord:这个方法的问题在于:若是有大量

tensorflow将图片保存为tfrecord和tfrecord的读取方式

tensorflow官方提供了3种方法来读取数据: 预加载数据(preloaded data):在TensorFlow图中定义常量或变量来保存所有的数据,适用于数据量不太大的情况.填充数据(feeding):通过Python产生数据,然后再把数据填充到后端. 从文件读取数据(reading from file):从文件中直接读取,然后通过队列管理器从文件中读取数据. 本文主要介绍第三种方法,通过tfrecord文件来保存和读取数据,对于前两种读取数据的方式也会进行一个简单的介绍. 项目下载git

Tensorflow中批量读取数据的案列分析及TFRecord文件的打包与读取

单一数据读取方式: 第一种:slice_input_producer() # 返回值可以直接通过 Session.run([images, labels])查看,且第一个参数必须放在列表中,如[...] [images, labels] = tf.train.slice_input_producer([images, labels], num_epochs=None, shuffle=True) 第二种:string_input_producer() # 需要定义文件读取器,然后通过读取器中的

Tensorflow中使用tfrecord方式读取数据的方法

前言 本博客默认读者对神经网络与Tensorflow有一定了解,对其中的一些术语不再做具体解释.并且本博客主要以图片数据为例进行介绍,如有错误,敬请斧正. 使用Tensorflow训练神经网络时,我们可以用多种方式来读取自己的数据.如果数据集比较小,而且内存足够大,可以选择直接将所有数据读进内存,然后每次取一个batch的数据出来.如果数据较多,可以每次直接从硬盘中进行读取,不过这种方式的读取效率就比较低了.此篇博客就主要讲一下Tensorflow官方推荐的一种较为高效的数据读取方式--tfre

tensorflow TFRecords文件的生成和读取的方法

TensorFlow提供了TFRecords的格式来统一存储数据,理论上,TFRecords可以存储任何形式的数据. TFRecords文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的.以下的代码给出了tf.train.Example的定义. message Example { Features features = 1; }; message Features { map<string, Feature> feature = 1; }; mes

tensorflow生成多个tfrecord文件实例

我就废话不多说了,直接上代码吧! import tensorflow as tf from PIL import Image import matplotlib.pyplot as plt import numpy as np import os i = 0 j = 0 num_shards = 100#总共写入的文件个数 instances_per_shard = 2#每个文件中的数据个数 sess=tf.InteractiveSession() cwd = "F:/寒假/google--da

python生成lmdb格式的文件实例

在crnn训练的时候需要用到lmdb格式的数据集,下面是python生成lmdb个是数据集的代码,注意一定要在linux系统下,否则会读入图像的时候出问题,可能遇到的问题都在代码里面注释了,看代码即可. #-*- coding:utf-8 -*- import os import lmdb#先pip install这个模块哦 import cv2 import glob import numpy as np def checkImageIsValid(imageBin): if imageBin

php生成txt文件实例代码介绍

这是一个朋友过来的 php 生成 txt 文件代码,这只是一个实例,需要我来给他生成多个 txt 文件实例的,但我觉得他这个代码有点意思,所以就分享上来了. 先说下这个 php 生成 txt 文件代码都是什么功能吧,肯定是要生成 txt 文件的,有点废话了,不说其它的了,这个 php 代码可以生成指定目录下的一个 txt 文件,并在 txt 文件里面写入三行文字,这个是在 php 里面定义好的. 夏日博客分享下实例的代码如下: <!doctype html> <html> <

python 读取excel文件生成sql文件实例详解

python 读取excel文件生成sql文件实例详解 学了python这么久,总算是在工作中用到一次.这次是为了从excel文件中读取数据然后写入到数据库中.这个逻辑用java来写的话就太重了,所以这次考虑通过python脚本来实现. 在此之前需要给python添加一个xlrd模块,这个模块是专门用来操作excel文件的. 在mac中可以通过easy_install xlrd命令实现自动安装模块 import xdrlib ,sys import xlrd def open_excel(fil

Tensorflow 实现将图像与标签数据转化为tfRecord文件

tensorflow中如果要对神经网络模型进行训练,需要把训练数据转换为tfrecord格式才能被读取,tensorflow的model文件里直接提供了相应的脚本文件在下面的文件夹中: cd tensorflow/models/research/object_detection/dataset_tools 其中包括: 1.create_coco_tf_record.py:注意,这个代码需要解析json格式的标签文件 2.create_pascal_tf_record.py:注意,这个代码需要解析

TFRecord文件查看包含的所有Features代码

TFRecord作为tensorflow中广泛使用的数据格式,它跨平台,省空间,效率高.因为 Tensorflow开发者众多,统一训练时数据的文件格式是一件很有意义的事情,也有助于降低学习成本和迁移成本. 但是TFRecord数据是二进制格式,没法直接查看.因此,如何能够方便的查看TFRecord格式和数据,就显得尤为重要了. 为什么需要查看TFReocrd数据?首先我们先看下常规的写入和读取TFRecord数据的关键过程. # 1. 写入过程 # 一张图片,我写入了其内容,label,长和宽几

tensorflow使用range_input_producer多线程读取数据实例

先放关键代码: i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue() inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE]) 原理解析: 第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生.shuffle指定是否打乱顺

linux 驱动之Kconfig文件和Makefile文件实例

linux 驱动之Kconfig文件和Makefile文件实例 在Linux编写驱动的过程中,有两个文件是我们必须要了解和知晓的.这其中,一个是Kconfig文件,另外一个是Makefile文件.如果大家比较熟悉的话,那么肯定对内核编译需要的.config文件不陌生,在.config文件中,我们发现有的模块被编译进了内核,有的只是生成了一个module.这中间,我们如何让内核发现我们编写的模块呢,这就需要在Kconfig中进行说明.至于如何生成模块,那么就需要利用Makefile告诉编译器,怎么

Mybatis映射文件实例详解

 一.输入映射 parameterType 指定输入参数的Java类型,可以使用别名或者类的全限定名.它可以接收简单类型.POJO.HashMap. 1.传递简单类型 根据用户ID查询用户信息: <select id="findUserById" parameterType="int" resultType="com.itheima.mybatis.po.User"> SELECT * FROM USER WHERE id =#{id