Java程序员学深度学习 DJL上手5 训练自己的模型

2021/9/20 17:28:43

本文主要是介绍Java程序员学深度学习 DJL上手5 训练自己的模型,对大家解决编程问题具有一定的参考价值,需要的程序猿们随着小编来一起学习吧!

Java程序员学深度学习 DJL上手5 训练自己的模型

  • 一、准备环境
  • 二、创建示例项目
  • 三、准备数据集
  • 四、创建模型
  • 五、创建训练器
    • 1. 训练器配置
    • 2. 初始化训练器
    • 3. 训练模型
    • 4. 保存模型
  • 六、源代码
    • 1. pom
    • 2. java

一、准备环境

  • windows
  • idea
  • maven

二、创建示例项目

三、准备数据集

int batchSize = 32;
Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
mnist.prepare(new ProgressBar());

这里对数据集进行了分批处理,每批大小32,合适的分批大小将在训练时显著提升性能。

四、创建模型

本节会根据之前文章创建模型。由于 MNIST 数据集中的图像为 28x28 灰度图像,这里我们创建一个具有 28 x 28 输入的 MLP 块。
输出的图输出为 10,因为每个图像可能有 10 个可能的类(0 到 9)。
对于隐藏的层,其大小是猜测的值new int[] {128, 64}

Model model = Model.newInstance("mlp");
model.setBlock(new Mlp(28 * 28, 10, new int[] {128, 64}));

五、创建训练器

1. 训练器配置

  • 损失函数,用来测量模型与测试数据集的匹配程度,值越低越好;这里定义为softmaxCrossEntropyLoss()
  • 评估函数,也用于测量模型与数据集的匹配程度。与损失不同,它们只供人们查看,不用于优化模型。
  • 监听器,用来监控训练过程。

        DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                .addEvaluator(new Accuracy())
                .addTrainingListeners(TrainingListener.Defaults.logging());

        Trainer trainer = model.newTrainer(config);

2. 初始化训练器

这里使用输入的形状来初始化训练器。初始化函数里形状的第一个参数是批次大小,这个不影响参数初始化。
第二个参数是输入图像的像素数,即28*28。

        trainer.initialize(new Shape(1, 28 * 28));

3. 训练模型

这里使用了DJL的EasyTrain,

        int epoch = 2;
        EasyTrain.fit(trainer, epoch, mnist, null);

4. 保存模型

保存模型还可以添加一些元数据,如训练迭代次数、训练精度等。

        Path modelDir = Paths.get("build/mlp");
        Files.createDirectories(modelDir);

        model.setProperty("Epoch", String.valueOf(epoch));

        model.save(modelDir, "mlp");

        System.out.println(model);

六、源代码

1. pom

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>com.xundh</groupId>
    <artifactId>djl-learning</artifactId>
    <version>0.1-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
        <java.version>8</java.version>
        <djl.version>0.13.0-SNAPSHOT</djl.version>
    </properties>

    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>ai.djl</groupId>
                <artifactId>bom</artifactId>
                <version>${djl.version}</version>
                <type>pom</type>
                <scope>import</scope>
            </dependency>
        </dependencies>
    </dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>api</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>basicdataset</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl</groupId>
            <artifactId>model-zoo</artifactId>
        </dependency>
        <!-- Pytorch -->
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-engine</artifactId>
        </dependency>
        <dependency>
            <groupId>ai.djl.pytorch</groupId>
            <artifactId>pytorch-native-auto</artifactId>
            <version>1.7.0</version>
        </dependency>
    </dependencies>
</project>

2. java

package com.xundh;

import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.Mnist;
import ai.djl.basicmodelzoo.basic.Mlp;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;

public class NDArrayLearning {
    public static void main(String[] args) throws IOException, TranslateException {
        int batchSize = 32;
        Mnist mnist = Mnist.builder().setSampling(batchSize, true).build();
        mnist.prepare(new ProgressBar());

        Model model = Model.newInstance("mlp");
        model.setBlock(new Mlp(28 * 28, 10, new int[]{128, 64}));
        DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                .addEvaluator(new Accuracy())
                .addTrainingListeners(TrainingListener.Defaults.logging());

        Trainer trainer = model.newTrainer(config);
        trainer.initialize(new Shape(1, 28 * 28));
        int epoch = 2;

        EasyTrain.fit(trainer, epoch, mnist, null);

        Path modelDir = Paths.get("build/mlp");
        Files.createDirectories(modelDir);

        model.setProperty("Epoch", String.valueOf(epoch));

        model.save(modelDir, "mlp");

        System.out.println(model);
    }
}

运行结果示例:
在这里插入图片描述



这篇关于Java程序员学深度学习 DJL上手5 训练自己的模型的文章就介绍到这儿,希望我们推荐的文章对大家有所帮助,也希望大家多多支持为之网!


扫一扫关注最新编程教程