RNN简易训练

2021/5/7 18:27:29

本文主要是介绍RNN简易训练,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

from tensorflow.contrib.layers import fully_connected
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
n_steps=28
n_inputs=28
n_nerons=150
n_outputs=10

learning_rate=0.001

x=tf.placeholder(tf.float32,[None,n_steps,n_inputs])
y=tf.placeholder(tf.int32,[None])

basic_cell=tf.contrib.rnn.BasicRNNCell(num_units=n_nerons)
outputs,states=tf.nn.dynamic_rnn(basic_cell,x,dtype=tf.float32)

logits=fully_connected(states,n_outputs,activation_fn=None)
xentropy=tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=logits)

loss=tf.reduce_mean(xentropy)
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op=optimizer.minimize(loss)
correct=tf.nn.in_top_k(logits,y,1)
accuracy=tf.reduce_mean(tf.cast(correct,tf.float32))
init=tf.global_variables_initializer()

mnist=input_data.read_data_sets("/tmp/data/")
x_test=mnist.test.images.reshape(-1,n_steps,n_inputs)
y_test=mnist.test.labels

n_epochs=100
batch_size=150

with tf.Session() as sess:
   init.run()
   for epoch in range(n_epochs):
      for iteration in range(mnist.train.num_examples//batch_size):
         x_batch,y_batch=mnist.train.next_batch(batch_size)
         x_batch=x_batch.reshape(-1,n_steps,n_inputs)
         sess.run(training_op,feed_dict={x:x_batch,y:y_batch})
      acc_train=accuracy.eval(feed_dict={x:x_batch,y:y_batch})
      acc_test = accuracy.eval(feed_dict={x: x_test, y: y_test})
      print(epoch,'Train acc:',acc_train,'Test acc:',acc_test)


这篇关于RNN简易训练的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程