使用微调后的Bert模型做编码器进行文本特征向量抽取
2021/4/14 18:56:19
本文主要是介绍使用微调后的Bert模型做编码器进行文本特征向量抽取,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
通常,我们使用bert做文本分类,泛化性好、表现优秀。在进行文本相似性计算任务时,往往是对语料训练词向量,再聚合文本向量embedding数据,计算相似度;但是,word2vec是静态词向量,表征能力有限,此时,可以用已进行特定环境下训练的bert模型,抽取出cls向量作为整个句子的表征向量以供下游任务使用,可以说是一个附加产物;主要流程如下:
1)加载ckpt模型
2)确定输出tensor名称,在bert中,cls的名称为:bert/pooler/dense/Tanh(而不是SoftMax)
3)存储为pb model
主代码:
def extract_bert_vector(): """ 抽取bert 768 特征向量 :return: """ OUTPUT_GRAPH = 'pb_model/bert_encoder.pb' output_node = ["bert/pooler/dense/Tanh"] ckpt_model = r'output' bert_config_file = r'chinese_L-12_H-768_A-12/bert_config.json' max_seq_length = 200 gpu_config = tf.ConfigProto() gpu_config.gpu_options.allow_growth = True sess = tf.Session(config=gpu_config) graph = tf.get_default_graph() with open(r'data/file_dict.json', 'r') as fr: label_list = json.load(fr) with graph.as_default(): print("going to restore checkpoint") input_ids_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_ids") input_mask_p = tf.placeholder(tf.int32, [None, max_seq_length], name="input_mask") bert_config = modeling.BertConfig.from_json_file(bert_config_file) (loss, per_example_loss, logits, probabilities) = create_model( bert_config=bert_config, is_training=False, input_ids=input_ids_p, input_mask=input_mask_p, segment_ids=None, labels=None, num_labels=len(label_list), use_one_hot_embeddings=False) saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint(ckpt_model)) graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node) with tf.gfile.GFile(OUTPUT_GRAPH, "wb") as f: f.write(graph.SerializeToString()) print('extract vector pb model saved!')
这篇关于使用微调后的Bert模型做编码器进行文本特征向量抽取的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-12-22项目:远程温湿度检测系统
- 2024-12-21《鸿蒙HarmonyOS应用开发从入门到精通(第2版)》简介
- 2024-12-21后台管理系统开发教程:新手入门全指南
- 2024-12-21后台开发教程:新手入门及实战指南
- 2024-12-21后台综合解决方案教程:新手入门指南
- 2024-12-21接口模块封装教程:新手必备指南
- 2024-12-21请求动作封装教程:新手必看指南
- 2024-12-21RBAC的权限教程:从入门到实践
- 2024-12-21登录鉴权实战:新手入门教程
- 2024-12-21动态权限实战入门指南