diff --git a/WORKSPACE b/WORKSPACE index 98220484..2cd1eb60 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -120,3 +120,41 @@ python_configure(name = "local_config_python") register_toolchains("@local_config_python//:toolchain") tf_configure(name = "local_config_tf") + +# Java binding support. +http_archive( + name = "fmeum_rules_jni", + sha256 = "8d685e381cb625e11fac330085de2ebc13ad497d30c4e9b09beb212f7c27e8e7", + url = "https://github.com/fmeum/rules_jni/releases/download/v0.3.0/rules_jni-v0.3.0.tar.gz", +) + +load("@fmeum_rules_jni//jni:repositories.bzl", "rules_jni_dependencies") + +rules_jni_dependencies() + +load("@bazel_tools//tools/build_defs/repo:java.bzl", "java_import_external") + +java_import_external( + name = "junit_long", + jar_sha256 = "59721f0805e223d84b90677887d9ff567dc534d7c502ca903c0c2b17f05c116a", + jar_urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar", + "https://repo1.maven.org/maven2/junit/junit/4.12/junit-4.12.jar", + "https://maven.ibiblio.org/maven2/junit/junit/4.12/junit-4.12.jar", + ], + licenses = ["reciprocal"], # Common Public License Version 1.0 + testonly_ = True, + deps = ["@org_hamcrest_core"], +) + +java_import_external( + name = "org_hamcrest_core", + jar_sha256 = "66fdef91e9739348df7a096aa384a5685f4e875584cce89386a7a47251c4d8e9", + jar_urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", + "https://repo1.maven.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", + "https://maven.ibiblio.org/maven2/org/hamcrest/hamcrest-core/1.3/hamcrest-core-1.3.jar", + ], + licenses = ["notice"], # New BSD License + testonly_ = True, +) diff --git a/java/com/google/riegeli/BUILD b/java/com/google/riegeli/BUILD new file mode 100644 index 00000000..7a0e0d0d --- /dev/null +++ b/java/com/google/riegeli/BUILD @@ -0,0 +1,69 @@ +load("@fmeum_rules_jni//jni:defs.bzl", "cc_jni_library", "java_jni_library", "jni_headers") + +# Java interface for that will be implemented using JNI later. +java_library( + name = "wrapper", + srcs = [ + "RecordReader.java", + "RecordWriter.java", + ], +) + +# Generate the native header. +jni_headers( + name = "header", + lib = ":wrapper", +) + +cc_jni_library( + name = "riegeli_jni", + srcs = [ + "jni_record_reader.cc", + "jni_record_writer.cc", + ], + visibility = [ + ], + deps = [ + ":header", + "//riegeli/bytes:fd_reader", + "//riegeli/bytes:fd_writer", + "//riegeli/records:record_reader", + "//riegeli/records:record_writer", + ], +) + +java_jni_library( + name = "loader", + srcs = [ + "Loader.java", + ], + native_libs = [ + ":riegeli_jni", + ], + visibility = [ + ], + deps = [ + ":wrapper", + ], +) + +java_library( + name = "tests", + testonly = 1, + srcs = [ + "RecordReadWriteTest.java", + ], + deps = [ + ":loader", + ":wrapper", + "@junit_long", + "@org_hamcrest_core", + ], +) + +java_test( + name = "RecordReadWriteTest", + runtime_deps = [ + ":tests", + ], +) diff --git a/java/com/google/riegeli/Loader.java b/java/com/google/riegeli/Loader.java new file mode 100644 index 00000000..f1c721d8 --- /dev/null +++ b/java/com/google/riegeli/Loader.java @@ -0,0 +1,19 @@ +package com.google.riegeli; + +import com.github.fmeum.rules_jni.RulesJni; + +public class Loader { + // Not sure whether it's worth the redirection to put JNI native lib loading logic in a single class. + static { + RulesJni.loadLibrary("riegeli_jni", RecordReader.class); + RulesJni.loadLibrary("riegeli_jni", RecordWriter.class); + } + + public final static RecordWriter newWriter() { + return new RecordWriter(); + } + + public static RecordReader newReader() { + return new RecordReader(); + } +} diff --git a/java/com/google/riegeli/RecordReadWriteTest.java b/java/com/google/riegeli/RecordReadWriteTest.java new file mode 100644 index 00000000..9a5efe55 --- /dev/null +++ b/java/com/google/riegeli/RecordReadWriteTest.java @@ -0,0 +1,46 @@ +package com.google.riegeli; + +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +public class RecordReadWriteTest { + + @Before + public void setUp() { + } + + private String createTestString(int length) { + // String.repeat is only available from Java 11, so create a + // helper method instead. + return String.join("", java.util.Collections.nCopies(length, "a")); + } + + @Test + public void writeWriteString() throws IOException { + // TODO: create a random file on TEST_TEMP directory. + final String filename = "/tmp/test.rg"; + RecordWriter writer = Loader.newWriter(); + writer.open(filename, "default"); + final int kNumRecords = 4096; + for (int i = 0; i < kNumRecords; i++) { + final String s = createTestString(i+1); + writer.writeRecord(s); + } + writer.close(); + + RecordReader reader = Loader.newReader(); + reader.open(filename); + for (int i = 0; i < kNumRecords; i++) { + byte[] record = reader.readRecord(); + final String s = createTestString(i+1); + assertEquals(new String(record), s); + } + byte[] record = reader.readRecord(); + assertEquals(null, record); + reader.close(); + } +} diff --git a/java/com/google/riegeli/RecordReader.java b/java/com/google/riegeli/RecordReader.java new file mode 100644 index 00000000..efdb3279 --- /dev/null +++ b/java/com/google/riegeli/RecordReader.java @@ -0,0 +1,19 @@ +package com.google.riegeli; + +import java.io.IOException; + +// JNI wrapper for riegeli record reader. +public class RecordReader { + + public final static class Options { + // Nothing is supported for now. + } + + private long recordReaderPtr; + + public native void open(String filename) throws IOException; + + public native byte[] readRecord(); + + public native void close() throws IOException; +} diff --git a/java/com/google/riegeli/RecordWriter.java b/java/com/google/riegeli/RecordWriter.java new file mode 100644 index 00000000..1a4bc438 --- /dev/null +++ b/java/com/google/riegeli/RecordWriter.java @@ -0,0 +1,50 @@ +package com.google.riegeli; + +import java.io.IOException; + +// JNI wrapper for riegeli record writer. +public class RecordWriter { + + public final static class Options { + // Nothing is supported for now. + } + + // Pointer to the C++ object. + private long recordWriterPtr; + + // Options could be: + // ``` + // options ::= option? ("," option?)* + // option ::= + // "default" | + // "transpose" (":" ("true" | "false"))? | + // "uncompressed" | + // "brotli" (":" brotli_level)? | + // "zstd" (":" zstd_level)? | + // "snappy" | + // "window_log" ":" window_log | + // "chunk_size" ":" chunk_size | + // "bucket_fraction" ":" bucket_fraction | + // "pad_to_block_boundary" (":" ("true" | "false"))? | + // "parallelism" ":" parallelism + // brotli_level ::= integer in the range [0..11] (default 6) + // zstd_level ::= integer in the range [-131072..22] (default 3) + // window_log ::= "auto" or integer in the range [10..31] + // chunk_size ::= "auto" or positive integer expressed as real with + // optional suffix [BkKMGTPE] + // bucket_fraction ::= real in the range [0..1] + // parallelism ::= non-negative integer + // ``` + public native void open(String filename, String options) throws IOException; + + public void writeRecord(String record) throws IOException { + writeRecord(record.getBytes()); + } + + public native void writeRecord(byte[] record) throws IOException; + + // Flush the data into disk, more `writeRecord` can be called. + public native void flush() throws IOException; + + public native void close() throws IOException; +} diff --git a/java/com/google/riegeli/jni_record_reader.cc b/java/com/google/riegeli/jni_record_reader.cc new file mode 100644 index 00000000..4dea4b82 --- /dev/null +++ b/java/com/google/riegeli/jni_record_reader.cc @@ -0,0 +1,82 @@ +#include "com_google_riegeli_RecordReader.h" + +#include "riegeli/bytes/fd_reader.h" +#include "riegeli/records/record_reader.h" + +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: com_google_riegeli_RecordReader + * Method: open + * Signature: (Ljava/lang/String;)V + */ +using ReaderType = riegeli::RecordReader>; + +JNIEXPORT void JNICALL Java_com_google_riegeli_RecordReader_open(JNIEnv* env, jobject reader, jstring filename) { + /* Obtain a C-copy of the Java string */ + const char* fname = env->GetStringUTFChars(filename, nullptr); + + /* Create the recorder */ + riegeli::RecordReaderBase::Options record_reader_options; + auto* record_reader = new ReaderType( + std::forward_as_tuple(fname, O_RDONLY), + record_reader_options); + + env->ReleaseStringUTFChars(filename, fname); + + /* Get the Field ID of the instance variables "recordReaderPtr" */ + jfieldID fid = env->GetFieldID(env->GetObjectClass(reader), "recordReaderPtr", "J"); + + // Save the pointer as member. + env->SetLongField(reader, fid, reinterpret_cast(record_reader)); +} + +namespace { +ReaderType* getRecordReader(JNIEnv* env, jobject reader) { + jfieldID fid = env->GetFieldID(env->GetObjectClass(reader), "recordReaderPtr", "J"); + jlong ptr = env->GetLongField(reader, fid); + return reinterpret_cast(ptr); +} +} + +/* + * Class: com_google_riegeli_RecordReader + * Method: readRecord + * Signature: ()[B + */ +JNIEXPORT jbyteArray JNICALL Java_com_google_riegeli_RecordReader_readRecord(JNIEnv* env, jobject obj) { + auto* record_reader = getRecordReader(env, obj); + if (!record_reader) { + return nullptr; + } + std::string record; + if (record_reader->ReadRecord(record)) { + jbyteArray ret = env->NewByteArray(record.size()); + env->SetByteArrayRegion(ret, 0, record.size(), reinterpret_cast(record.data())); + return ret; + } else { + return nullptr; + } +} + +/* + * Class: com_google_riegeli_RecordReader + * Method: close + * Signature: ()V + */ +JNIEXPORT void JNICALL Java_com_google_riegeli_RecordReader_close(JNIEnv* env, jobject reader) { + auto* record_reader = getRecordReader(env, reader); + if (record_reader) { + record_reader->Close(); + delete record_reader; + + jfieldID fid = env->GetFieldID(env->GetObjectClass(reader), "recordReaderPtr", "J"); + env->SetLongField(reader, fid, 0L); + } +} + +#ifdef __cplusplus +} +#endif + diff --git a/java/com/google/riegeli/jni_record_writer.cc b/java/com/google/riegeli/jni_record_writer.cc new file mode 100644 index 00000000..c7e92c4e --- /dev/null +++ b/java/com/google/riegeli/jni_record_writer.cc @@ -0,0 +1,106 @@ +#include "com_google_riegeli_RecordWriter.h" + +#include "riegeli/bytes/fd_writer.h" +#include "riegeli/records/record_writer.h" + +#ifdef __cplusplus +extern "C" { +#endif + +using WriterType = riegeli::RecordWriter>; +/* + * Class: com_google_riegeli_RecordWriter + * Method: open + * Signature: (Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_com_google_riegeli_RecordWriter_open(JNIEnv* env, jobject writer, jstring filename, jstring options) { + // Prepare the options + riegeli::RecordWriterBase::Options record_writer_options; + { + const char* options_str = env->GetStringUTFChars(options, nullptr); + const auto status = record_writer_options.FromString(options_str); + env->ReleaseStringUTFChars(options, options_str); + } + + // Create the writer + const char* fname = env->GetStringUTFChars(filename, nullptr); + auto* record_writer = new WriterType( + std::forward_as_tuple(fname, O_WRONLY | O_CREAT | O_TRUNC), + record_writer_options); + env->ReleaseStringUTFChars(filename, fname); + + /* Get the Field ID of the instance variables "recordReaderPtr" */ + jfieldID fid = env->GetFieldID(env->GetObjectClass(writer), "recordWriterPtr", "J"); + + // Save the pointer as member. + env->SetLongField(writer, fid, reinterpret_cast(record_writer)); +} + +namespace { +WriterType* getRecordWriter(JNIEnv* env, jobject writer) { + jfieldID fid = env->GetFieldID(env->GetObjectClass(writer), "recordWriterPtr", "J"); + jlong ptr = env->GetLongField(writer, fid); + return reinterpret_cast(ptr); +} + +void throwException(JNIEnv* env, const char* exceptionClass, const char* message) { + jclass Exception = env->FindClass(exceptionClass); + env->ThrowNew(Exception, message); // Error Message +} + +void throwIllegalStateException(JNIEnv* env, const char* message) { + throwException(env, "java/lang/IllegalStateException", message); +} + +void throwIOException(JNIEnv* env, const char* message) { + throwException(env, "java/io/IOException", message); +} + +} + +JNIEXPORT void JNICALL Java_com_google_riegeli_RecordWriter_writeRecord( + JNIEnv* env, jobject writer, jbyteArray record) { + auto* native_writer = getRecordWriter(env, writer); + if (!native_writer) { + throwIllegalStateException(env, "open should have been called"); + return; + } + // TODO: throw a runtime exception if `open` method has not been + // called successfully. + const jint size = env->GetArrayLength(record); + jbyte* data = env->GetByteArrayElements(record, 0); + // TODO: check the return value and status + bool ret = native_writer->WriteRecord(absl::string_view(reinterpret_cast(data), size)); + env->ReleaseByteArrayElements(record, data, 0); + if (!ret) { + throwIOException(env, "Fail to write record"); + } +} + +JNIEXPORT void JNICALL Java_com_google_riegeli_RecordWriter_flush(JNIEnv* env, jobject writer) { + auto* record_writer = getRecordWriter(env, writer); + if (record_writer) { + record_writer->Flush(); + } else { + throwIllegalStateException(env, "open should have been called"); + } +} + +JNIEXPORT void JNICALL Java_com_google_riegeli_RecordWriter_close(JNIEnv* env, jobject writer) { + auto* record_writer = getRecordWriter(env, writer); + if (record_writer) { + record_writer->Close(); + delete record_writer; + + jfieldID fid = env->GetFieldID(env->GetObjectClass(writer), "recordWriterPtr", "J"); + env->SetLongField(writer, fid, 0L); + } else { + throwIllegalStateException(env, "open should have been called"); + } +} + + +#ifdef __cplusplus +} +#endif +