Mindspore实现手写字体识别

2021/7/3 23:51:41

本文主要是介绍Mindspore实现手写字体识别,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

Mindspore实现手写字体识别

一、实验目的

加深对神经网络原理的理解
熟悉Minspore平台
掌握训练过程

二、实验环境

Windows + Python3+
一台装有集成开发环境(IDE)—— PyCharm的计算机

三、实验内容

1.下载数据集放置目录如下
在这里插入图片描述

四、代码填写

#encoding=utf-8
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import mindspore.dataset as ds
train_data_path = r"\datasets\MNIST_Data\train"
test_data_path = r"\datasets\MNIST_Data\test"
mnist_ds = ds.MnistDataset(train_data_path)#加载数据集
print('The type of mnist_ds:', type(mnist_ds))
print("Number of pictures contained in the mnist_ds:",
      mnist_ds.get_dataset_size())
#迭代器读取数据
dic_ds = mnist_ds.create_dict_iterator()
item = next(dic_ds)
img = item["image"].asnumpy()
label = item["label"].asnumpy()
#打印数据集信息 并可视化
print("The item of mnist_ds:", item.keys())
print("Tensor of image in item:", img.shape)
print("The label of item:", label)
plt.imshow(np.squeeze(img))
plt.title("number:%s"% item["label"].asnumpy())
plt.show()
"""
-------定义dataset(dataloader)-----
"""
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore import dtype as mstype

def create_dataset(data_path, batch_size=32, repeat_size=1,
                   num_parallel_workers=1):
    #调用API读取MNIST数据集合
    mnist_ds = ds.MnistDataset(data_path)
"""
-------对数据增强-----
"""
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081
    #根据上面设置的参数阐释增强数据过程
    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)
    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
    rescale_op = CV.Rescale(rescale, shift)
    hwc2chw_op = CV.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32)
    #使用map函数对数据集进行操作
    mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image",
                            num_parallel_workers=num_parallel_workers)
    mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image",
                            num_parallel_workers=num_parallel_workers)
    # 设置数据读取,比如是否随机,批次量多少,数据量加倍
    buffer_size = 10000
    mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size)
    mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
    mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds

#初始化dataset并查看内容
ms_dataset = create_dataset(train_data_path)
print('Number of groups in the dataset:', ms_dataset.get_dataset_size())

"""
-------利用next获取样本并查看单个样本格式------
"""
data =next(ms_dataset.create_dict_iterator(output_numpy=True))#填写
images = data['image']#填写
labels =data['label']#填写
print('Tensor of image:', images.shape)
print('Labels:', labels)

"""
-------可视化数据集------
"""
count = 1
for i in images:
    plt.subplot(4, 8, count)
    plt.imshow(np.squeeze(i))
    plt.title('num:%s'%labels[count-1])
    plt.xticks([])
    count += 1
    plt.axis("off")
plt.show()
"""
-------定义LeNet5模型-----
"""
import mindspore.nn as nn
from mindspore.common.initializer import Normal

class LeNet5(nn.Cell):
    """Lenet network structure."""
    # define the operator required
    def __init__(self, num_class=10, num_channel=1):
        super(LeNet5, self).__init__()
        self.conv1=nn.Conv2d(num_channel,6,5,pad_mode='valid')
self.conv2=nn.Conv2d(6,16,5,pad_mode='valid')
self.fc1=nn.Dense(16*5*5,120,weight_init=Normal(0.02))
self.fc2=nn.Dense(120,84,weight_init=Normal(0.02))
self.fc3=nn.Dense(84,num_class,weight_init=Normal(0.02))
self.relu=nn.ReLU()
self.max_pool2d=nn.MaxPool2d(kernel_size=2,stride=2)
self.flatten=nn.Flatten()

    # use the preceding operators to construct networks
def construct(self, x):
        x=self.max_pool2d(self.relu(self.conv1(x)))
x=self.max_pool2d(self.relu(self.conv2(x)))
x=self.flatten(x)
x=self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x=self.fc3(x)
        return x
network = LeNet5()

"""
-------定义CALLBACK函数-----
"""
from mindspore.train.callback import Callback
#自定义CallBlack函数
# 记录损失和精度
class StepLossAccInfo(Callback):
    def __init__(self, model, eval_dataset, steps_loss, steps_eval):
        self.model = model
        self.eval_dataset = eval_dataset
        self.steps_loss = steps_loss
        self.steps_eval = steps_eval
    def step_end(self, run_context):
        cb_params = run_context.original_args()
        cur_epoch = cb_params.cur_epoch_num
        cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num
        self.steps_loss["loss_value"].append(str(cb_params.net_outputs))
        self.steps_loss["step"].append(str(cur_step))
        if cur_step % 125 == 0:
            acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False)
            self.steps_eval["step"].append(cur_step)
            self.steps_eval["acc"].append(acc["Accuracy"])

"""
-------开始训练-----
"""
import os
from mindspore import Tensor, Model
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.nn import Accuracy
network = LeNet5()
epoch_size = 1
momentum=0.9
lr=0.01
mnist_path =r"\datasets\MNIST_Data" #这里填写你的数据集路径
model_path =r"\datasets\models\ckpt\mindspore_quick_start"#模型保存路径
train_data_path = r"\datasets\MNIST_Data\train"
test_data_path = r"\datasets\MNIST_Data\test"
net_loss=SoftmaxCrossEntropyWithLogits(sparse=True,reduction='mean')
net_opt=nn.Momentum(network.trainable_params(),lr,momentum)

repeat_size = 1
ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size)
eval_dataset = create_dataset(os.path.join(mnist_path, "test"), 32)
# 使用Model定义模型,这个模型包括损失函数,优化器,网络结构,
model =Model(network,net_loss,net_opt,metrics={'Accuracy':Accuracy()})#填写
# 保存模型和参数
config_ck =CheckpointConfig(save_checkpoint_steps=375,keep_checkpoint_max=16)#使用Checkpoint设置保存模型
ckpoint_cb =ModelCheckpoint(prefix="checkpoint_lenet",directory=model_path,config=config_ck)#使用ModelCheckpoint设置保存模型的名称地址等信息
steps_loss = {"step": [], "loss_value": []}
steps_eval = {"step": [], "acc": []}
# 保存每一步step,以及对应的损失和准确率信息
step_loss_acc_info = StepLossAccInfo(model,eval_dataset,steps_loss,steps_eval)#使用StepLossAccInfo类
#填写训练模型
model.train(epoch_size,ds_train,callbacks=[ckpoint_cb,LossMonitor(125),step_loss_acc_info],dataset_sink_mode=False)
"""
-------打印想训练过程-----
"""

steps = steps_loss["step"]
loss_value = steps_loss["loss_value"]
steps = list(map(int, steps))
loss_value = list(map(float, loss_value))
plt.plot(steps, loss_value, color="red")
plt.xlabel("Steps")
plt.ylabel("Loss_value")
plt.title("Change chart of model loss value")
plt.show()
"""
------在测试集上验证模型-----
"""
from mindspore import load_checkpoint, load_param_into_net
#定义验证函数
def test_net(network, model, mnist_path):
    print("============== Starting Testing ==============")
    #填写  加载保存的模型
    param_dict = load_checkpoint(mnist_path)#填写
load_param_into_net(network,param_dict)
ds_eval =rd.create_dataset(os.path.join(mnist_path,"test")) #填写  创建测试集dataloader
acc =model.eval(ds_eval,dataset_sink_mode=False)#填写  输入模型获取精度
    print("============== Accuracy:{} ==============".format(acc))
test_net(network, model, mnist_path)

五、实验结果
读取数据集
在这里插入图片描述

数据集测试查看

在这里插入图片描述

数据集训练
在这里插入图片描述

预测
在这里插入图片描述
在这里插入图片描述



这篇关于Mindspore实现手写字体识别的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程