TensorFlow tfrecord的解析

发布时间 2023-11-17 17:44:20作者: 15375357604
import tensorflow as tf
import json
aa = {
"label": {
"binary_label": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"triple_label": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"four_label": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
}
},
"context": {
"item_code": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"query": {
"is_use": 1,
"data_type": "string",
"default_value": "",
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"prod_name": {
"is_use": 0,
"data_type": "string",
"default_value": "",
"feature_length": "fixed_length",
"shape": [-1, 1]
}
},
"feature": {
"query_ids": {
"is_use": 1,
"data_type": "int64_list",
"default_value": 0,
"feature_length": "var_length",
"feature_type": "int_sequence",
"vocab_list": ["unk", "1", "2", "3", "4", "5"],
"shape": [-1, -1],
"preprocess": "pad_sequence",
"description": "搜索词token id sequence"
},
"title_ids": {
"is_use": 1,
"data_type": "int64_list",
"default_value": 0,
"feature_length": "var_length",
"feature_type": "int_sequence",
"vocab_list": ["unk", "1", "2", "3", "4", "5"],
"shape": [-1, -1],
"preprocess": "pad_sequence",
"description": "商品标题token id sequence"
},
"query_token_ids": {
"is_use": 1,
"data_type": "string",
"default_value": "",
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 1],
"preprocess": "pad_sequence",
"description": "搜索词原始token id sequence"
},
"ic_token_ids": {
"is_use": 1,
"data_type": "string",
"default_value": "",
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 1],
"preprocess": "pad_sequence",
"description": "商品标题原始token id sequence"
}
}
}




def get_feature_description(feature_config):
feature_description = {}
# 1.解析特征配置
for feature_name, feature_stats in feature_config["feature"].items():
if feature_stats["is_use"] != 1:
continue
feature_length = feature_stats["feature_length"]
feature_shape = feature_stats["shape"][1]
default_single_value = feature_stats["default_value"]
data_type = feature_stats["data_type"]
if feature_length == "fixed_length":
if data_type == "int64":
if feature_shape > 1:
default_value = [int(default_single_value)] * feature_shape
else:
default_value = int(default_single_value)
feature_description[feature_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.int64,
default_value=default_value)
elif data_type == "float32":
if feature_shape > 1:
default_value = [float(default_single_value)] * feature_shape
else:
default_value = float(default_single_value)
feature_description[feature_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.float32,
default_value=default_value)
elif data_type == "string":
if feature_shape > 1:
default_value = [str(default_single_value)] * feature_shape
else:
default_value = str(default_single_value)
feature_description[feature_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.string,
default_value=default_value)
else:
raise ValueError(f"fixed_length datatype {data_type} now is not supported!")
elif feature_length == "var_length":
if data_type == "int64_list":
feature_description[feature_name] = tf.io.VarLenFeature(tf.int64)
elif data_type == "float32_list":
# TODO
pass
elif data_type == "string_list":
# TODO
pass
else:
raise ValueError(f"var_length datatype {data_type} now is not supported!")
else:
# TODO
raise ValueError(f"feature_length {feature_length} now is not supported!")
# 2.解析label配置
for label_name, label_stats in feature_config["label"].items():
if label_stats["is_use"] != 1:
continue
feature_length = label_stats["feature_length"]
data_type = label_stats["data_type"]
feature_shape = label_stats["shape"][1]
default_value = label_stats["default_value"]
assert data_type == "int64"
assert feature_length == "fixed_length"
assert feature_shape == 1
feature_description[label_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.int64,
default_value=int(default_value))

# 3.解析context配置
for context_name, context_stats in feature_config["context"].items():
if context_stats["is_use"] != 1:
continue
feature_length = context_stats["feature_length"]
feature_shape = context_stats["shape"][1]
default_value = context_stats["default_value"]
data_type = context_stats["data_type"]
assert feature_shape == 1
if feature_length == "fixed_length":
if data_type == "int64":
feature_description[context_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.int64,
default_value=int(default_value))
elif data_type == "float32":
feature_description[context_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.float32,
default_value=float(default_value))
elif data_type == "string":
feature_description[context_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.string,
default_value=str(default_value))
else:
raise ValueError(f"fixed_length datatype {data_type} now is not supported!")

return feature_description

feature_description = get_feature_description(aa)
# 定义解析函数
def parse_tfrecord_fn(example):
parsed_example = tf.io.parse_single_example(example, feature_description)

return parsed_example


# 指定要解析的TFRecord文件路径
tfrecord_file = './data/eval.tfrecord'

# 创建TFRecordDataset对象并应用解析函数
dataset = tf.data.TFRecordDataset(tfrecord_file)
dataset = dataset.map(parse_tfrecord_fn)

# 遍历数据集并打印样本
for example in dataset:
print(example)
print(example['query_ids'].values.numpy())
print(example['title_ids'].values.numpy())
print(example['binary_label'].numpy())
print(example['four_label'].numpy())
print(example['ic_token_ids'].numpy())
print(example['item_code'].numpy())
print(example['query'].numpy())
print(example['query_token_ids'].numpy())
print(example['triple_label'].numpy())
print(".............................................................")



# 生成tfrecord
import tensorflow as tf

# 准备数据
data = {
'query': 'apple',
'query_ids': [1, 2, 3],
'title_ids': [4, 5, 6],
'binary_label': 1,
'four_label': 2,
'ic_token_ids': '12345',
'item_code': 123456,
'query_token_ids': '67890',
'triple_label': 3
}

# 指定输出的 TFRecord 文件路径
train_file = './train_test.tfrecord'

# 创建 TFRecordWriter 对象
with tf.io.TFRecordWriter(train_file) as writer:
# 创建 Example 对象
feature = {
'query': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data['query'].encode()])),
'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=data['query_ids'])),
'title_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=data['title_ids'])),
'binary_label': tf.train.Feature(int64_list=tf.train.Int64List(value=[data['binary_label']])),
'four_label': tf.train.Feature(int64_list=tf.train.Int64List(value=[data['four_label']])),
'ic_token_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data['ic_token_ids'].encode()])),
'item_code': tf.train.Feature(int64_list=tf.train.Int64List(value=[data['item_code']])),
'query_token_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data['query_token_ids'].encode()])),
'triple_label': tf.train.Feature(int64_list=tf.train.Int64List(value=[data['triple_label']])),
}
# 创建 Features 对象
features = tf.train.Features(feature=feature)
# 创建 Example 对象
example_proto = tf.train.Example(features=features)
# 序列化 Example 对象并写入 TFRecord 文件
writer.write(example_proto.SerializeToString())

print(f'Generated TFRecord file: {train_file}')




# val_file = './data/train.tfrecord'
#
# # 创建 TFRecordWriter 对象
# with tf.io.TFRecordWriter(val_file) as writer:
# # 遍历数据并写入 TFRecord 文件
# for example in data:
# # 创建 Example 对象
# feature = {
# 'query': tf.train.Feature(bytes_list=tf.train.BytesList(value=[example['query'].encode()])),
# 'query_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=example['query_ids'])),
# 'title_ids': tf.train.Feature(int64_list=tf.train.Int64List(value=example['title_ids'])),
# 'binary_label': tf.train.Feature(int64_list=tf.train.Int64List(value=[example['binary_label']])),
# 'four_label': tf.train.Feature(int64_list=tf.train.Int64List(value=[example['four_label']])),
# 'ic_token_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[example['ic_token_ids'].encode()])),
# 'item_code': tf.train.Feature(int64_list=tf.train.Int64List(value=[example['item_code']])),
# 'query_token_ids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[example['query_token_ids'].encode()])),
# 'triple_label': tf.train.Feature(int64_list=tf.train.Int64List(value=[example['triple_label']])),
# }
# # 创建 Features 对象
# features = tf.train.Features(feature=feature)
# # 创建 Example 对象
# example_proto = tf.train.Example(features=features)
# # 序列化 Example 对象并写入 TFRecord 文件
# writer.write(example_proto.SerializeToString())
#
# print(f'Generated TFRecord file: {val_file}')


# from wordcloud import STOPWORDS
#
# import re
#
# from collections import defaultdict
#
# item_dict = defaultdict(int)
#
# with open("./data/title_key.txt", "r", encoding="utf-8") as f,open("./data/query_key.txt", "r", encoding="utf-8") as f1,open("./data/data.txt", "w", encoding="utf-8") as out:
# for line in f:
# key, num = line.strip("\n").split("\t")
# if key.strip():
# item_dict[key] = int(num)
# for line in f1:
# key, num = line.strip("\n").split("\t")
# if key.strip():
# item_dict[key]+=int(num)
# sorted_dict = sorted(item_dict.items(), key=lambda x: x[1], reverse=True)
# for key,num in sorted_dict:
# if num>4 and key not in STOPWORDS:
# if re.search("^[a-z0-9]+$",key):
# out.write("{}\t{}\n".format(key, num))
#
#
#


# bert的tfrecord
import sys

import tensorflow as tf
import json
aa = {
"label": {
"binary_label": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"triple_label": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"four_label": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
}
},
"context": {
"item_code": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"query": {
"is_use": 1,
"data_type": "string",
"default_value": "",
"feature_length": "fixed_length",
"shape": [-1, 1]
},
"prod_name": {
"is_use": 0,
"data_type": "string",
"default_value": "",
"feature_length": "fixed_length",
"shape": [-1, 1]
}
},
"feature": {
"query_input_ids": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 10],
"preprocess": "pad_sequence",
"description": "原始搜索词input_ids"
},
"query_attention_mask": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 10],
"preprocess": "pad_sequence",
"description": "原始搜索词attention_mask"
},
"query_token_type_ids": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 10],
"preprocess": "pad_sequence",
"description": "原始搜索词token_type_ids"
},
"title_input_ids": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 30],
"preprocess": "pad_sequence",
"description": "标题input_ids"
},
"title_attention_mask": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 30],
"preprocess": "pad_sequence",
"description": "标题attention_mask"
},
"title_token_type_ids": {
"is_use": 1,
"data_type": "int64",
"default_value": 0,
"feature_length": "fixed_length",
"feature_type": "raw",
"shape": [-1, 30],
"preprocess": "pad_sequence",
"description": "标题token_type_ids"
}
}
}





def get_feature_description(feature_config):
feature_description = {}
# 1.解析特征配置
for feature_name, feature_stats in feature_config["feature"].items():
if feature_stats["is_use"] != 1:
continue
feature_length = feature_stats["feature_length"]
feature_shape = feature_stats["shape"][1]
default_single_value = feature_stats["default_value"]
data_type = feature_stats["data_type"]
if feature_length == "fixed_length":
if data_type == "int64":
if feature_shape > 1:
default_value = [int(default_single_value)] * feature_shape
else:
default_value = int(default_single_value)
feature_description[feature_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.int64,
default_value=default_value)
elif data_type == "float32":
if feature_shape > 1:
default_value = [float(default_single_value)] * feature_shape
else:
default_value = float(default_single_value)
feature_description[feature_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.float32,
default_value=default_value)
elif data_type == "string":
if feature_shape > 1:
default_value = [str(default_single_value)] * feature_shape
else:
default_value = str(default_single_value)
feature_description[feature_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.string,
default_value=default_value)
else:
raise ValueError(f"fixed_length datatype {data_type} now is not supported!")
elif feature_length == "var_length":
if data_type == "int64_list":
feature_description[feature_name] = tf.io.VarLenFeature(tf.int64)
elif data_type == "float32_list":
# TODO
pass
elif data_type == "string_list":
# TODO
pass
else:
raise ValueError(f"var_length datatype {data_type} now is not supported!")
else:
# TODO
raise ValueError(f"feature_length {feature_length} now is not supported!")
# 2.解析label配置
for label_name, label_stats in feature_config["label"].items():
if label_stats["is_use"] != 1:
continue
feature_length = label_stats["feature_length"]
data_type = label_stats["data_type"]
feature_shape = label_stats["shape"][1]
default_value = label_stats["default_value"]
assert data_type == "int64"
assert feature_length == "fixed_length"
assert feature_shape == 1
feature_description[label_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.int64,
default_value=int(default_value))

# 3.解析context配置
for context_name, context_stats in feature_config["context"].items():
if context_stats["is_use"] != 1:
continue
feature_length = context_stats["feature_length"]
feature_shape = context_stats["shape"][1]
default_value = context_stats["default_value"]
data_type = context_stats["data_type"]
assert feature_shape == 1
if feature_length == "fixed_length":
if data_type == "int64":
feature_description[context_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.int64,
default_value=int(default_value))
elif data_type == "float32":
feature_description[context_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.float32,
default_value=float(default_value))
elif data_type == "string":
feature_description[context_name] = tf.io.FixedLenFeature(shape=(feature_shape,), dtype=tf.string,
default_value=str(default_value))
else:
raise ValueError(f"fixed_length datatype {data_type} now is not supported!")

return feature_description

feature_description = get_feature_description(aa)
# 定义解析函数
def parse_tfrecord_fn(example):
parsed_example = tf.io.parse_single_example(example, feature_description)

return parsed_example


# 指定要解析的TFRecord文件路径
tfrecord_file = './data/train_test.tfrecord'

# 创建TFRecordDataset对象并应用解析函数
dataset = tf.data.TFRecordDataset(tfrecord_file)
dataset = dataset.map(parse_tfrecord_fn)

# 遍历数据集并打印样本
for example in dataset:
print("binary_label:", example['binary_label'].numpy())
print("four_label:", example['four_label'].numpy())
print("item_code:", example['item_code'].numpy())
print("query:", example['query'].numpy())
print("query_attention_mask:", example['query_attention_mask'].numpy())
print("query_input_ids:", example['query_input_ids'].numpy())
print("query_token_type_ids:", example['query_token_type_ids'].numpy())
print("title_attention_mask:", example['title_attention_mask'].numpy())
print("title_input_ids:", example['title_input_ids'].numpy())
print("title_token_type_ids:", example['title_token_type_ids'].numpy())
print("triple_label:", example['triple_label'].numpy())
sys.exit(1)