Mindspore实现手写字体识别
2021/7/3 23:51:41
本文主要是介绍Mindspore实现手写字体识别,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
Mindspore实现手写字体识别
一、实验目的
加深对神经网络原理的理解
熟悉Minspore平台
掌握训练过程
二、实验环境
Windows + Python3+
一台装有集成开发环境(IDE)—— PyCharm的计算机
三、实验内容
1.下载数据集放置目录如下
四、代码填写
#encoding=utf-8 import matplotlib.pyplot as plt import matplotlib import numpy as np import mindspore.dataset as ds train_data_path = r"\datasets\MNIST_Data\train" test_data_path = r"\datasets\MNIST_Data\test" mnist_ds = ds.MnistDataset(train_data_path)#加载数据集 print('The type of mnist_ds:', type(mnist_ds)) print("Number of pictures contained in the mnist_ds:", mnist_ds.get_dataset_size()) #迭代器读取数据 dic_ds = mnist_ds.create_dict_iterator() item = next(dic_ds) img = item["image"].asnumpy() label = item["label"].asnumpy() #打印数据集信息 并可视化 print("The item of mnist_ds:", item.keys()) print("Tensor of image in item:", img.shape) print("The label of item:", label) plt.imshow(np.squeeze(img)) plt.title("number:%s"% item["label"].asnumpy()) plt.show() """ -------定义dataset(dataloader)----- """ import mindspore.dataset.vision.c_transforms as CV import mindspore.dataset.transforms.c_transforms as C from mindspore.dataset.vision import Inter from mindspore import dtype as mstype def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=1): #调用API读取MNIST数据集合 mnist_ds = ds.MnistDataset(data_path) """ -------对数据增强----- """ resize_height, resize_width = 32, 32 rescale = 1.0 / 255.0 shift = 0.0 rescale_nml = 1 / 0.3081 shift_nml = -1 * 0.1307 / 0.3081 #根据上面设置的参数阐释增强数据过程 resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) rescale_op = CV.Rescale(rescale, shift) hwc2chw_op = CV.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) #使用map函数对数据集进行操作 mnist_ds = mnist_ds.map(operations=type_cast_op, input_columns="label", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=rescale_nml_op, input_columns="image", num_parallel_workers=num_parallel_workers) mnist_ds = mnist_ds.map(operations=hwc2chw_op, input_columns="image", num_parallel_workers=num_parallel_workers) # 设置数据读取,比如是否随机,批次量多少,数据量加倍 buffer_size = 10000 mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True) mnist_ds = mnist_ds.repeat(repeat_size) return mnist_ds #初始化dataset并查看内容 ms_dataset = create_dataset(train_data_path) print('Number of groups in the dataset:', ms_dataset.get_dataset_size()) """ -------利用next获取样本并查看单个样本格式------ """ data =next(ms_dataset.create_dict_iterator(output_numpy=True))#填写 images = data['image']#填写 labels =data['label']#填写 print('Tensor of image:', images.shape) print('Labels:', labels) """ -------可视化数据集------ """ count = 1 for i in images: plt.subplot(4, 8, count) plt.imshow(np.squeeze(i)) plt.title('num:%s'%labels[count-1]) plt.xticks([]) count += 1 plt.axis("off") plt.show() """ -------定义LeNet5模型----- """ import mindspore.nn as nn from mindspore.common.initializer import Normal class LeNet5(nn.Cell): """Lenet network structure.""" # define the operator required def __init__(self, num_class=10, num_channel=1): super(LeNet5, self).__init__() self.conv1=nn.Conv2d(num_channel,6,5,pad_mode='valid') self.conv2=nn.Conv2d(6,16,5,pad_mode='valid') self.fc1=nn.Dense(16*5*5,120,weight_init=Normal(0.02)) self.fc2=nn.Dense(120,84,weight_init=Normal(0.02)) self.fc3=nn.Dense(84,num_class,weight_init=Normal(0.02)) self.relu=nn.ReLU() self.max_pool2d=nn.MaxPool2d(kernel_size=2,stride=2) self.flatten=nn.Flatten() # use the preceding operators to construct networks def construct(self, x): x=self.max_pool2d(self.relu(self.conv1(x))) x=self.max_pool2d(self.relu(self.conv2(x))) x=self.flatten(x) x=self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x=self.fc3(x) return x network = LeNet5() """ -------定义CALLBACK函数----- """ from mindspore.train.callback import Callback #自定义CallBlack函数 # 记录损失和精度 class StepLossAccInfo(Callback): def __init__(self, model, eval_dataset, steps_loss, steps_eval): self.model = model self.eval_dataset = eval_dataset self.steps_loss = steps_loss self.steps_eval = steps_eval def step_end(self, run_context): cb_params = run_context.original_args() cur_epoch = cb_params.cur_epoch_num cur_step = (cur_epoch-1)*1875 + cb_params.cur_step_num self.steps_loss["loss_value"].append(str(cb_params.net_outputs)) self.steps_loss["step"].append(str(cur_step)) if cur_step % 125 == 0: acc = self.model.eval(self.eval_dataset, dataset_sink_mode=False) self.steps_eval["step"].append(cur_step) self.steps_eval["acc"].append(acc["Accuracy"]) """ -------开始训练----- """ import os from mindspore import Tensor, Model from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.nn import Accuracy network = LeNet5() epoch_size = 1 momentum=0.9 lr=0.01 mnist_path =r"\datasets\MNIST_Data" #这里填写你的数据集路径 model_path =r"\datasets\models\ckpt\mindspore_quick_start"#模型保存路径 train_data_path = r"\datasets\MNIST_Data\train" test_data_path = r"\datasets\MNIST_Data\test" net_loss=SoftmaxCrossEntropyWithLogits(sparse=True,reduction='mean') net_opt=nn.Momentum(network.trainable_params(),lr,momentum) repeat_size = 1 ds_train = create_dataset(os.path.join(mnist_path, "train"), 32, repeat_size) eval_dataset = create_dataset(os.path.join(mnist_path, "test"), 32)
# 使用Model定义模型,这个模型包括损失函数,优化器,网络结构, model =Model(network,net_loss,net_opt,metrics={'Accuracy':Accuracy()})#填写 # 保存模型和参数 config_ck =CheckpointConfig(save_checkpoint_steps=375,keep_checkpoint_max=16)#使用Checkpoint设置保存模型 ckpoint_cb =ModelCheckpoint(prefix="checkpoint_lenet",directory=model_path,config=config_ck)#使用ModelCheckpoint设置保存模型的名称地址等信息 steps_loss = {"step": [], "loss_value": []} steps_eval = {"step": [], "acc": []} # 保存每一步step,以及对应的损失和准确率信息
step_loss_acc_info = StepLossAccInfo(model,eval_dataset,steps_loss,steps_eval)#使用StepLossAccInfo类 #填写训练模型 model.train(epoch_size,ds_train,callbacks=[ckpoint_cb,LossMonitor(125),step_loss_acc_info],dataset_sink_mode=False) """ -------打印想训练过程----- """ steps = steps_loss["step"] loss_value = steps_loss["loss_value"] steps = list(map(int, steps)) loss_value = list(map(float, loss_value)) plt.plot(steps, loss_value, color="red") plt.xlabel("Steps") plt.ylabel("Loss_value") plt.title("Change chart of model loss value") plt.show()
""" ------在测试集上验证模型----- """ from mindspore import load_checkpoint, load_param_into_net #定义验证函数 def test_net(network, model, mnist_path): print("============== Starting Testing ==============") #填写 加载保存的模型 param_dict = load_checkpoint(mnist_path)#填写 load_param_into_net(network,param_dict) ds_eval =rd.create_dataset(os.path.join(mnist_path,"test")) #填写 创建测试集dataloader acc =model.eval(ds_eval,dataset_sink_mode=False)#填写 输入模型获取精度 print("============== Accuracy:{} ==============".format(acc)) test_net(network, model, mnist_path)
五、实验结果
读取数据集
数据集测试查看
数据集训练
预测
这篇关于Mindspore实现手写字体识别的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-12-25初学者必备:订单系统资料详解与实操教程
- 2024-12-24内网穿透资料入门教程
- 2024-12-24微服务资料入门指南
- 2024-12-24微信支付系统资料入门教程
- 2024-12-24微信支付资料详解:新手入门指南
- 2024-12-24Hbase资料:新手入门教程
- 2024-12-24Java部署资料
- 2024-12-24Java订单系统资料:新手入门教程
- 2024-12-24Java分布式资料入门教程
- 2024-12-24Java监控系统资料详解与入门教程