一个好玩的deep learning Demo!
2022/8/30 23:52:49
本文主要是介绍一个好玩的deep learning Demo!,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
对于生活中的熟悉的动物,我们人脑经过一次扫描,便可以得到该动物的物种!那么机器是如何识别这个图片上的动物是属于哪一物种呢?
本次实验借生活中最常见的猫和狗来探究其原理!
环境准备:
tensorflow ,python,一些data
实验预期:
当模型训练完成后,我们可以用该模型去预测一张图片属于哪一个类别,很显然,本次项目属于一个二分类问题,
网上有很多此类的项目,但是都不能很好的落地,那么这次实验所完成的最终结果是,我们上传一张图片,控制台
便会返回该图片的类别:猫/狗
模型搭建:
对于图片识别来说,最强大的工具莫过于卷积神经网络,对于CNN的原理也不是很难,只要知道其主要的计算过程即可,
熟悉CNN的人都知道,并不是层数越多越好,因为层数过多,会造正过拟合,导致实验结果不会很理想,所以经过我多次的实验,
最终模型的设置如下:
model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(16, (3, 3), activation='relu', input_shape=(150, 150, 3)), tf.keras.layers.MaxPooling2D(2, 2), tf.keras.layers.Conv2D(32, (3, 3), activation='relu'), tf.keras.layers.MaxPooling2D(2, 2), tf.keras.layers.Conv2D(64, (3, 3), activation='relu'), tf.keras.layers.MaxPooling2D(2, 2), tf.keras.layers.Flatten(), tf.keras.layers.Dense(512, activation='relu'), tf.keras.layers.Dense(1, activation='sigmoid') ])
每一层卷积跟一层最大池化,Conv2D()中参数:16表示卷积核个数,(3,3)表示卷积核大小,很多论文中给出的代码中设定的也是(3,3),input_shape表示输入数据形状,后面是通道数;
经过最大池化留下来的神经元对输出才会有贡献!环节卷积层对位置的敏感性!
然后再模型之前,我们也需要对数据进行一些操作:读取数据,将数据分为验证数据集和训练数据集
base_dir = 'D:/cats and dogs' train_dir = os.path.join(base_dir, 'train') validation_dir = os.path.join(base_dir, 'validation') train_cats_dir = os.path.join(train_dir, 'cats') train_dogs_dir = os.path.join(train_dir, 'dogs') validation_cats_dir = os.path.join(validation_dir, 'cats') validation_dogs_dir = os.path.join(validation_dir, 'dogs')
接下来的操作就是一些固定的步骤,对数据进行归一化,生成带标签的数据,绘制损失曲线等,直接上代码:
train_datagen = ImageDataGenerator(rescale=1.0 / 255.) test_datagen = ImageDataGenerator(rescale=1.0 / 255.) train_generator = train_datagen.flow_from_directory(train_dir, batch_size=20, class_mode='binary', target_size=(150, 150)) validation_generator = test_datagen.flow_from_directory(validation_dir, batch_size=20, class_mode='binary', target_size=(150, 150)) history = model.fit_generator(train_generator, validation_data=validation_generator, steps_per_epoch=100, epochs=15, validation_steps=50, verbose=2) model.save('model.h5') acc = history.history['acc'] val_acc = history.history['val_acc'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(len(acc)) plt.plot(epochs, acc) plt.plot(epochs, val_acc) plt.title('Training and validation accuracy') plt.legend(('Training accuracy', 'validation accuracy')) plt.figure() plt.plot(epochs, loss) plt.plot(epochs, val_loss) plt.legend(('Training loss', 'validation loss')) plt.title('Training and validation loss') plt.show()
预测部分
from tensorflow.keras.models import load_model import numpy as np from tensorflow.keras.preprocessing import image path = 'D:/cats and dogs/cat.123.jpg' model = load_model('model.h5') img = image.load_img(path, target_size=(150, 150)) x = image.img_to_array(img) / 255.0 x = np.expand_dims(x, axis=0) # np.vstack:按垂直方向(行顺序)堆叠数组构成一个新的数组 images = np.vstack([x]) classes = model.predict(images, batch_size=1) if classes[0] > 0.5: print("图片识别为狗") else: print("图片识别为猫")
结果说明还可以!!!!!!!
这篇关于一个好玩的deep learning Demo!的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-05-01为什么公共事业机构会偏爱 TiDB :TiDB 数据库在某省妇幼健康管理系统的应用
- 2024-04-26敏捷开发:想要快速交付就必须舍弃产品质量?
- 2024-04-26静态代码分析的这些好处,我竟然都不知道?
- 2024-04-26你在测试金字塔的哪一层?(下)
- 2024-04-26快刀斩乱麻,DevOps让代码评审也自动起来
- 2024-04-262024年最好用的10款ER图神器!
- 2024-04-2203-为啥大模型LLM还没能完全替代你?
- 2024-04-2101-大语言模型发展
- 2024-04-17基于SpringWeb MultipartFile文件上传、下载功能
- 2024-04-14个人开发者,Spring Boot 项目如何部署