【TextCNN完整版】快速+高准确率的baseline
2021/11/27 23:12:59
本文主要是介绍【TextCNN完整版】快速+高准确率的baseline,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
前言:
2个月前写过一篇《TextCNN的完整步骤(不到60行代码)》,但是并没有考虑到后续工程化部署以及数据量较大的情况(无法全部加载到内存里),所以今天根据实际案例做了一次改造和优化。
TextCNN的操作步骤一般可以分为以下几步:
1、数据整理:日常工作中的文本可能不像比赛一样直接给你一个csv文件,你可能需要自己整合起来;另外textcnn在训练和预测时,不认分类变量(如上海, 北京等),所以必须通过map或label_encoder的方式修改,到最后样本预测结束后再map_reverse回去。
2、建立词库:tokenizer.fit_on_texts,这一步非常重要,如果后面出现训练准确率一直在个位数的情况,请到这一条仔细检查下;
3、制作tf数据集:如果文本太多内存装不下,建议还是上batch(32或64)吧。但是需要注意的是,如果你的train_data和valid_data都做成了dataset,那么test_data也必须做成dataset,虽然label目前还没有,但可以虚拟成均为0;
4、构建TextCNN网络:这个没什么好说的,具体是[2,3,4]还是[3,4,5]都可以;
5、设定weight权重:在分类任务中,绝大部分都是不平衡的,尤其是多分类,所以设定weight权重还是很必要的;
6、训练模型:可调超参包括learning_rate(建议3e-4),epochs(建议30-40,反正会设早停),optimizer(Adam就很好),EARLY_STOP_PATIENCE(早停次数,3次即可);
7、模型固化:tensorflow2中直接可以model.save(’./model/text_cnn.h5’),在本文中就不演示了;
8、模型加载:textcnn_model = tf.keras.models.load_model(‘service/model/text_cnn.h5’);
9、样本预测:text_cnn_model.predict(test_dataset),注意出来的结果是0-1的浮点数,需要通过np.argmax(predictions, axis=-1)选择正确的标签;
具体代码如下:
一、导入数据
import os import pandas as pd import numpy as np import tensorflow as tf from sklearn.utils import resample from sklearn.model_selection import train_test_split
#训练数据导入 train_type_list = [] train_text_list = [] train_dir_name_list = os.listdir('./train/') train_dir_name_list.remove('.DS_Store') for dir_name in train_dir_name_list: for file in os.listdir('./train/'+dir_name+'/'): train_type_list.append(dir_name.split('-')[1]) train_text_list.append(open('./train/'+str(dir_name)+'/'+str(file),'r',encoding='gb18030',errors='ignore').read().replace('\n', ' ').replace('\u3000', '')) print(len(train_type_list))
#标签字典 cls_num = len(set(train_type_list)) cls_dict = {} for k,v in enumerate(set(train_type_list)): cls_dict[k] = v cls_dict_reverse = {v:k for k,v in cls_dict.items()}
train_data = pd.DataFrame({'text':train_text_list,'target':train_type_list}) train_data['target'] = train_data['target'].map(cls_dict_reverse) train_data = resample(train_data) train_data.head()
#预测数据导入 test_text_list = [] test_filename = [] for file in os.listdir('./test'): test_filename.append(file) test_text_list.append(open('./test/'+file,'r', encoding='gb18030', errors='ignore').read().replace('\n',' ')) test_data = pd.DataFrame({'text':test_text_list, 'filename':test_filename}) test_data['target']=0
二、TF数据准备
X_train, X_val, y_train, y_val = train_test_split(train_data['text'], train_data['target'], test_size=0.1, random_state=27) #tokenizer NUM_LABEL = cls_num #类别数量 BATCH_SIZE = 32 MAX_LEN = 200 #最长序列长度 BUFFER_SIZE = tf.constant(train_data.shape[0], dtype=tf.int64) tokenizer = tf.keras.preprocessing.text.Tokenizer(char_level=True) tokenizer.fit_on_texts(X_train)
def build_tf_dataset(text, label, is_train=False): '''制作tf数据集''' sequence = tokenizer.texts_to_sequences(text) sequence_padded = tf.keras.preprocessing.sequence.pad_sequences(sequence,padding='post',maxlen=MAX_LEN) dataset = tf.data.Dataset.from_tensor_slices((sequence_padded, label)) if is_train: dataset = dataset.shuffle(BUFFER_SIZE) dataset = dataset.batch(BATCH_SIZE) dataset = dataset.prefetch(BUFFER_SIZE) else: dataset = dataset.batch(BATCH_SIZE) dataset = dataset.prefetch(BATCH_SIZE) return dataset
train_dataset = build_tf_dataset(X_train, y_train, is_train=True) val_dataset = build_tf_dataset(X_val, y_val, is_train=False) test_dataset = build_tf_dataset(test_data['text'], test_data['target'], is_train=False)
三、构建TextCNN网络
VOCAB_SIZE = len(tokenizer.index_word) + 1 print(VOCAB_SIZE) EMBEDDING_DIM = 100 FILTERS = [3, 4, 5] NUM_FILTERS = 128 #卷积核的大小 DENSE_DIM = 256 #全连接层大小 CLASS_NUM = 20 #类别数量 DROPOUT_RATE = 0.5 #dropout比例
def build_text_cnn_model(): inputs = tf.keras.Input(shape=(None,)) embed = tf.keras.layers.Embedding( input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM, trainable=True, mask_zero=True)(inputs) embed = tf.keras.layers.Dropout(DROPOUT_RATE)(embed) pool_outputs = [] for filter_size in FILTERS: conv = tf.keras.layers.Conv1D(NUM_FILTERS, filter_size, padding='same', activation='relu', data_format='channels_last', use_bias=True)(embed) max_pool = tf.keras.layers.GlobalMaxPooling1D(data_format='channels_last')(conv) pool_outputs.append(max_pool) outputs = tf.keras.layers.concatenate(pool_outputs, axis=-1) outputs = tf.keras.layers.Dense(DENSE_DIM, activation='relu')(outputs) outputs = tf.keras.layers.Dropout(DROPOUT_RATE)(outputs) outputs = tf.keras.layers.Dense(CLASS_NUM, activation='softmax')(outputs) model = tf.keras.Model(inputs=inputs, outputs=outputs) return model text_cnn_model = build_text_cnn_model() text_cnn_model.summary()
#设定weight权重 df_weight = train_data['target'].value_counts().sort_index().reset_index() df_weight['weight'] = df_weight['target'].min() / df_weight['target'] df_weight_dict = {k:v for k,v in zip(df_weight['index'], df_weight['weight'])} df_weight_dict
四、开始训练
LR = 3e-4 EPOCHS = 30 EARLY_STOP_PATIENCE = 2 loss = tf.keras.losses.SparseCategoricalCrossentropy() optimizer = tf.keras.optimizers.Adam(LR) text_cnn_model.compile(loss=loss, optimizer=optimizer, metrics=['accuracy']) callback = tf.keras.callbacks.EarlyStopping(monitor='val_accuracy', patience=EARLY_STOP_PATIENCE, restore_best_weights=True) history = text_cnn_model.fit(train_dataset, epochs=EPOCHS, callbacks=[callback], validation_data=val_dataset, class_weight=df_weight_dict )
在CPU上效果也不差,准确率能达到90%左右。
五、预测和导出结果
test_predict = text_cnn_model.predict(test_dataset) preds = np.argmax(test_predict, axis=-1) test_data['category'] = preds test_data['category'] = test_data['category'].map(cls_dict) test_data[['filename','category']].to_csv('zhanglei.csv', index=False)
这篇关于【TextCNN完整版】快速+高准确率的baseline的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2025-01-06PingCAP 连续两年入选 Gartner 云数据库管理系统魔力象限“荣誉提及”
- 2025-01-05Easysearch 可搜索快照功能,看这篇就够了
- 2025-01-04BOT+EPC模式在基础设施项目中的应用与优势
- 2025-01-03用LangChain构建会检索和搜索的智能聊天机器人指南
- 2025-01-03图像文字理解,OCR、大模型还是多模态模型?PalliGema2在QLoRA技术上的微调与应用
- 2025-01-03混合搜索:用LanceDB实现语义和关键词结合的搜索技术(应用于实际项目)
- 2025-01-03停止思考数据管道,开始构建数据平台:介绍Analytics Engineering Framework
- 2025-01-03如果 Azure-Samples/aks-store-demo 使用了 Score 会怎样?
- 2025-01-03Apache Flink概述:实时数据处理的利器
- 2025-01-01使用 SVN合并操作时,怎么解决冲突的情况?-icode9专业技术文章分享