Tensorflow 实现Mnist图片预测

2019/12/28 0:28:48

本文主要是介绍Tensorflow 实现Mnist图片预测,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

1.准备数据,使用占位符,动态加载训练数据

x=tf.placeholder(tf.float32,[None,784])
y_true=tf.placeholder(tf.int32,[None,10])

2.初始化参数,建立模型

weight=tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0))
bias=tf.Variable(tf.canstant(0.0,shape=[10]))
y_predict=tf.matmul(x,weight)+bias

3.求平均交叉熵损失

loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict))

4.梯度下降优化

train_op=tf.GradientDescentOptimizer(0.3).minimize(loss)

5.求准确率

equal_list=tf.equal(tf.arg_max(y_true,1),tf.arg_max(y_predict,1))
accuracy=tf.reduce_mean(tf.cast(equal_list,tf.float32))

完整代码:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os

mnist = input_data.read_data_sets('./data/MNISI_data/', one_hot=True)


def full_connection():
    # 1.准备数据
    with tf.variable_scope("data"):
        x = tf.placeholder(tf.float32, [None, 784])
        y_true = tf.placeholder(tf.int32, [None, 10])
    # 2.建立模型
    with tf.variable_scope('predict_model'):
        weight = tf.Variable(tf.random_normal([784, 10], mean=0.0, stddev=1.0), name='w')
        bias = tf.Variable(tf.constant(0.0, shape=[10]))
        y_predict = tf.matmul(x, weight) + bias
    # 3.平均交叉熵损失
    with tf.variable_scope('loss'):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_predict))
    # 4.梯度下降优化
    with tf.variable_scope('optimizer'):
        train_op = tf.train.GradientDescentOptimizer(0.4).minimize(loss)
    # 5.求准确率
    with tf.variable_scope('acc'):
        equal_list = tf.equal(tf.arg_max(y_true, 1), tf.arg_max(y_predict, 1))
        accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))
    init_op = tf.initialize_all_variables()
    # 收集变量,tensorboard使用
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('accuracy', accuracy)
    tf.summary.histogram('weight', weight)
    tf.summary.histogram('bias', bias)
    merged = tf.summary.merge_all()
    saver = tf.train.Saver()
    is_train = False
    with tf.Session() as sess:
        if is_train == True:
            sess.run(init_op)
            fileWriter = tf.summary.FileWriter('./temp/summary/test', graph=sess.graph)
            if os.path.exists('./temp/ckpt/checkpoint'):
                # 加载训练的模型
                saver.restore(sess, './temp/ckpt/full_conn')
            for i in range(4000):
                # 每次批量货期50个数据集
                mnist_x, mnist_y = mnist.train.next_batch(50)
                sess.run(train_op, feed_dict={x: mnist_x, y_true: mnist_y})
                summary = sess.run(merged, feed_dict={x: mnist_x, y_true: mnist_y})
                fileWriter.add_summary(summary, i)
                print("训练低%d步,准确率为:%f" % (i, sess.run(accuracy, feed_dict={x: mnist_x, y_true: mnist_y})))
            # 保存训练完的模型
            saver.save(sess, './temp/ckpt/full_conn')
        else:
            saver.restore(sess, './temp/ckpt/full_conn')
            for i in range(100):
                # 每次批量货期1个数据集
                x_test, y_test = mnist.test.next_batch(1)
                print('低%d张图片,手写数字图片目标:%d--%d' % (
                    i,
                    tf.arg_max(y_test, 1).eval(),
                    tf.arg_max(sess.run(y_predict, feed_dict={x: x_test, y_true: y_test}), 1).eval()
                ))


if __name__ == '__main__':
    full_connection()


这篇关于Tensorflow 实现Mnist图片预测的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程