|
| 1 | +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| 2 | +
|
| 3 | + Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | + you may not use this file except in compliance with the License. |
| 5 | + You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | + Unless required by applicable law or agreed to in writing, software |
| 10 | + distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | + See the License for the specific language governing permissions and |
| 13 | + limitations under the License. |
| 14 | +==============================================================================*/ |
| 15 | + |
| 16 | +#include <string> |
| 17 | +#include <vector> |
| 18 | +#include <stdio.h> |
| 19 | + |
| 20 | +#include "tensorflow/core/framework/op.h" |
| 21 | +#include "tensorflow/core/lib/core/status.h" |
| 22 | +#include "tensorflow/core/lib/strings/str_util.h" |
| 23 | +#include "tensorflow/core/platform/env.h" |
| 24 | +#include "tensorflow/core/platform/init_main.h" |
| 25 | +#include "tensorflow/core/util/command_line_flags.h" |
| 26 | +#include "tensorflow/core/framework/api_def.pb.h" |
| 27 | +#include "tensorflow/core/framework/op_def.pb.h" |
| 28 | +#include "tensorflow/core/lib/core/status.h" |
| 29 | +#include "tensorflow/core/lib/core/errors.h" |
| 30 | +#include "tensorflow/core/lib/io/path.h" |
| 31 | +#include "tensorflow/core/framework/op_gen_lib.h" |
| 32 | +#include "google/protobuf/unknown_field_set.h" |
| 33 | + |
| 34 | +namespace tensorflow { |
| 35 | +namespace java { |
| 36 | + |
| 37 | +const char kUsageHeader[] = |
| 38 | + "\n\nExporter of operation and API defs, for use in Java op generation.\n\n" |
| 39 | + "This executable exports the op def and api def protos for all operations " |
| 40 | + "registered in the provided list of libraries. The proto will be printed " |
| 41 | + "to stdout in binary format. It is an OpList proto, with each OpDef having" |
| 42 | + " the associated ApiDef attached as unknown field 100\n\n" |
| 43 | + "The first argument is the location of the tensorflow binary built for TF-" |
| 44 | + "Java.\nFor example, `bazel-out/k8-opt/bin/external/org_tensorflow/tensorfl" |
| 45 | + "ow/libtensorflow_cc.so`.\n\n" |
| 46 | + "Finally, the `--api_dirs` argument takes a list of comma-separated " |
| 47 | + "directories of API definitions can be provided to override default\n" |
| 48 | + "values found in the ops definitions. Directories are ordered by priority " |
| 49 | + "(the last having precedence over the first).\nFor example, `bazel-tensorf" |
| 50 | + "low-core-api/external/org_tensorflow/tensorflow/core/api_def/base_api,src" |
| 51 | + "/bazel/api_def`\n\n"; |
| 52 | + |
| 53 | +void Write(OpDef* op_def, const ApiDef& api_def){ |
| 54 | + auto *refl = op_def->GetReflection(); |
| 55 | + refl->MutableUnknownFields(op_def)->AddLengthDelimited(100, api_def.SerializeAsString()); |
| 56 | +} |
| 57 | + |
| 58 | +Status UpdateOpDefs(OpList* op_list, const std::vector<tensorflow::string>& api_dirs_, Env* env_) { |
| 59 | + ApiDefMap api_map(*op_list); |
| 60 | + if (!api_dirs_.empty()) { |
| 61 | + // Only load api files that correspond to the requested "op_list" |
| 62 | + for (const auto& op : op_list->op()) { |
| 63 | + for (const auto& api_def_dir : api_dirs_) { |
| 64 | + const std::string api_def_file_pattern = |
| 65 | + io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt"); |
| 66 | + if (env_->FileExists(api_def_file_pattern).ok()) { |
| 67 | + TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern)) |
| 68 | + << api_def_file_pattern; |
| 69 | + } |
| 70 | + } |
| 71 | + } |
| 72 | + } |
| 73 | + api_map.UpdateDocs(); |
| 74 | + |
| 75 | + for (int i = 0 ; i < op_list->op_size() ; i++) { |
| 76 | + OpDef *op_def = op_list->mutable_op(i); |
| 77 | + const ApiDef* api_def = api_map.GetApiDef(op_def->name()); |
| 78 | + Write(op_def, *api_def); |
| 79 | + } |
| 80 | + return Status::OK(); |
| 81 | +} |
| 82 | + |
| 83 | +} |
| 84 | +} |
| 85 | + |
| 86 | +// See usage header. |
| 87 | +// Writes an OpList proto to stdout, with each OpDef having its ApiDef in field 100 |
| 88 | +int main(int argc, char* argv[]) { |
| 89 | + tensorflow::string api_dirs_str; |
| 90 | + std::vector<tensorflow::Flag> flag_list = { |
| 91 | + tensorflow::Flag( |
| 92 | + "api_dirs", &api_dirs_str, |
| 93 | + "List of directories that contain the ops API definitions protos")}; |
| 94 | + tensorflow::string usage = tensorflow::java::kUsageHeader; |
| 95 | + usage += tensorflow::Flags::Usage( |
| 96 | + tensorflow::string(argv[0]) + " <ops library paths...>", flag_list); |
| 97 | + bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); |
| 98 | + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); |
| 99 | + QCHECK(parsed_flags_ok && argc > 1) << usage; |
| 100 | + std::vector<tensorflow::string> api_dirs = tensorflow::str_util::Split( |
| 101 | + api_dirs_str, ",", tensorflow::str_util::SkipEmpty()); |
| 102 | + |
| 103 | + tensorflow::Env* env = tensorflow::Env::Default(); |
| 104 | + void* ops_libs_handles[50]; |
| 105 | + for (int i = 1; i < argc; ++i) { |
| 106 | + TF_CHECK_OK(env->LoadDynamicLibrary(argv[1], &ops_libs_handles[i - 1])); |
| 107 | + } |
| 108 | + tensorflow::OpList ops; |
| 109 | + tensorflow::OpRegistry::Global()->Export(false, &ops); |
| 110 | + TF_CHECK_OK(tensorflow::java::UpdateOpDefs(&ops, api_dirs, env)); |
| 111 | + |
| 112 | + |
| 113 | + std::ostream & out = std::cout; |
| 114 | + ops.SerializeToOstream(&out); |
| 115 | + |
| 116 | + return 0; |
| 117 | +} |
0 commit comments