TFRecord的Shuffle、划分和读取
2022/7/27 23:24:21
本文主要是介绍TFRecord的Shuffle、划分和读取,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
对数据集的shuffle处理需要设置相应的buffer_size参数,相当于需要将相应数目的样本读入内存,且这部分内存会在训练过程中一直保持占用。完全的shuffle需要将整个数据集读入内存,这在大规模数据集的情况下是不现实的,故需要结合设备内存以及Batch大小将TFRecord文件随机划分为多个子文件,再对数据集做local shuffle(即设置相对较小的buffer_size,不小于单个子文件的样本数)。
Shuffle和划分
下文以一个异常检测数据集(正负样本不平衡)为例,在生成第一批TFRecord时,我将正负样本分别写入单独的TFrecord文件以备后续在对正负样本有不同处理策略的情况下无需再解析example_proto。比如在以下代码中,我对正负样本有不同的验证集比例,并将他们写入不同的验证集文件。
import numpy as np import tensorflow as tf from tqdm.notebook import tqdm as tqdm # TFRecord划分 raw_normal_dataset = tf.data.TFRecordDataset("normal_16_256.tfrecords","GZIP") raw_anomaly_dataset = tf.data.TFRecordDataset("anomaly_16_256.tfrecords","GZIP") normal_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP") anomaly_val_writer = tf.io.TFRecordWriter(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP") train_writer_list = [tf.io.TFRecordWriter(r'ex_1/'+'train_16_256_{}.tfrecords'.format(i),"GZIP") for i in range(SUBFILE_NUM+1)] with tqdm(total=LEN_NORMAL_DATASET+LEN_ANOMALY_DATASET) as pbar: for example_proto in raw_normal_dataset: # 划分训练集和测试集 if np.random.random() > 0.99: # 正样本测试集的比例 normal_val_writer.write(example_proto.numpy()) else: train_writer_list[np.random.randint(0,SUBFILE_NUM+1)].write(example_proto.numpy()) pbar.update(1) for example_proto in raw_anomaly_dataset: # 划分训练集和测试集 if np.random.random() > 0.7: # 负样本测试集的比例 anomaly_val_writer.write(example_proto.numpy()) else: train_writer_list[np.random.randint(0,SUBFILE_NUM+1)].write(example_proto.numpy()) pbar.update(1) normal_val_writer.close() anomaly_val_writer.close() for train_writer in train_writer_list: train_writer.close()
读取
raw_train_dataset = tf.data.TFRecordDataset([r'ex_1/'+'train_16_256_{}.tfrecords'.format(i) for i in range(SUBFILE_NUM+1)],"GZIP") raw_train_dataset = raw_train_dataset.shuffle(buffer_size=100000).batch(BATCH_SIZE) parsed_train_dataset = raw_train_dataset.map(map_func=map_func) raw_normal_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'normal_val_16_256.tfrecords',"GZIP") raw_anomaly_val_dataset = tf.data.TFRecordDataset(r'ex_1/'+'anomaly_val_16_256.tfrecords',"GZIP") parsed_nomarl_val_dataset = raw_normal_val_dataset.batch(BATCH_SIZE).map(map_func=map_func) parsed_anomaly_val_dateset = raw_anomaly_val_dataset.batch(BATCH_SIZE).map(map_func=map_func)
这篇关于TFRecord的Shuffle、划分和读取的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-12-27文件掩码什么意思?-icode9专业技术文章分享
- 2024-12-27如何使用循环来处理多个订单的退款请求,代码怎么写?-icode9专业技术文章分享
- 2024-12-27VSCode 在编辑时切换到另一个文件后再切回来如何保持在原来的位置?-icode9专业技术文章分享
- 2024-12-27Sealos Devbox 基础教程:使用 Cursor 从零开发一个 One API 替代品 审核中
- 2024-12-27TypeScript面试真题解析与实战指南
- 2024-12-27TypeScript大厂面试真题详解与解析
- 2024-12-26怎么使用nsenter命令进入容器?-icode9专业技术文章分享
- 2024-12-26导入文件提示存在乱码,请确定使用的是UTF-8编码怎么解决?-icode9专业技术文章分享
- 2024-12-26csv文件怎么设置编码?-icode9专业技术文章分享
- 2024-12-25TypeScript基础知识详解