制作catvsdog_path_dataset.tfrecords的代码 数据集制作完成路径为: E:\catanddog\train1\catvsdog_path_dataset.tfrecords

发布时间 2023-08-12 22:44:19作者: blackstrom
# -*- coding:utf-8 -*- -#
#PROJECT_NAME:081200
#Name:01
#Author:GG
#Date:2023/8/12

import tensorflow as tf
import os
import numpy as np
import cv2

file_dir = "E:\\catanddog\\train0"

save_dir ="E:\\catanddog\\train1"

images = [] # 每张图片的路径组成的列表
temp = [] # 保存cat dog文件夹路径
for root, sub_folders, files in os.walk(file_dir):

for name in files:
images.append(os.path.join(root, name))

for name in sub_folders:
temp.append(os.path.join(root, name))

labels = [] # 保存注释列表

# 此时temp为根目录下所有文件夹的路径列表 一次取出一个文件夹 对文件夹里面的所有数据图片进行注释
for one_folder in temp: #在遍历一个名为 temp 的可迭代对象中的每个元素,将每个元素赋值给变量 one_folder。
n_img = len(os.listdir(one_folder)) # 得到图片总数 os.listdir(one_folder) 函数用于返回指定目录下的所有文件和文件夹的名称列表。len() 函数用于获取列表的长度,即其中元素的个数,
letter = one_folder.split('\\')[-1] # 按照“\\”分割 取出最后一个也就是文件夹的名称

# 标注数据集 将cat标注为0 dog标注为1
if letter == 'cat':
labels = np.append(labels, n_img * [0])
print(labels)

else:
labels = np.append(labels, n_img * [1])
print(labels)

temp = np.array([images, labels]) # 重新创建数组temp 将images 和 labels 最为一对键值对写入
print(temp)
temp = temp.transpose() # 将temp转置
print(temp)
np.random.shuffle(temp) # 打乱数据集的顺序
print(temp)

image_list = list(temp[:, 0]) # 取出数组中的第一维 也就是图片的路径列表
print(image_list )
label_list = list(temp[:, 1]) # 取出数组中的第二维 也就是图片的标签列表
print(label_list )
label_list = [int(float(i)) for i in label_list]#将label_list中的每个元素转换为浮点数,然后再转换为整数类型。该代码逐个处理label_list中的元素,并根据给定的转换过程创建一个新的整数列表
print(label_list )

filename = os.path.join(save_dir, 'catvsdog_path_dataset.tfrecords')
print(filename)
n_samples = len(label_list)
writer = tf.python_io.TFRecordWriter(filename)
print('\n开始制作数据集...')
for i in np.arange(0, n_samples):
# try:
print("正在制作第 %d 张 \n" % i)
image = cv2.imread(image_list[i])
image_raw = image.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=label_list)),
'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw]))}))
writer.write(example.SerializeToString())

# except:
# print('无法读取此文件:' , images[i])

# writer.close()
print('\n数据集制作完成路径为: %s' % filename)