Java程序员学深度学习 DJL上手5 训练自己的模型
2021/9/20 17:28:43
本文主要是介绍Java程序员学深度学习 DJL上手5 训练自己的模型,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!
Java程序员学深度学习 DJL上手5 训练自己的模型
- 一、准备环境
- 二、创建示例项目
- 三、准备数据集
- 四、创建模型
- 五、创建训练器
- 1. 训练器配置
- 2. 初始化训练器
- 3. 训练模型
- 4. 保存模型
- 六、源代码
- 1. pom
- 2. java
一、准备环境
- windows
- idea
- maven
二、创建示例项目
三、准备数据集
int batchSize = 32; Mnist mnist = Mnist.builder().setSampling(batchSize, true).build(); mnist.prepare(new ProgressBar());
这里对数据集进行了分批处理,每批大小32,合适的分批大小将在训练时显著提升性能。
四、创建模型
本节会根据之前文章创建模型。由于 MNIST 数据集中的图像为 28x28 灰度图像,这里我们创建一个具有 28 x 28 输入的 MLP 块。
输出的图输出为 10,因为每个图像可能有 10 个可能的类(0 到 9)。
对于隐藏的层,其大小是猜测的值new int[] {128, 64}
Model model = Model.newInstance("mlp"); model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));
五、创建训练器
1. 训练器配置
- 损失函数,用来测量模型与测试数据集的匹配程度,值越低越好;这里定义为
softmaxCrossEntropyLoss()
- 评估函数,也用于测量模型与数据集的匹配程度。与损失不同,它们只供人们查看,不用于优化模型。
- 监听器,用来监控训练过程。
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) .addTrainingListeners(TrainingListener.Defaults.logging()); Trainer trainer = model.newTrainer(config);
2. 初始化训练器
这里使用输入的形状来初始化训练器。初始化函数里形状的第一个参数是批次大小,这个不影响参数初始化。
第二个参数是输入图像的像素数,即28*28。
trainer.initialize(new Shape(1, 28 * 28));
3. 训练模型
这里使用了DJL的EasyTrain,
int epoch = 2; EasyTrain.fit(trainer, epoch, mnist, null);
4. 保存模型
保存模型还可以添加一些元数据,如训练迭代次数、训练精度等。
Path modelDir = Paths.get("build/mlp"); Files.createDirectories(modelDir); model.setProperty("Epoch", String.valueOf(epoch)); model.save(modelDir, "mlp"); System.out.println(model);
六、源代码
1. pom
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>com.xundh</groupId> <artifactId>djl-learning</artifactId> <version>0.1-SNAPSHOT</version> <properties> <maven.compiler.source>1.8</maven.compiler.source> <maven.compiler.target>1.8</maven.compiler.target> <java.version>8</java.version> <djl.version>0.13.0-SNAPSHOT</djl.version> </properties> <dependencyManagement> <dependencies> <dependency> <groupId>ai.djl</groupId> <artifactId>bom</artifactId> <version>${djl.version}</version> <type>pom</type> <scope>import</scope> </dependency> </dependencies> </dependencyManagement> <dependencies> <dependency> <groupId>ai.djl</groupId> <artifactId>api</artifactId> </dependency> <dependency> <groupId>ai.djl</groupId> <artifactId>basicdataset</artifactId> </dependency> <dependency> <groupId>ai.djl</groupId> <artifactId>model-zoo</artifactId> </dependency> <!-- Pytorch --> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-engine</artifactId> </dependency> <dependency> <groupId>ai.djl.pytorch</groupId> <artifactId>pytorch-native-auto</artifactId> <version>1.7.0</version> </dependency> </dependencies> </project>
2. java
package com.xundh; import ai.djl.Model; import ai.djl.basicdataset.cv.classification.Mnist; import ai.djl.basicmodelzoo.basic.Mlp; import ai.djl.ndarray.types.Shape; import ai.djl.training.DefaultTrainingConfig; import ai.djl.training.EasyTrain; import ai.djl.training.Trainer; import ai.djl.training.evaluator.Accuracy; import ai.djl.training.listener.TrainingListener; import ai.djl.training.loss.Loss; import ai.djl.training.util.ProgressBar; import ai.djl.translate.TranslateException; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; public class NDArrayLearning { public static void main(String[] args) throws IOException, TranslateException { int batchSize = 32; Mnist mnist = Mnist.builder().setSampling(batchSize, true).build(); mnist.prepare(new ProgressBar()); Model model = Model.newInstance("mlp"); model.setBlock(new Mlp(28 * 28, 10, new int[]{128, 64})); DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss()) .addEvaluator(new Accuracy()) .addTrainingListeners(TrainingListener.Defaults.logging()); Trainer trainer = model.newTrainer(config); trainer.initialize(new Shape(1, 28 * 28)); int epoch = 2; EasyTrain.fit(trainer, epoch, mnist, null); Path modelDir = Paths.get("build/mlp"); Files.createDirectories(modelDir); model.setProperty("Epoch", String.valueOf(epoch)); model.save(modelDir, "mlp"); System.out.println(model); } }
运行结果示例:
这篇关于Java程序员学深度学习 DJL上手5 训练自己的模型的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!
- 2024-12-28一步到位:购买适合 SEO 的域名全攻略
- 2024-12-27OpenFeign服务间调用学习入门
- 2024-12-27OpenFeign服务间调用学习入门
- 2024-12-27OpenFeign学习入门:轻松掌握微服务通信
- 2024-12-27OpenFeign学习入门:轻松掌握微服务间的HTTP请求
- 2024-12-27JDK17新特性学习入门:简洁教程带你轻松上手
- 2024-12-27JMeter传递token学习入门教程
- 2024-12-27JMeter压测学习入门指南
- 2024-12-27JWT单点登录学习入门指南
- 2024-12-27JWT单点登录原理学习入门