深度学习-----从零开始实现识别手写字体任务(六)计算测试集的准确率和Tensorflow的执行阶段
2021/7/2 23:21:21
本文主要是介绍深度学习-----从零开始实现识别手写字体任务(六)计算测试集的准确率和Tensorflow的执行阶段,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
计算测试集的准确率
def compute_accuracy(v_xs, v_ys): global prediction # y_pre将v_xs输入模型后得到的预测值 (10000,10) y_pre = sess.run(prediction, feed_dict={xs: v_xs, keep_prob: 1}) # argmax(axis) axis = 1 返回结果为:数组中每一行最大值所在“列”索引值 # tf.equal返回布尔值,correct_prediction (10000,1) correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(v_ys, 1)) # tf.cast将bool转成float32, tf.reduce_mean求均值,作为accuracy值(0到1) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys, keep_prob: 1}) return result
该函数有两个参数,v_xs为测试集,v_ys为测试集的标签。
y_pre将测试集输入后的模型转化为预测值
argmax函数将数组每一行的最大值所在的列返回出来
使用tf.equal比较测试集和标签的值是否相等,若相等返回ture
tf.cast将布尔值转化为浮点数,通过求均值来计算准确率
最后返回准确率。
Tensorflow的执行阶段
TensorFlow 程序通常被组织成一个构建阶段和一个执行阶段,之前的都是构建阶段,现在是执行阶段,需要创立一个session对象来一遍遍执行上述程序。
Session对象在使用完后需要关闭以释放资源. 除了显式调用 close 外, 也可以使用 "with" 代码块 来自动完成关闭动作。
keep_prob_rate = 0.6 with tf.Session() as sess: # 初始化图中所有Variables init = tf.global_variables_initializer() sess.run(init) # 总迭代次数(batch)为max_epoch=1000,每次取100张图做batch梯度下降 print("step 0, test accuracy %g" % (compute_accuracy( mnist.test.images, mnist.test.labels))) for i in range(max_epoch): # mnist.train.next_batch 默认shuffle=True,随机读取,batch大小为100 batch_xs, batch_ys = mnist.train.next_batch(100) # 此batch是个2维tuple,batch[0]是(100,784)的样本数据数组,batch[1]是(100,10)的样本标签数组,分别赋值给batch_xs, batch_ys sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: keep_prob_rate}) # 暂时不进行赋值的元素叫占位符(如xs、ys),run需要它们时得赋值,feed_dict就是用来赋值的,格式为字典型 if (i + 1) % 50 == 0: print("step %d, test accuracy %g" % (i + 1, compute_accuracy( mnist.test.images, mnist.test.labels)))
第一步我们要先初始化所有的variables
先输出模型一开始的准确率
然后进行迭代
经过训练通过train_step改变所有variables的值,从而提高准确率
最后输出准确率
这篇关于深度学习-----从零开始实现识别手写字体任务(六)计算测试集的准确率和Tensorflow的执行阶段的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-10-30tensorflow是什么-icode9专业技术文章分享
- 2024-10-15成功地使用本地的 NVIDIA GPU 运行 PyTorch 或 TensorFlow
- 2024-01-23供应链投毒预警 | 恶意Py包仿冒tensorflow AI框架实施后门投毒攻击
- 2024-01-19attributeerror: module 'tensorflow' has no attribute 'placeholder'
- 2024-01-19module 'tensorflow.compat.v2' has no attribute 'internal'
- 2023-07-17【2023年】第33天 Neural Networks and Deep Learning with TensorFlow
- 2023-07-10【2023年】第32天 Boosted Trees with TensorFlow 2.0(随机森林)
- 2023-07-09【2023年】第31天 Logistic Regression with TensorFlow 2.0(用TensorFlow进行逻辑回归)
- 2023-07-01【2023年】第30天 Supervised Learning with TensorFlow 2(用TensorFlow进行监督学习 2)
- 2023-06-18【2023年】第29天 Supervised Learning with TensorFlow 1(用TensorFlow进行监督学习 1)