1113-七言诗词收集与LSTM自动写诗
2021/11/13 23:39:45
本文主要是介绍1113-七言诗词收集与LSTM自动写诗,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
七言诗词收集
数据清洗
通过之前对每个诗词进行的诗词形式的分类:提取诗词形式与对应的诗词内容两列
开始清洗:
①找到formal为七言绝句的诗词
②对诗词进行分词,判断是否符合要求,然后去除一些非法字符的段落
import pandas as pd import re #获取指定文件夹下的excel import os def get_filename(path,filetype): # 输入路径、文件类型例如'.xlsx' name = [] for root,dirs,files in os.walk(path): for i in files: if os.path.splitext(i)[1]==filetype: name.append(i) return name # 输出由有后缀的文件名组成的列表 def read(): file = 'data/' list = get_filename(file, '.xlsx') qi_list=[] for it in list: newfile =file+it print(newfile) # 获取诗词内容 data = pd.read_excel(newfile) formal=data.formal content=data.content for i in range(len(formal)): fom=formal[i] if fom=='七言绝句': text=content[i].replace('\n','') text_list=re.split('[,。]',text) print(text_list) if len(text_list)==9 and len(text_list[len(text_list)-1])==0: f = True for i in range(len(text_list)-1): it=text_list[i] print(len(it)) if len(it)!=7 or it.find('□')!=-1: f=False break if f: #print(text) qi_list.append(text[:32]) qi_list.append(text[32:64]) print(qi_list) return qi_list def write(content): with open("./poem_train/qi_jueju.txt", "w", encoding="utf-8") as f: for it in content: f.write(it) # 自带文件关闭功能,不需要再写f.close() f.write("\n") if __name__ == '__main__': content=read() write(content)
保存形式
整理了3万多条诗句,感觉还可以在细化提取,之后在改善一下此代码
LSTM自动写诗
代码
import torch import torch.nn as nn import numpy as np from gensim.models.word2vec import Word2Vec import pickle from torch.utils.data import Dataset,DataLoader import os def split_poetry(file='qi_jueju2.txt'): all_data=open(file,"r",encoding="utf-8").read() all_data_split=" ".join(all_data) with open("split.txt","w",encoding='utf-8') as f: f.write(all_data_split) def train_vec(split_file='split.txt',org_file='qi_jueju2.txt'): #word2vec模型 vec_params_file="vec_params.pkl" #判断切分文件是否存在,不存在进行切分 if os.path.exists(split_file)==False: split_poetry() #读取切分的文件 split_all_data=open(split_file,"r",encoding="utf-8").read().split("\n") #读取原始文件 org_data=open(org_file,"r",encoding="utf-8").read().split("\n") #存在模型文件就去加载,返回数据即可 if os.path.exists(vec_params_file): return org_data,pickle.load(open(vec_params_file,"rb")) #词向量大小:vector_size,构造word2vec模型,字维度107,只要出现一次就统计该字,workers=6同时工作 embedding_num=128 model=Word2Vec(split_all_data,vector_size=embedding_num,min_count=1,workers=6) #保存模型 pickle.dump((model.syn1neg,model.wv.key_to_index,model.wv.index_to_key),open(vec_params_file,"wb")) return org_data,(model.syn1neg,model.wv.key_to_index,model.wv.index_to_key) class MyDataset(Dataset): #数据打包 #加载所有数据 #存储和初始化变量 def __init__(self,all_data,w1,word_2_index): self.w1=w1 self.word_2_index=word_2_index self.all_data=all_data #获取一条数据,并做处理 def __getitem__(self, index): a_poetry_words = self.all_data[index] a_poetry_index = [self.word_2_index[word] for word in a_poetry_words] xs_index = a_poetry_index[:-1] ys_index = a_poetry_index[1:] #取出31个字,每个字对应107维度向量,【31,107】 xs_embedding=self.w1[xs_index] return xs_embedding,np.array(ys_index).astype(np.int64) #获取数据总长度 def __len__(self): return len(self.all_data) class Mymodel(nn.Module): def __init__(self,embedding_num,hidden_num,word_size): super(Mymodel, self).__init__() self.embedding_num=embedding_num self.hidden_num = hidden_num self.word_size = word_size #num_layer:两层,代表层数,出来后的维度[5,31,64],设置hidden_num=64 self.lstm=nn.LSTM(input_size=embedding_num,hidden_size=hidden_num,batch_first=True,num_layers=2,bidirectional=False) #做一个随机失活,防止过拟合,同时可以保持生成的古诗不唯一 self.dropout=nn.Dropout(0.3) #做一个flatten,将维度合并【5*31,64】 self.flatten=nn.Flatten(0,1) #加一个线性层:[64,词库大小] self.linear=nn.Linear(hidden_num,word_size) #交叉熵 self.cross_entropy=nn.CrossEntropyLoss() def forward(self,xs_embedding,h_0=None,c_0=None): xs_embedding=xs_embedding.to(device) if h_0==None or c_0==None: #num_layers,batch_size,hidden_size h_0=torch.tensor(np.zeros((2,xs_embedding.shape[0],self.hidden_num),np.float32)) c_0 = torch.tensor(np.zeros((2, xs_embedding.shape[0], self.hidden_num),np.float32)) h_0=h_0.to(device) c_0=c_0.to(device) hidden,(h_0,c_0)=self.lstm(xs_embedding,(h_0,c_0)) hidden_drop=self.dropout(hidden) flatten_hidden=self.flatten(hidden_drop) pre=self.linear(flatten_hidden) return pre,(h_0,c_0) def generate_poetry_auto(): result='' #随机产生第一个字的下标 word_index=np.random.randint(0,word_size,1)[0] result += index_2_word[word_index] h_0 = torch.tensor(np.zeros((2, 1, hidden_num), np.float32)) c_0 = torch.tensor(np.zeros((2, 1, hidden_num), np.float32)) for i in range(31): word_embedding=torch.tensor(w1[word_index].reshape(1,1,-1)) pre,(h_0,c_0)=model(word_embedding,h_0,c_0) word_index=int(torch.argmax(pre)) result+=index_2_word[word_index] print(result) def test(): if os.path.exists(model_result_file): model=pickle.load(open(model_result_file, "rb")) generate_poetry_auto() if __name__ == '__main__': device="cuda" if torch.cuda.is_available() else "cpu" print(device) #源数据小了,batch不能太大 batch_size=64 all_data,(w1,word_2_index,index_2_word)=train_vec() dataset=MyDataset(all_data,w1,word_2_index) dataloader=DataLoader(dataset,batch_size=batch_size,shuffle=True) epoch=1000 word_size , embedding_num=w1.shape lr=0.01 hidden_num=128 model_result_file='model_lstm.pkl' #测试代码 # if os.path.exists(model_result_file): # model=pickle.load(open(model_result_file, "rb")) # generate_poetry_auto() #训练代码 model=Mymodel(embedding_num,hidden_num,word_size) #放入gpu训练 model.to(device) optimizer=torch.optim.AdamW(model.parameters(),lr=lr) for e in range(epoch): #按照指定的batch_size获取诗词条数【32,31,107】 #ys_index:torch.Size([32,31]) for batch_index,(xs_embedding,ys_index) in enumerate(dataloader): xs_embedding=xs_embedding.to(device) ys_index=ys_index.to(device) pre,_=model.forward(xs_embedding) loss=model.cross_entropy(pre,ys_index.reshape(-1)) optimizer.zero_grad() loss.backward() optimizer.step() if batch_index%100==0: print(f"loss:{loss:.3f}") generate_poetry_auto() pickle.dump(model, open(model_result_file, "wb"))
结果
这篇关于1113-七言诗词收集与LSTM自动写诗的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2025-01-11cursor试用出现:Too many free trial accounts used on this machine 的解决方法
- 2025-01-11百万架构师第十四课:源码分析:Spring 源码分析:深入分析IOC那些鲜为人知的细节|JavaGuide
- 2025-01-11不得不了解的高效AI办公工具API
- 2025-01-102025 蛇年,J 人直播带货内容审核团队必备的办公软件有哪 6 款?
- 2025-01-10高效运营背后的支柱:文档管理优化指南
- 2025-01-10年末压力山大?试试优化你的文档管理
- 2025-01-10跨部门协作中的进度追踪重要性解析
- 2025-01-10总结 JavaScript 中的变体函数调用方式
- 2025-01-10HR团队如何通过数据驱动提升管理效率?6个策略
- 2025-01-10WBS实战指南:如何一步步构建高效项目管理框架?