深度学习之基于CNN和ResNet实现鸟类识别
2021/7/22 23:14:16
本文主要是介绍深度学习之基于CNN和ResNet实现鸟类识别,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
本次利用迁移学习用已经构建好的ResNet网络对鸟类图片进行分类,但是结果不甚理想。
1.导入库
import numpy as np import tensorflow as tf import os,PIL import random import pathlib import matplotlib.pyplot as plt from sklearn.model_selection import train_test_split from keras.utils import np_utils
2.数据加载
(需要数据的可以私信我
加载数据 dataset_url = "E:/tmp/.keras/datasets/Birds_photos" dataset_dir = pathlib.Path(dataset_url) train_Bananaquit = os.path.join(dataset_dir,"train","Bananaquit") train_BlackSki = os.path.join(dataset_dir,"train","Black Skimmer") train_BTB = os.path.join(dataset_dir,"train","Black Throated Bushtiti") train_Cockatoo = os.path.join(dataset_dir,"train","Cockatoo") train_dir = os.path.join(dataset_dir,"train") test_Bananaquit = os.path.join(dataset_dir,"test","Bananaquit") test_BlackSki = os.path.join(dataset_dir,"test","Black Skimmer") test_BTB = os.path.join(dataset_dir,"test","Black Throated Bushtiti") test_Cockatoo = os.path.join(dataset_dir,"test","Cockatoo") test_dir = os.path.join(dataset_dir,"test") #统计训练集和测试集的数据数目 train_Bananaquit_num = len(os.listdir(train_Bananaquit)) train_BlackSki_num = len(os.listdir(train_BlackSki)) train_BTB_num = len(os.listdir(train_BTB)) train_Cockatoo_num = len(os.listdir(train_Cockatoo)) train_all = train_Bananaquit_num+train_BlackSki_num+train_BTB_num+train_Cockatoo_num test_Bananaquit_num = len(os.listdir(test_Bananaquit)) test_BlackSki_num = len(os.listdir(test_BlackSki)) test_BTB_num = len(os.listdir(test_BTB)) test_Cockatoo_num = len(os.listdir(test_Cockatoo)) test_all = test_Bananaquit_num+test_BlackSki_num+test_BTB_num+test_Cockatoo_num
3.超参数的设置
其实这一模块博主一直不太明白,每次都是乱试,不知道怎样设置超参数才能使得效果最好。
batch_size = 32 epochs = 10 height = 224 width = 224
4.数据预处理
数据预处理的几步:归一化->调整图片大小
train_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255) test_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255) train_data_gen = train_generator.flow_from_directory( batch_size=batch_size, directory=train_dir, shuffle=True, target_size=(height,width), class_mode="categorical" ) test_data_gen = test_generator.flow_from_directory( batch_size=batch_size, directory=test_dir, shuffle=True, target_size=(height,width), class_mode="categorical" )
5.CNN网络搭建&&编译
model = tf.keras.Sequential([ tf.keras.layers.Conv2D(16,3,padding="same",activation="relu",input_shape=(height,width,3)), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Conv2D(32,3,padding="same",activation="relu"), tf.keras.layers.MaxPooling2D(), tf.keras.layers.Conv2D(64,3,padding="same",activation="relu"), tf.keras.layers.AveragePooling2D((2,2)), tf.keras.layers.Dropout(0.5), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128,activation="relu"), tf.keras.layers.Dense(4,activation='softmax') ]) model.compile(optimizer="adam", loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), metrics=["acc"])
结果如下所示:
虽然加入了Dropout层,但是仍然出现了过拟合的现象。基于此,进行数据增强操作。
6.数据增强
这一部分应当与数据预处理合为一步操作。数据增强包括随机选择、水平翻转、放大操作等。
train_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255, rotation_range=45,#随机翻转 width_shift_range=.15, height_shift_range=.15, horizontal_flip=True,#水平翻转 zoom_range=0.5#放大操作 )
经过数据增强之后的实验结果如下所示:
经过20次epochs之后,过拟合的现象得到了缓解。
7.ResNet
利用已经搭建好的ResNet网络对同样的数据集进行训练。
conv_base = tf.keras.applications.ResNet50(weights='imagenet', include_top=False, input_shape=(height,width, 3)) conv_base.trainable = False model = tf.keras.Sequential() model.add(conv_base) model.add(tf.keras.layers.GlobalAveragePooling2D()) model.add(tf.keras.layers.Flatten()) model.add(tf.keras.layers.Dense(512,activation='relu')) model.add(tf.keras.layers.Dense(4,activation='sigmoid')) model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['acc'])
由于硬件的原因,训练速度特别慢,而且实验效果很差,在没有经过数据增强之前,过拟合现象(有可能不是这种现象)很严重,至于数据增强之后的效果如何,博主并未测试。
对于训练集,准确率有时高达100%。但是对于测试集,实验效果很难差强人意。希望过路的大佬指正。除此之外,博主还利用了VGG16网络进行训练,实验效果相对于ResNet50而言变好了,但是训练速度特别慢。
努力加油a啊
这篇关于深度学习之基于CNN和ResNet实现鸟类识别的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-12-22怎么通过控制台去看我的页面渲染的内容在哪个文件中呢-icode9专业技术文章分享
- 2024-12-22el-tabs 组件只被引用了一次,但有时会渲染两次是什么原因?-icode9专业技术文章分享
- 2024-12-22wordpress有哪些好的安全插件?-icode9专业技术文章分享
- 2024-12-22wordpress如何查看系统有哪些cron任务?-icode9专业技术文章分享
- 2024-12-21Svg Sprite Icon教程:轻松入门与应用指南
- 2024-12-20Excel数据导出实战:新手必学的简单教程
- 2024-12-20RBAC的权限实战:新手入门教程
- 2024-12-20Svg Sprite Icon实战:从入门到上手的全面指南
- 2024-12-20LCD1602显示模块详解
- 2024-12-20利用Gemini构建处理各种PDF文档的Document AI管道