11/*
2- Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+ Copyright 2021 The TensorFlow Authors. All Rights Reserved.
33
4- Licensed under the Apache License, Version 2.0 (the "License");
5- you may not use this file except in compliance with the License.
6- You may obtain a copy of the License at
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
77
8- http://www.apache.org/licenses/LICENSE-2.0
8+ http://www.apache.org/licenses/LICENSE-2.0
99
10- Unless required by applicable law or agreed to in writing, software
11- distributed under the License is distributed on an "AS IS" BASIS,
12- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13- See the License for the specific language governing permissions and
14- limitations under the License.
15- ==============================================================================
16- */
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
15+ ==============================================================================
16+ */
1717package org .tensorflow .op .generator ;
1818
1919import com .google .protobuf .InvalidProtocolBufferException ;
2020import com .squareup .javapoet .JavaFile ;
2121import com .squareup .javapoet .TypeSpec ;
22+ import org .tensorflow .proto .framework .ApiDef ;
23+ import org .tensorflow .proto .framework .OpDef ;
24+ import org .tensorflow .proto .framework .OpList ;
25+
2226import java .io .File ;
2327import java .io .FileInputStream ;
2428import java .io .FileNotFoundException ;
3236import java .nio .file .attribute .BasicFileAttributes ;
3337import java .util .LinkedHashMap ;
3438import java .util .Map ;
35- import org .tensorflow .proto .framework .ApiDef ;
36- import org .tensorflow .proto .framework .OpDef ;
37- import org .tensorflow .proto .framework .OpList ;
3839
3940public final class OpGenerator {
4041
4142 private static final String LICENSE =
42- "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n " +
43- "\n " +
44- "Licensed under the Apache License, Version 2.0 (the \" License\" );\n " +
45- "you may not use this file except in compliance with the License.\n " +
46- "You may obtain a copy of the License at\n " +
47- "\n " +
48- " http://www.apache.org/licenses/LICENSE-2.0\n " +
49- "\n " +
50- "Unless required by applicable law or agreed to in writing, software\n " +
51- "distributed under the License is distributed on an \" AS IS\" BASIS,\n " +
52- "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n " +
53- "See the License for the specific language governing permissions and\n " +
54- "limitations under the License.\n " +
55- "=======================================================================*/" +
56- "\n " ;
43+ "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n "
44+ + "\n "
45+ + "Licensed under the Apache License, Version 2.0 (the \" License\" );\n "
46+ + "you may not use this file except in compliance with the License.\n "
47+ + "You may obtain a copy of the License at\n "
48+ + "\n "
49+ + " http://www.apache.org/licenses/LICENSE-2.0\n "
50+ + "\n "
51+ + "Unless required by applicable law or agreed to in writing, software\n "
52+ + "distributed under the License is distributed on an \" AS IS\" BASIS,\n "
53+ + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n "
54+ + "See the License for the specific language governing permissions and\n "
55+ + "limitations under the License.\n "
56+ + "=======================================================================*/"
57+ + "\n " ;
5758
5859 private static final String HELP_TEXT = "Args should be: <outputDir> <opDefFile> [base_package]" ;
5960
@@ -62,9 +63,9 @@ public final class OpGenerator {
6263 /**
6364 * Args should be {@code <outputDir> <opDefFile> [base_package]}.
6465 *
65- * {@code base_package} is {@code org.tensorflow.op} by default.
66- * <p>
67- * <b>Will delete everything in {@code outputDir}.</b>
66+ * <p> {@code base_package} is {@code org.tensorflow.op} by default.
67+ *
68+ * <p>< b>Will delete everything in {@code outputDir}.</b>
6869 */
6970 public static void main (String [] args ) {
7071
@@ -106,19 +107,23 @@ public static void main(String[] args) {
106107
107108 if (basePackage .exists ()) {
108109 try {
109- Files .walkFileTree (basePackage .toPath (), new SimpleFileVisitor <Path >() {
110- @ Override
111- public FileVisitResult visitFile (Path file , BasicFileAttributes attrs ) throws IOException {
112- Files .delete (file );
113- return FileVisitResult .CONTINUE ;
114- }
115-
116- @ Override
117- public FileVisitResult postVisitDirectory (Path dir , IOException exc ) throws IOException {
118- Files .delete (dir );
119- return FileVisitResult .CONTINUE ;
120- }
121- });
110+ Files .walkFileTree (
111+ basePackage .toPath (),
112+ new SimpleFileVisitor <Path >() {
113+ @ Override
114+ public FileVisitResult visitFile (Path file , BasicFileAttributes attrs )
115+ throws IOException {
116+ Files .delete (file );
117+ return FileVisitResult .CONTINUE ;
118+ }
119+
120+ @ Override
121+ public FileVisitResult postVisitDirectory (Path dir , IOException exc )
122+ throws IOException {
123+ Files .delete (dir );
124+ return FileVisitResult .CONTINUE ;
125+ }
126+ });
122127 } catch (IOException ignored ) {
123128
124129 }
@@ -127,9 +132,7 @@ public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOEx
127132 generate (outputDir , packageName , inputFile );
128133 }
129134
130- /**
131- * Build the list of ops and api defs, then call {@link #generate(File, String, Map)}
132- */
135+ /** Build the list of ops and api defs, then call {@link #generate(File, String, Map)} */
133136 private static void generate (File outputDir , String packageName , File opDefs ) {
134137 OpList opList = null ;
135138 try {
@@ -140,87 +143,121 @@ private static void generate(File outputDir, String packageName, File opDefs) {
140143 System .exit (1 );
141144 } catch (IOException e ) {
142145 throw new RuntimeException (
143- "Error parsing op def file " + opDefs + ", was it generated by op_export_main.cc (in tensorflow-core-api)?" ,
146+ "Error parsing op def file "
147+ + opDefs
148+ + ", was it generated by op_export_main.cc (in tensorflow-core-api)?" ,
144149 e );
145150 }
146151 Map <OpDef , ApiDef > defs = new LinkedHashMap <>(opList .getOpCount ());
147152
148153 for (OpDef op : opList .getOpList ()) {
149154 try {
150155 if (!op .getUnknownFields ().hasField (API_DEF_FIELD_NUMBER )) {
151- throw new RuntimeException ("No attached ApiDef for op " + op .getName () + ", was " + opDefs
152- + " generated by op_export_main.cc (in tensorflow-core-api)? The op's ApiDef should be"
153- + " attached as unknown field " + API_DEF_FIELD_NUMBER + "." );
156+ throw new RuntimeException (
157+ "No attached ApiDef for op "
158+ + op .getName ()
159+ + ", was "
160+ + opDefs
161+ + " generated by op_export_main.cc (in tensorflow-core-api)? The op's ApiDef should be"
162+ + " attached as unknown field "
163+ + API_DEF_FIELD_NUMBER
164+ + "." );
154165 }
155- ApiDef api = ApiDef
156- .parseFrom (op .getUnknownFields ().getField (API_DEF_FIELD_NUMBER ).getLengthDelimitedList ().get (0 ));
166+ ApiDef api =
167+ ApiDef .parseFrom (
168+ op .getUnknownFields ()
169+ .getField (API_DEF_FIELD_NUMBER )
170+ .getLengthDelimitedList ()
171+ .get (0 ));
157172 defs .put (op , api );
158173 } catch (InvalidProtocolBufferException e ) {
159- throw new RuntimeException ("Could not parse attached ApiDef for op " + op .getName () + ", was " + opDefs
160- + " generated by op_export_main.cc (in tensorflow-core-api)?" , e );
174+ throw new RuntimeException (
175+ "Could not parse attached ApiDef for op "
176+ + op .getName ()
177+ + ", was "
178+ + opDefs
179+ + " generated by op_export_main.cc (in tensorflow-core-api)?" ,
180+ e );
161181 }
162182 }
163183
164184 generate (outputDir , packageName , defs );
165185 }
166186
167- /**
168- * Generate all the ops that pass {@link ClassGenerator#canGenerateOp(OpDef, ApiDef)}.
169- */
187+ /** Generate all the ops that pass {@link ClassGenerator#canGenerateOp(OpDef, ApiDef)}. */
170188 private static void generate (File outputDir , String basePackage , Map <OpDef , ApiDef > ops ) {
171- ops .entrySet ().stream ().filter (e -> ClassGenerator .canGenerateOp (e .getKey (), e .getValue ())).forEach ((entry ) -> {
172- entry .getValue ().getEndpointList ().forEach ((endpoint ) -> {
173-
174- String name ;
175- String pack ;
176-
177- int pos = endpoint .getName ().lastIndexOf ('.' );
178- if (pos > -1 ) {
179- pack = endpoint .getName ().substring (0 , pos );
180- name = endpoint .getName ().substring (pos + 1 );
181- } else {
182- pack = "core" ;
183- name = endpoint .getName ();
184- }
185-
186- TypeSpec .Builder cls = TypeSpec .classBuilder (name );
187- try {
188- new ClassGenerator (
189- cls ,
190- entry .getKey (),
191- entry .getValue (),
192- basePackage ,
193- basePackage + "." + pack ,
194- pack ,
195- name ,
196- endpoint ).buildClass ();
197- } catch (Exception e ) {
198- throw new IllegalStateException ("Failed to generate class for op " + entry .getKey ().getName (), e );
199- }
200- TypeSpec spec = cls .build ();
201-
202- JavaFile file = JavaFile .builder (basePackage + "." + pack , spec )
203- .indent (" " )
204- .skipJavaLangImports (true )
205- .build ();
206-
207- File outputFile = new File (outputDir , basePackage .replace ('.' , '/' ) +
208- '/' + pack .replace ('.' , '/' ) + '/' + spec .name + ".java" );
209- outputFile .getParentFile ().mkdirs ();
210- try {
211- StringBuilder builder = new StringBuilder ();
212- builder .append (LICENSE + '\n' );
213- builder .append ("// This class has been generated, DO NOT EDIT!\n \n " );
214- file .writeTo (builder );
215-
216- Files .write (outputFile .toPath (), builder .toString ().getBytes (StandardCharsets .UTF_8 ), StandardOpenOption .WRITE ,
217- StandardOpenOption .CREATE , StandardOpenOption .TRUNCATE_EXISTING );
218- } catch (IOException ioException ) {
219- throw new IllegalStateException ("Failed to write file " + outputFile , ioException );
220- }
221- });
222- });
189+ ops .entrySet ().stream ()
190+ .filter (e -> ClassGenerator .canGenerateOp (e .getKey (), e .getValue ()))
191+ .forEach (
192+ (entry ) -> {
193+ entry
194+ .getValue ()
195+ .getEndpointList ()
196+ .forEach (
197+ (endpoint ) -> {
198+ String name ;
199+ String pack ;
200+
201+ int pos = endpoint .getName ().lastIndexOf ('.' );
202+ if (pos > -1 ) {
203+ pack = endpoint .getName ().substring (0 , pos );
204+ name = endpoint .getName ().substring (pos + 1 );
205+ } else {
206+ pack = "core" ;
207+ name = endpoint .getName ();
208+ }
209+
210+ TypeSpec .Builder cls = TypeSpec .classBuilder (name );
211+ try {
212+ new ClassGenerator (
213+ cls ,
214+ entry .getKey (),
215+ entry .getValue (),
216+ basePackage ,
217+ basePackage + "." + pack ,
218+ pack ,
219+ name ,
220+ endpoint )
221+ .buildClass ();
222+ } catch (Exception e ) {
223+ throw new IllegalStateException (
224+ "Failed to generate class for op " + entry .getKey ().getName (), e );
225+ }
226+ TypeSpec spec = cls .build ();
227+
228+ JavaFile file =
229+ JavaFile .builder (basePackage + "." + pack , spec )
230+ .indent (" " )
231+ .skipJavaLangImports (true )
232+ .build ();
233+
234+ File outputFile =
235+ new File (
236+ outputDir ,
237+ basePackage .replace ('.' , '/' )
238+ + '/'
239+ + pack .replace ('.' , '/' )
240+ + '/'
241+ + spec .name
242+ + ".java" );
243+ outputFile .getParentFile ().mkdirs ();
244+ try {
245+ StringBuilder builder = new StringBuilder ();
246+ builder .append (LICENSE + '\n' );
247+ builder .append ("// This class has been generated, DO NOT EDIT!\n \n " );
248+ file .writeTo (builder );
249+
250+ Files .write (
251+ outputFile .toPath (),
252+ builder .toString ().getBytes (StandardCharsets .UTF_8 ),
253+ StandardOpenOption .WRITE ,
254+ StandardOpenOption .CREATE ,
255+ StandardOpenOption .TRUNCATE_EXISTING );
256+ } catch (IOException ioException ) {
257+ throw new IllegalStateException (
258+ "Failed to write file " + outputFile , ioException );
259+ }
260+ });
261+ });
223262 }
224-
225-
226263}
0 commit comments