Java调用Keras、Tensorflow模型

  • 2018-04-03
  • 10,008

实现python离线训练模型,Java在线预测部署。

目前深度学习主流使用python训练自己的模型,有非常多的框架提供了能快速搭建神经网络的功能,其中Keras提供了high-level的语法,底层可以使用tensorflow或者theano。

但是有很多公司后台应用是用Java开发的,如果用python提供HTTP接口,对业务延迟要求比较高的话,仍然会有一定得延迟,所以能不能使用Java调用模型,python可以离线的训练模型?(tensorflow也提供了成熟的部署方案TensorFlow Serving

手头上有一个用Keras训练的模型,网上关于Java调用Keras模型的资料不是很多,而且大部分是重复的,并且也没有讲的很详细。大致有两种方案,一种是基于Java的深度学习库导入Keras模型实现,另外一种是用tensorflow提供的Java接口调用。

Deeplearning4J

Eclipse Deeplearning4j is the first commercial-grade, open-source, distributed deep-learning library written for Java and Scala. Integrated with Hadoop and Spark, DL4J brings AIAI to business environments for use on distributed GPUs and CPUs.

Deeplearning4j目前支持导入Keras训练的模型,并且提供了类似python中numpy的一些功能,更方便地处理结构化的数据。遗憾的是,Deeplearning4j现在只覆盖了Keras <2.0版本的大部分Layer,如果你是用Keras 2.0以上的版本,在导入模型的时候可能会报错。

了解更多:
Keras Model Import: Supported Features
Importing Models From Keras to Deeplearning4j

Tensorflow

文档,Java的文档很少,不过调用模型的过程也很简单。采用这种方式调用模型需要先将Keras导出的模型转成tensorflow的protobuf协议的模型。

1、Keras的h5模型转为pb模型

在Keras中使用model.save(model.h5)保存当前模型为HDF5格式的文件中。
Keras的后端框架使用的是tensorflow,所以先把模型导出为pb模型。在Java中只需要调用模型进行预测,所以将当前的graph中的Variable全部变成Constant,并且使用训练后的weight。以下是freeze graph的代码:

该方法可以将tensor为Variable的graph全部转为constant并且使用训练后的weight。注意output_name比较重要,后面Java调用模型的时候会用到。

在Keras中,模型是这么定义的:

下面的代码可以查看定义好的Keras模型的输入、输出的name,这对之后Java调用有帮助。

训练好Keras模型后,转换为pb模型:

运行之后会生成model.pb的模型,这将是之后调用的模型。

2、Java调用

新建一个maven项目,pom里面导入tensorflow包:

核心代码:

Graph和Tensor对象都是需要通过close()方法显式地释放占用的资源,代码中使用了try-with-resources的方法实现的。

至此,已经可以实现Keras离线训练,Java在线预测的功能。

链接:https://www.ioiogoo.cn/2018/04/03/java调用keras、tensorflow模型/
本站所有文章除特殊说明外均为原创,转载请注明出处!