tf.train.Example的用法

发布时间 2023-05-04 11:04:30作者: 独上兰舟1

目录
前言
tf.train.BytesList等
tf.train.Feature
tf.train.Features
tf.train.Example
前言
最近在看到一个代码时,里面用到了tf.train.Example,于是学习了其用法,这里记录一下,也希望能对其他朋友有用。
另外,本文涉及的代码基于python 3.6.5 tensorflow 1.8.0
tf.train.Example主要用在将数据处理成二进制方面,一般是为了提升IO效率和方便管理数据。下面按调用顺序介绍使用tf.train.Example涉及的几个类。

tf.train.BytesList等
现在我们有一个data.txt文件,内容如下:

21
This is a test data file.
We will convert this text file to bin file.
1
2
3
文件中第一行是个整数,第二行和第三行都是字符串。这是我们处理的原始数据。
我们先使用下面的代码将数据读进来:

import struct
import tensorflow as tf


def read_text_file(text_file):
lines = []
with open(text_file, "r") as f:
for line in f:
lines.append(line.strip())
return lines

def text_to_binary(in_file, out_file):
inputs = read_text_file(in_file)

with open(out_file, 'wb') as writer:
pass

if __name__ == '__main__':
text_to_binary('data.txt', 'data.bin')
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
格式化原始数据可以使用tf.train.BytesList tf.train.Int64List tf.train.FloatList三个类。这三个类都有实例属性value用于我们将值传进去,一般tf.train.Int64List tf.train.FloatList对应处理整数和浮点数,tf.train.BytesList用于处理其他类型的数据。
这里第一行数据我们可以用tf.train.Int64List处理,而第二、第三行数据我们使用tf.train.BytesList处理。下面我们看代码实现,我们将上述代码的pass替换如下:

data_id = tf.train.Int64List(value=[int(inputs[0])])
data = tf.train.BytesList(value=[bytes(' '.join(inputs[1:]), encoding='utf-8')])
1
2
注意到,tf.train.Int64List的value值需要是int数据的列表,而tf.train.BytesList的value值需要是bytes数据的列表。
我们分别打印data_id和data的值可以看到:

value: 21

value: "This is a test data file. We will convert this text file to bin file."
1
2
3
这样我们就完成了第一步操作。

tf.train.Feature
tf.train.Feature有三个属性为tf.train.bytes_list tf.train.float_list tf.train.int64_list,显然我们只需要根据上一步得到的值来设置tf.train.Feature的属性就可以了,如下所示:

tf.train.Feature(int64_list=data_id)
tf.train.Feature(bytes_list=data)
1
2
tf.train.Features
从名字来看,我们应该能猜出tf.train.Features是tf.train.Feature的复数,事实上tf.train.Features有属性为feature,这个属性的一般设置方法是传入一个字典,字典的key是字符串(feature名),而值是tf.train.Feature对象。因此,我们可以这样得到tf.train.Features对象:

feature_dict = {
"data_id": tf.train.Feature(int64_list=data_id),
"data": tf.train.Feature(bytes_list=data)
}
features = tf.train.Features(feature=feature_dict)
1
2
3
4
5
tf.train.Example
终于到我们的主角了。tf.train.Example有一个属性为features,我们只需要将上一步得到的结果再次当做参数传进来即可。
另外,tf.train.Example还有一个方法SerializeToString()需要说一下,这个方法的作用是把tf.train.Example对象序列化为字符串,因为我们写入文件的时候不能直接处理对象,需要将其转化为字符串才能处理。
当然,既然有对象序列化为字符串的方法,那么肯定有从字符串反序列化到对象的方法,该方法是FromString(),需要传递一个tf.train.Example对象序列化后的字符串进去做为参数才能得到反序列化的对象。
在我们这里,只需要构建tf.train.Example对象并序列化就可以了,这一步的代码为:

example = tf.train.Example(features=features)
example_str = example.SerializeToString()
1
2
好了,那么现在我们看一下将data.txt处理成data.bin的完整代码:

import struct
import tensorflow as tf


def read_text_file(text_file):
lines = []
with open(text_file, "r") as f:
for line in f:
lines.append(line.strip())
return lines


def text_to_binary(in_file, out_file):
inputs = read_text_file(in_file)

with open(out_file, 'wb') as writer:
data_id = tf.train.Int64List(value=[int(inputs[0])])
data = tf.train.BytesList(value=[bytes(' '.join(inputs[1:]), encoding='utf-8')])

feature_dict = {
"data_id": tf.train.Feature(int64_list=data_id),
"data": tf.train.Feature(bytes_list=data)
}
features = tf.train.Features(feature=feature_dict)

example = tf.train.Example(features=features)
example_str = example.SerializeToString()

str_len = len(example_str)

writer.write(struct.pack('H', str_len))
writer.write(struct.pack('%ds' % str_len, example_str))


if __name__ == '__main__':
text_to_binary('data.txt', 'data.bin')

代码里还涉及到了struct模块,关于struct模块的用法可以参考我的这篇文章:Python二进制数据处理
————————————————
版权声明:本文为CSDN博主「hfutdog」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/hfutdog/article/details/86244944