Tensorflow 训练自己的数据集将数据直接导入到内存

制作自己的训练集

下图是我们数据的存放格式,在data目录下有验证集与测试集分别对应iris_test, iris_train

为了向伟大的MNIST致敬,我们采用的数据名称格式和MNIST类似

classification_index.jpg

图像的index都是5的整数倍是因为我们选择测试集的原则是每5个样本,选择一个样本作为测试集,其余的作为训练集和验证集

生成这样数据的过程相对简单,如果有需要python代码的,可以给我发邮件,或者在我的github下载

至此,我们的训练集,测试集,验证集就生成成功了,之所以我们的文件夹只有训练集和测试集是因为我们在后续的训练过程中,会在训练集中分出一部分作为验证集,所以两者暂时合称为训练集

将数据集写入到Tensorflow中

1. 直接写入到队列中

import tensorflow as tf
import numpy as np
import os

train_dir = '/home/ruyiwei/data/iris_train/'#your data directory
def get_files(file_dir):
  '''
  Args:
    file_dir: file directory
  Returns:
    list of images and labels
  '''
  iris = []
  label_iris = []
  contact = []
  label_contact = []
  for file in os.listdir(file_dir):
    name = file.split('_')
    if name[0]=="iris":
      iris.append(file_dir + file)
      label_iris.append(0)
    else:
      contact.append(file_dir + file)
      label_contact.append(1)
  print('There are %d iris\nThere are %d contact' %(len(iris), len(contact)))

  image_list = np.hstack((iris, contact))
  label_list = np.hstack((label_iris, label_contact))

  temp = np.array([image_list, label_list])
  temp = temp.transpose()
  np.random.shuffle(temp)

  image_list = list(temp[:, 0])
  label_list = list(temp[:, 1])
  label_list = [int(i) for i in label_list]

  return image_list, label_list

为了大家更方便的理解和修改代码,我们对代码进行讲解如下

1-3行 : 导入需要的模块
5行: 定义训练集合的位置,这个需要根据自己的机器进行修改
7行: 定义函数 get_files
18行: os.listdir(file_dir) 获取指定目录file_dir下的所有文件名词,也就是我们的训练图片名称
18行:for file in os.listdir(file_dir): 遍历所有的图片
19行: name为一个数组,由于我们根据MINIST来定制的图片名词,所以file.split(‘_')会将图像名称分为两部分,第一部分为classification,通过name[0]来获得分类信息。
21行、24行:iris.append(file_dir + file)/contact.append(file_dir + file)将图像的绝对路径放入到iris/contact
22行、25行:label_iris.append(0)/label_contact.append(1)给对应的图片贴标签
28-29行:将二分类的图像与标签压入到list中
31-33行:合并二分类图像,然后打乱
38行:返回打乱后对应的图像与标签

在spyder下执行如上代码后会返回如下信息

这样图像和标签信息就被load到了内存中,我们接下来就可以利用现有的模型对图像进行分类训练,模型的选择和训练的过程,我们会在后面进行讲解。

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持我们。

时间: 2018-06-16

python训练数据时打乱训练数据与标签的两种方法小结

如下所示: <code class="language-python">import numpy as np data = np.array([[1,1],[2,2],[3,3],[4,4],[5,5]]) y = np.array([1,2,3,4,5]) print '-------第1种方法:通过打乱索引从而打乱数据,好处是1:数据量很大时能够节约内存,2每次都不一样----------' data = np.array([[1,1],[2,2],[3,3],[4,4

JavaScript检查数据中是否存在相同的元素(两种方法)

这里是两个用于数组中查找重复元素的demo,可以看看啦 <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8"> <title>Title</title> </head> <body> <input type="text" id="Values" style=

python 读取文件并把矩阵转成numpy的两种方法

在当前目录下: 方法1: file = open('filename') a =file.read() b =a.split('\n')#使用换行 len(b) #统计有多少行 for i in range(len(b)): b[i] = b[i].split()#使用空格分开 len(b[0])#可以查看第一行有多少列. B[0][311]#可以查看具体某行某列的数 import numpy as np b = np.array(b)#转成numpy形的 type(b) # 输出<输出clas

Python实现平行坐标图的两种方法小结

平行坐标图,一种数据可视化的方式.以多个垂直平行的坐标轴表示多个维度,以维度上的刻度表示在该属性上对应值,相连而得的一个折线表示一个样本,以不同颜色区分类别. 但是很可惜,才疏学浅,没办法在Python里实现不同颜色来区分不同的类别.如果对此比较在意的大神可以不要往下看了......... 上图是一个基于iris数据集所画的一个平行坐标图. 隔开隔开.......................................隔开隔开 不多扯了,下面正式上代码 方法一.基于pyecharts第三

python 字典中取值的两种方法小结

如下所示: a={'name':'tony','sex':'male'} 获得name的值的方式有两种 print a['name'],type(a['name']) print a.get('name'),type(a.get('name')) 发现这两个结果完全一致,并没有任何的差异. 怎么选择这两个不同的字典取值方式呢? 如果字典已知,我们可以任选一个,而当我们不确定字典中是否存在某个键时,我之前的做法如下 if 'age' in a.keys(): print a['age'] 因为不先

Python多线程编程(二):启动线程的两种方法

在Python中我们主要是通过thread和threading这两个模块来实现的,其中Python的threading模块是对thread做了一些包装的,可以更加方便的被使用,所以我们使用threading模块实现多线程编程.一般来说,使用线程有两种模式,一种是创建线程要执行的函数,把这个函数传递进Thread对象里,让它来执行:另一种是直接从Thread继承,创建一个新的class,把线程执行的代码放到这个新的 class里. 将函数传递进Thread对象 复制代码 代码如下: '''  Cr

Select2在使用ajax获取远程数据时显示默认数据的方法

假设我需要在我的select2中默认添加一个之前从服务器上获取过的数据,通过以下方法实现.实测可行~ var value = 1 var text = '默认文本' $('.selecter').html('<option value="' + value + '">' + text + '</option>').trigger("change") 以上代码其实就是将class="selecter"的select内容进行重

分享MYSQL插入数据时忽略重复数据的方法

使用下以两种方法时必须把字段设为"主键(PRIMARY KEY"或"唯一约束(UNIQUE)".1:使用REPLACE INTO (此种方法是利用替换的方法,有点似类于先删除再插入)  复制代码 代码如下: REPLACE INTO Syntax  REPLACE [LOW_PRIORITY | DELAYED]      [INTO] tbl_name [(col_name,...)]      {VALUES | VALUE} ({expr | DEFAULT}

使用Java构造和解析Json数据的两种方法(详解二)

JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式,采用完全独立于语言的文本格式,是理想的数据交换格式.同时,JSON是 JavaScript 原生格式,这意味着在 JavaScript 中处理 JSON数据不须要任何特殊的 API 或工具包. 在www.json.org上公布了很多JAVA下的json构造和解析工具,其中org.json和json-lib比较简单,两者使用上差不多但还是有些区别.下面接着介绍用org.json构造和解析Json数据的方法

使用Java构造和解析Json数据的两种方法(详解一)

JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式,采用完全独立于语言的文本格式,是理想的数据交换格式.同时,JSON是 JavaScript 原生格式,这意味着在 JavaScript 中处理 JSON数据不须要任何特殊的 API 或工具包. 在www.json.org上公布了很多JAVA下的json构造和解析工具,其中org.json和json-lib比较简单,两者使用上差不多但还是有些区别.下面首先介绍用json-lib构造和解析Json数据的方法