Skip to content

Commit b92525e

Browse files
authored
Generate op specs and package as resource (#247)
1 parent 5ca958d commit b92525e

File tree

7 files changed

+161
-5
lines changed

7 files changed

+161
-5
lines changed

tensorflow-core/pom.xml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
</modules>
3737

3838
<properties>
39+
<!-- Match version used by TensorFlow, in tensorflow/workspace.bzl -->
40+
<protobuf.version>3.8.0</protobuf.version>
41+
3942
<native.classifier>${javacpp.platform}${javacpp.platform.extension}</native.classifier>
4043
<javacpp.build.skip>false</javacpp.build.skip> <!-- To skip execution of build.sh: -Djavacpp.build.skip=true -->
4144
<javacpp.parser.skip>false</javacpp.parser.skip> <!-- To skip header file parsing phase: -Djavacpp.parser.skip=true -->

tensorflow-core/tensorflow-core-api/BUILD

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,33 @@ cc_library(
3636
],
3737
)
3838

39+
tf_cc_binary(
40+
name = "java_op_exporter",
41+
linkopts = select({
42+
"@org_tensorflow//tensorflow:windows": [],
43+
"//conditions:default": ["-lm"],
44+
}),
45+
deps = [
46+
":java_op_export_lib",
47+
],
48+
)
49+
50+
cc_library(
51+
name = "java_op_export_lib",
52+
srcs = [
53+
"src/bazel/op_generator/op_export_main.cc",
54+
],
55+
hdrs = [
56+
],
57+
copts = tf_copts(),
58+
deps = [
59+
"@org_tensorflow//tensorflow/core:framework",
60+
"@org_tensorflow//tensorflow/core:lib",
61+
"@org_tensorflow//tensorflow/core:op_gen_lib",
62+
"@org_tensorflow//tensorflow/core:protos_all_cc",
63+
],
64+
)
65+
3966
filegroup(
4067
name = "java_api_def",
4168
srcs = glob(["src/bazel/api_def/*"])

tensorflow-core/tensorflow-core-api/build.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ bazel build $BUILD_FLAGS ${BUILD_USER_FLAGS:-} \
3838
@org_tensorflow//tensorflow/tools/lib_package:jnilicenses_generate \
3939
:java_proto_gen_sources \
4040
:java_op_generator \
41+
:java_op_exporter \
4142
:java_api_import \
4243
:custom_ops_test
4344

@@ -85,7 +86,17 @@ $BAZEL_BIN/java_op_generator \
8586
--api_dirs=$BAZEL_SRCS/external/org_tensorflow/tensorflow/core/api_def/base_api,src/bazel/api_def \
8687
$TENSORFLOW_LIB
8788

89+
GEN_RESOURCE_DIR=src/gen/resources/org/tensorflow/op
90+
mkdir -p $GEN_RESOURCE_DIR
91+
92+
# Generate Java operator wrappers
93+
$BAZEL_BIN/java_op_exporter \
94+
--api_dirs=$BAZEL_SRCS/external/org_tensorflow/tensorflow/core/api_def/base_api,src/bazel/api_def \
95+
$TENSORFLOW_LIB > $GEN_RESOURCE_DIR/ops.pb
96+
97+
8898
# Copy generated Java protos from source jars
99+
89100
cd $GEN_SRCS_DIR
90101
find $TENSORFLOW_BIN/core -name \*-speed-src.jar -exec jar xf {} \;
91102
rm -rf META-INF

tensorflow-core/tensorflow-core-api/pom.xml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
<description>Platform-dependent native code and pure-Java code for the TensorFlow machine intelligence library.</description>
1616

1717
<properties>
18-
<!-- Match version used by TensorFlow, in tensorflow/workspace.bzl -->
19-
<protobuf.version>3.8.0</protobuf.version>
2018
<native.build.skip>false</native.build.skip>
2119
<javacpp.build.skip>${native.build.skip}</javacpp.build.skip>
2220
<javacpp.parser.skip>${native.build.skip}</javacpp.parser.skip>
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include <string>
17+
#include <vector>
18+
#include <stdio.h>
19+
20+
#include "tensorflow/core/framework/op.h"
21+
#include "tensorflow/core/lib/core/status.h"
22+
#include "tensorflow/core/lib/strings/str_util.h"
23+
#include "tensorflow/core/platform/env.h"
24+
#include "tensorflow/core/platform/init_main.h"
25+
#include "tensorflow/core/util/command_line_flags.h"
26+
#include "tensorflow/core/framework/api_def.pb.h"
27+
#include "tensorflow/core/framework/op_def.pb.h"
28+
#include "tensorflow/core/lib/core/status.h"
29+
#include "tensorflow/core/lib/core/errors.h"
30+
#include "tensorflow/core/lib/io/path.h"
31+
#include "tensorflow/core/framework/op_gen_lib.h"
32+
#include "google/protobuf/unknown_field_set.h"
33+
34+
namespace tensorflow {
35+
namespace java {
36+
37+
const char kUsageHeader[] =
38+
"\n\nExporter of operation and API defs, for use in Java op generation.\n\n"
39+
"This executable exports the op def and api def protos for all operations "
40+
"registered in the provided list of libraries. The proto will be printed "
41+
"to stdout in binary format. It is an OpList proto, with each OpDef having"
42+
" the associated ApiDef attached as unknown field 100\n\n"
43+
"The first argument is the location of the tensorflow binary built for TF-"
44+
"Java.\nFor example, `bazel-out/k8-opt/bin/external/org_tensorflow/tensorfl"
45+
"ow/libtensorflow_cc.so`.\n\n"
46+
"Finally, the `--api_dirs` argument takes a list of comma-separated "
47+
"directories of API definitions can be provided to override default\n"
48+
"values found in the ops definitions. Directories are ordered by priority "
49+
"(the last having precedence over the first).\nFor example, `bazel-tensorf"
50+
"low-core-api/external/org_tensorflow/tensorflow/core/api_def/base_api,src"
51+
"/bazel/api_def`\n\n";
52+
53+
void Write(OpDef* op_def, const ApiDef& api_def){
54+
auto *refl = op_def->GetReflection();
55+
refl->MutableUnknownFields(op_def)->AddLengthDelimited(100, api_def.SerializeAsString());
56+
}
57+
58+
Status UpdateOpDefs(OpList* op_list, const std::vector<tensorflow::string>& api_dirs_, Env* env_) {
59+
ApiDefMap api_map(*op_list);
60+
if (!api_dirs_.empty()) {
61+
// Only load api files that correspond to the requested "op_list"
62+
for (const auto& op : op_list->op()) {
63+
for (const auto& api_def_dir : api_dirs_) {
64+
const std::string api_def_file_pattern =
65+
io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt");
66+
if (env_->FileExists(api_def_file_pattern).ok()) {
67+
TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern))
68+
<< api_def_file_pattern;
69+
}
70+
}
71+
}
72+
}
73+
api_map.UpdateDocs();
74+
75+
for (int i = 0 ; i < op_list->op_size() ; i++) {
76+
OpDef *op_def = op_list->mutable_op(i);
77+
const ApiDef* api_def = api_map.GetApiDef(op_def->name());
78+
Write(op_def, *api_def);
79+
}
80+
return Status::OK();
81+
}
82+
83+
}
84+
}
85+
86+
// See usage header.
87+
// Writes an OpList proto to stdout, with each OpDef having its ApiDef in field 100
88+
int main(int argc, char* argv[]) {
89+
tensorflow::string api_dirs_str;
90+
std::vector<tensorflow::Flag> flag_list = {
91+
tensorflow::Flag(
92+
"api_dirs", &api_dirs_str,
93+
"List of directories that contain the ops API definitions protos")};
94+
tensorflow::string usage = tensorflow::java::kUsageHeader;
95+
usage += tensorflow::Flags::Usage(
96+
tensorflow::string(argv[0]) + " <ops library paths...>", flag_list);
97+
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
98+
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
99+
QCHECK(parsed_flags_ok && argc > 1) << usage;
100+
std::vector<tensorflow::string> api_dirs = tensorflow::str_util::Split(
101+
api_dirs_str, ",", tensorflow::str_util::SkipEmpty());
102+
103+
tensorflow::Env* env = tensorflow::Env::Default();
104+
void* ops_libs_handles[50];
105+
for (int i = 1; i < argc; ++i) {
106+
TF_CHECK_OK(env->LoadDynamicLibrary(argv[1], &ops_libs_handles[i - 1]));
107+
}
108+
tensorflow::OpList ops;
109+
tensorflow::OpRegistry::Global()->Export(false, &ops);
110+
TF_CHECK_OK(tensorflow::java::UpdateOpDefs(&ops, api_dirs, env));
111+
112+
113+
std::ostream & out = std::cout;
114+
ops.SerializeToOstream(&out);
115+
116+
return 0;
117+
}

tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,10 @@ public final class Ops {
358358

359359
public final TpuOps tpu;
360360

361-
public final AudioOps audio;
362-
363361
public final MathOps math;
364362

363+
public final AudioOps audio;
364+
365365
public final SignalOps signal;
366366

367367
public final TrainOps train;
@@ -387,8 +387,8 @@ private Ops(Scope scope) {
387387
sparse = new SparseOps(this);
388388
bitwise = new BitwiseOps(this);
389389
tpu = new TpuOps(this);
390-
audio = new AudioOps(this);
391390
math = new MathOps(this);
391+
audio = new AudioOps(this);
392392
signal = new SignalOps(this);
393393
train = new TrainOps(this);
394394
quantization = new QuantizationOps(this);
Binary file not shown.

0 commit comments

Comments
 (0)