TensorFlow首次快速体验

/ Java / 没有评论 / 2019浏览

TensorFlow是深度学习中使用人数最多的框架,本文快速尝试一下其能力,方便入门

添加依赖

<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow</artifactId>
    <version>1.13.1</version>
</dependency>

定义图模型

示例完成一个简单的函数:

f(x, y) = z = a*x + b*y

其中a, b是常量,x, y是变量

Graph graph = new Graph()
Operation a = graph.opBuilder("Const", "a")
        .setAttr("dtype", DataType.fromClass(Double.class))
        .setAttr("value", Tensor.<Double>create(3.0, Double.class))
        .build();
Operation b = graph.opBuilder("Const", "b")
        .setAttr("dtype", DataType.fromClass(Double.class))
        .setAttr("value", Tensor.<Double>create(2.0, Double.class))
        .build()
Operation x = graph.opBuilder("Placeholder", "x")
        .setAttr("dtype", DataType.fromClass(Double.class))
        .build();
Operation y = graph.opBuilder("Placeholder", "y")
        .setAttr("dtype", DataType.fromClass(Double.class))
        .build();
Operation ax = graph.opBuilder("Mul", "ax")
        .addInput(a.output(0))
        .addInput(x.output(0))
        .build();
Operation by = graph.opBuilder("Mul", "by")
        .addInput(b.output(0))
        .addInput(y.output(0))
        .build();
Operation z = graph.opBuilder("Add", "z")
        .addInput(ax.output(0))
        .addInput(by.output(0))
        .build();

可以看出来,用Java定义图模型比较麻烦,但是使用Python会简单很多

执行

Session session = new Session(graph);
Tensor<Double> tensor = session.runner().fetch("z")
        .feed("x", Tensor.create(3.0, Double.class))
        .feed("y", Tensor.create(6.0, Double.class))
        .run().get(0).expect(Double.class);
System.out.println(tensor.doubleValue());

图模型保存及加载

Path path = Paths.get("tensor.model");
byte[] bytes = graph.toGraphDef();
Files.write(path, bytes);
Graph graph = new Graph();
byte[] bytes = Files.readAllBytes(path);
graph.importGraphDef(bytes);

ps: 模型可以在不同语言通用,所以可以使用python训练模型,然后提供给其他语言使用,比如Java

结果

最后输出结果:21.0

参考