From e2c7ae58f089d332fdfb4302d591b66b3ada1556 Mon Sep 17 00:00:00 2001 From: punAhuja Date: Fri, 13 Dec 2024 15:58:55 +0530 Subject: [PATCH 1/7] Initial changes for converting cagra index to hnsw index for searching --- .../main/java/com/nvidia/cuvs/CagraIndex.java | 2 +- .../java/com/nvidia/cuvs/CuVSResources.java | 8 ++- .../java/com/nvidia/cuvs/common/Util.java | 35 ++++++++++++ .../java/com/nvidia/cuvs/TestCagraToHnsw.java | 57 +++++++++++++++++++ java/internal/src/cuvs_java.c | 31 ++++++++++ 5 files changed, 131 insertions(+), 2 deletions(-) create mode 100644 java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java index 67394b53b6..70be98d74a 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java @@ -380,7 +380,7 @@ protected IndexReference(MemorySegment indexMemorySegment) { * * @return index MemorySegment */ - protected MemorySegment getMemorySegment() { + public MemorySegment getMemorySegment() { return memorySegment; } } diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java index 60f68981b1..de729e390c 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java @@ -42,6 +42,7 @@ public class CuVSResources implements AutoCloseable { private final MethodHandle createResourcesMethodHandle; private final MethodHandle destroyResourcesMethodHandle; + public final MethodHandle cagraToHnswHandle; private MemorySegment resourcesMemorySegment; /** @@ -61,6 +62,11 @@ public CuVSResources() throws Throwable { destroyResourcesMethodHandle = linker.downcallHandle(libcuvsNativeLibrary.find("destroy_resources").get(), FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS)); createResources(); + cagraToHnswHandle = linker.downcallHandle( + libcuvsNativeLibrary.find("convert_cagra_to_hnsw") + .orElseThrow(() -> new IllegalStateException("convert_cagra_to_hnsw not found")), + FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS) + ); } /** @@ -91,7 +97,7 @@ public void close() { * * @return cuvsResources MemorySegment */ - protected MemorySegment getMemorySegment() { + public MemorySegment getMemorySegment() { return resourcesMemorySegment; } diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java index 750e49d642..aa4c74530e 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java @@ -20,15 +20,22 @@ import java.io.FileOutputStream; import java.io.IOException; import java.lang.foreign.Arena; +import java.lang.foreign.FunctionDescriptor; import java.lang.foreign.Linker; import java.lang.foreign.MemoryLayout; import java.lang.foreign.MemoryLayout.PathElement; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; +import java.lang.invoke.MethodHandle; import java.lang.invoke.VarHandle; +import java.nio.charset.StandardCharsets; import org.apache.commons.io.IOUtils; +import com.nvidia.cuvs.CagraIndex; +import java.nio.charset.StandardCharsets; +import com.nvidia.cuvs.CuVSResources; + public class Util { /** * A utility method for getting an instance of {@link MemorySegment} for a @@ -97,4 +104,32 @@ public static File loadLibraryFromJar(String path) throws IOException { return temp; } + + public static void convertCagraToHnsw(CuVSResources resources, String cagraFilePath, String hnswFilePath) { + try { + MethodHandle convertCagraToHnswHandle = resources.cagraToHnswHandle; + + byte[] cagraBytes = cagraFilePath.getBytes(StandardCharsets.UTF_8); + byte[] hnswBytes = hnswFilePath.getBytes(StandardCharsets.UTF_8); + + MemorySegment cagraFileSegment = resources.arena.allocate(cagraBytes.length + 1); + MemorySegment hnswFileSegment = resources.arena.allocate(hnswBytes.length + 1); + + cagraFileSegment.asSlice(0, cagraBytes.length).copyFrom(MemorySegment.ofArray(cagraBytes)); + cagraFileSegment.set(ValueLayout.JAVA_BYTE, cagraBytes.length, (byte) 0); + + hnswFileSegment.asSlice(0, hnswBytes.length).copyFrom(MemorySegment.ofArray(hnswBytes)); + hnswFileSegment.set(ValueLayout.JAVA_BYTE, hnswBytes.length, (byte) 0); + + int returnValue = (int) convertCagraToHnswHandle.invoke(resources.getMemorySegment(), cagraFileSegment, + hnswFileSegment); + + if (returnValue != 0) { + throw new RuntimeException("Failed to convert CAGRA to HNSW, error code: " + returnValue); + } + } catch (Throwable e) { + throw new RuntimeException("Error during CAGRA to HNSW conversion", e); + } + } + } diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java new file mode 100644 index 0000000000..e67ee14898 --- /dev/null +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java @@ -0,0 +1,57 @@ +package com.nvidia.cuvs; + +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; + +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.nvidia.cuvs.common.Util; + +public class TestCagraToHnsw { + + private static final Logger log = LoggerFactory.getLogger(TestCagraToHnsw.class); + + @Test + public void testCagraToHnswConversion() throws Throwable { + String cagraFilePath = "cagra_index.bin"; + String hnswFilePath = "hnsw_index.bin"; + + try (CuVSResources resources = new CuVSResources()) { + float[][] dataset = { + { 1.0f, 2.0f, 3.0f }, + { 4.0f, 5.0f, 6.0f }, + { 7.0f, 8.0f, 9.0f } + }; + + // Build CAGRA index + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + // Serialize CAGRA index to file using OutputStream + try (OutputStream outputStream = new FileOutputStream(cagraFilePath)) { + index.serialize(outputStream, new File(cagraFilePath)); + log.info("CAGRA index successfully created at: {}", cagraFilePath); + } + + Util.convertCagraToHnsw(resources, cagraFilePath, hnswFilePath); + log.info("HNSW index conversion completed. File created at: {}", hnswFilePath); + + File hnswFile = new File(hnswFilePath); + assertTrue("HNSW index file should exist", hnswFile.exists()); + + Path hnswPath = Path.of(hnswFilePath); + long hnswFileSize = Files.size(hnswPath); + assertTrue("HNSW index file should not be empty", hnswFileSize > 0); + } + } +} diff --git a/java/internal/src/cuvs_java.c b/java/internal/src/cuvs_java.c index b6a078e240..af4a51e742 100644 --- a/java/internal/src/cuvs_java.c +++ b/java/internal/src/cuvs_java.c @@ -21,6 +21,7 @@ #include #include #include +#include #define try bool __HadError=false; #define catch(x) ExitJmp:if(__HadError) @@ -113,3 +114,33 @@ void search_cagra_index(cuvsCagraIndex_t index, float *queries, int topk, long n cuvsRMMFree(cuvsResources, neighbors, sizeof(uint32_t) * n_queries * topk); cuvsRMMFree(cuvsResources, queries_d, sizeof(float) * n_queries * dimensions); } + +void convert_cagra_to_hnsw(cuvsResources_t resources, + const char* cagra_filename, + const char* hnsw_filename, + int* return_value) { + if (!resources || !cagra_filename || !hnsw_filename) { + *return_value = -1; // Invalid parameters + return; + } + + // Step 1: Create and deserialize the CAGRA index + cuvsCagraIndex_t cagra_index; + cuvsCagraIndexCreate(&cagra_index); + + *return_value = cuvsCagraDeserialize(resources, cagra_filename, cagra_index); + if (*return_value != 0) { + *return_value = -2; // CAGRA deserialization failed + cuvsCagraIndexDestroy(cagra_index); + return; + } + + // Step 2: Serialize the CAGRA index to an HNSW-compatible file + *return_value = cuvsCagraSerializeToHnswlib(resources, hnsw_filename, cagra_index); + if (*return_value != 0) { + *return_value = -3; // Serialization to HNSW failed + } + + // Cleanup + cuvsCagraIndexDestroy(cagra_index); +} From 9c507eef984d1044fb4846d5124379963e3778c2 Mon Sep 17 00:00:00 2001 From: punAhuja Date: Wed, 18 Dec 2024 13:40:41 +0530 Subject: [PATCH 2/7] Directly serialising to hnsw index file need to fix issues --- .../java/com/nvidia/cuvs/common/Util.java | 29 ++++++++++++-- .../java/com/nvidia/cuvs/TestCagraToHnsw.java | 40 +++++++++---------- java/internal/src/cuvs_java.c | 22 +++------- 3 files changed, 49 insertions(+), 42 deletions(-) diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java index aa4c74530e..6d9b93a0d9 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java @@ -20,7 +20,6 @@ import java.io.FileOutputStream; import java.io.IOException; import java.lang.foreign.Arena; -import java.lang.foreign.FunctionDescriptor; import java.lang.foreign.Linker; import java.lang.foreign.MemoryLayout; import java.lang.foreign.MemoryLayout.PathElement; @@ -32,11 +31,35 @@ import org.apache.commons.io.IOUtils; -import com.nvidia.cuvs.CagraIndex; -import java.nio.charset.StandardCharsets; import com.nvidia.cuvs.CuVSResources; public class Util { + + public static void serializeCagraToHnsw(CuVSResources resources, String cagraFilePath, String hnswFilePath) { + try (Arena arena = Arena.ofConfined()) { + MemorySegment cagraPathSegment = toCString(arena, cagraFilePath); + MemorySegment hnswPathSegment = toCString(arena, hnswFilePath); + + int result = (int) resources.cagraToHnswHandle.invokeExact(resources.getMemorySegment(), // resources + cagraPathSegment, // CAGRA index file path + hnswPathSegment // HNSW index file path + ); + + if (result != 0) { + throw new RuntimeException("Failed to serialize CAGRA index to HNSW file. Error code: " + result); + } + } catch (Throwable e) { + throw new RuntimeException("Error during serialization: " + e.getMessage(), e); + } + } + + private static MemorySegment toCString(Arena arena, String string) { + byte[] bytes = (string + "\0").getBytes(); + MemorySegment segment = arena.allocate(bytes.length); + segment.copyFrom(MemorySegment.ofArray(bytes)); + return segment; + } + /** * A utility method for getting an instance of {@link MemorySegment} for a * {@link String}. diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java index e67ee14898..4cc278f1fb 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java @@ -4,9 +4,6 @@ import java.io.File; import java.io.FileOutputStream; -import java.io.OutputStream; -import java.nio.file.Files; -import java.nio.file.Path; import org.junit.Test; import org.slf4j.Logger; @@ -19,39 +16,38 @@ public class TestCagraToHnsw { private static final Logger log = LoggerFactory.getLogger(TestCagraToHnsw.class); @Test - public void testCagraToHnswConversion() throws Throwable { + public void testSerialization() throws Throwable { String cagraFilePath = "cagra_index.bin"; String hnswFilePath = "hnsw_index.bin"; try (CuVSResources resources = new CuVSResources()) { - float[][] dataset = { - { 1.0f, 2.0f, 3.0f }, - { 4.0f, 5.0f, 6.0f }, - { 7.0f, 8.0f, 9.0f } + // Build and serialize a CAGRA index + float[][] dataset = { + {1.0f, 2.0f, 3.0f}, + {4.0f, 5.0f, 6.0f}, + {7.0f, 8.0f, 9.0f} }; - // Build CAGRA index - CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); - CagraIndex index = new CagraIndex.Builder(resources) + CagraIndex cagraIndex = new CagraIndex.Builder(resources) .withDataset(dataset) - .withIndexParams(indexParams) + .withIndexParams(new CagraIndexParams.Builder(resources).build()) .build(); - // Serialize CAGRA index to file using OutputStream - try (OutputStream outputStream = new FileOutputStream(cagraFilePath)) { - index.serialize(outputStream, new File(cagraFilePath)); - log.info("CAGRA index successfully created at: {}", cagraFilePath); + // Create a temporary file for intermediate serialization + File tempFile = File.createTempFile("cagra_temp_", ".tmp"); + + // Serialize CAGRA index to a file + try (FileOutputStream outputStream = new FileOutputStream(cagraFilePath)) { + cagraIndex.serialize(outputStream, tempFile); } - Util.convertCagraToHnsw(resources, cagraFilePath, hnswFilePath); - log.info("HNSW index conversion completed. File created at: {}", hnswFilePath); + // Convert the CAGRA index to HNSW format + Util.serializeCagraToHnsw(resources, cagraFilePath, hnswFilePath); + // Verify the HNSW index file File hnswFile = new File(hnswFilePath); assertTrue("HNSW index file should exist", hnswFile.exists()); - - Path hnswPath = Path.of(hnswFilePath); - long hnswFileSize = Files.size(hnswPath); - assertTrue("HNSW index file should not be empty", hnswFileSize > 0); + log.info("HNSW index file successfully created at: {}", hnswFilePath); } } } diff --git a/java/internal/src/cuvs_java.c b/java/internal/src/cuvs_java.c index af4a51e742..a77d7fdb88 100644 --- a/java/internal/src/cuvs_java.c +++ b/java/internal/src/cuvs_java.c @@ -66,8 +66,7 @@ cuvsCagraIndex_t build_cagra_index(float *dataset, long rows, long dimensions, c index_params->compression = compression_params; *returnValue = cuvsCagraBuild(cuvsResources, index_params, &dataset_tensor, index); - - omp_set_num_threads(1); + omp_set_num_threads(1); return index; } @@ -124,23 +123,12 @@ void convert_cagra_to_hnsw(cuvsResources_t resources, return; } - // Step 1: Create and deserialize the CAGRA index - cuvsCagraIndex_t cagra_index; - cuvsCagraIndexCreate(&cagra_index); - - *return_value = cuvsCagraDeserialize(resources, cagra_filename, cagra_index); + // direct conversion from CAGRA to HNSW + *return_value = cuvsCagraSerializeToHnswlib(resources, hnsw_filename, cagra_filename); if (*return_value != 0) { - *return_value = -2; // CAGRA deserialization failed - cuvsCagraIndexDestroy(cagra_index); + printf("Error: Failed to serialize directly to HNSW format\n"); return; } - // Step 2: Serialize the CAGRA index to an HNSW-compatible file - *return_value = cuvsCagraSerializeToHnswlib(resources, hnsw_filename, cagra_index); - if (*return_value != 0) { - *return_value = -3; // Serialization to HNSW failed - } - - // Cleanup - cuvsCagraIndexDestroy(cagra_index); + printf("Successfully serialized CAGRA index to HNSW format: %s\n", hnsw_filename); } From fce505988b474b0319fc883a05c0a4ae9885682c Mon Sep 17 00:00:00 2001 From: punAhuja Date: Wed, 18 Dec 2024 14:05:42 +0530 Subject: [PATCH 3/7] HNSW file generated now --- .../java/com/nvidia/cuvs/common/Util.java | 4 +-- .../java/com/nvidia/cuvs/TestCagraToHnsw.java | 24 ++++++++++++---- java/internal/src/cuvs_java.c | 28 ++++++++++++++++--- 3 files changed, 44 insertions(+), 12 deletions(-) diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java index 6d9b93a0d9..f883416826 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java @@ -41,8 +41,8 @@ public static void serializeCagraToHnsw(CuVSResources resources, String cagraFil MemorySegment hnswPathSegment = toCString(arena, hnswFilePath); int result = (int) resources.cagraToHnswHandle.invokeExact(resources.getMemorySegment(), // resources - cagraPathSegment, // CAGRA index file path - hnswPathSegment // HNSW index file path + cagraPathSegment, + hnswPathSegment ); if (result != 0) { diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java index 4cc278f1fb..c079f9e86d 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java @@ -4,6 +4,8 @@ import java.io.File; import java.io.FileOutputStream; +import java.nio.file.Files; +import java.nio.file.Path; import org.junit.Test; import org.slf4j.Logger; @@ -21,33 +23,43 @@ public void testSerialization() throws Throwable { String hnswFilePath = "hnsw_index.bin"; try (CuVSResources resources = new CuVSResources()) { - // Build and serialize a CAGRA index + log.info("Starting CAGRA to HNSW test."); + + // Step 1: Build a sample dataset float[][] dataset = { {1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}, {7.0f, 8.0f, 9.0f} }; + log.info("Building CAGRA index..."); CagraIndex cagraIndex = new CagraIndex.Builder(resources) .withDataset(dataset) .withIndexParams(new CagraIndexParams.Builder(resources).build()) .build(); + log.info("CAGRA index built successfully."); - // Create a temporary file for intermediate serialization + log.info("Serializing CAGRA index to file: {}", cagraFilePath); File tempFile = File.createTempFile("cagra_temp_", ".tmp"); - - // Serialize CAGRA index to a file try (FileOutputStream outputStream = new FileOutputStream(cagraFilePath)) { cagraIndex.serialize(outputStream, tempFile); } + assertTrue("CAGRA index file should exist", new File(cagraFilePath).exists()); + log.info("CAGRA index serialized to: {}", cagraFilePath); - // Convert the CAGRA index to HNSW format + log.info("Converting CAGRA index to HNSW format..."); Util.serializeCagraToHnsw(resources, cagraFilePath, hnswFilePath); - // Verify the HNSW index file File hnswFile = new File(hnswFilePath); assertTrue("HNSW index file should exist", hnswFile.exists()); log.info("HNSW index file successfully created at: {}", hnswFilePath); + + Files.deleteIfExists(Path.of(cagraFilePath)); + Files.deleteIfExists(Path.of(hnswFilePath)); + log.info("Test completed successfully. Temporary files cleaned up."); + } catch (Exception e) { + log.error("Test failed: ", e); + throw e; // Re-throw exception to mark test as failed } } } diff --git a/java/internal/src/cuvs_java.c b/java/internal/src/cuvs_java.c index a77d7fdb88..88677248f3 100644 --- a/java/internal/src/cuvs_java.c +++ b/java/internal/src/cuvs_java.c @@ -120,15 +120,35 @@ void convert_cagra_to_hnsw(cuvsResources_t resources, int* return_value) { if (!resources || !cagra_filename || !hnsw_filename) { *return_value = -1; // Invalid parameters + printf("Error: Invalid parameters provided\n"); return; } - // direct conversion from CAGRA to HNSW - *return_value = cuvsCagraSerializeToHnswlib(resources, hnsw_filename, cagra_filename); - if (*return_value != 0) { - printf("Error: Failed to serialize directly to HNSW format\n"); + printf("Initializing conversion: CAGRA -> HNSW\n"); + printf("CAGRA file: %s\n", cagra_filename); + printf("HNSW file: %s\n", hnsw_filename); + + cuvsCagraIndex_t cagra_index; + cuvsCagraIndexCreate(&cagra_index); + + // Step 1: Deserialize the CAGRA index + *return_value = cuvsCagraDeserialize(resources, cagra_filename, cagra_index); + if (*return_value != CUVS_SUCCESS) { + printf("Error: Failed to deserialize CAGRA index. Error code: %d\n", *return_value); + cuvsCagraIndexDestroy(cagra_index); + return; + } + printf("Successfully deserialized CAGRA index\n"); + + // Step 2: Serialize to HNSW format + *return_value = cuvsCagraSerializeToHnswlib(resources, hnsw_filename, cagra_index); + if (*return_value != CUVS_SUCCESS) { + printf("Error: Failed to serialize to HNSW format. Error code: %d\n", *return_value); + cuvsCagraIndexDestroy(cagra_index); return; } printf("Successfully serialized CAGRA index to HNSW format: %s\n", hnsw_filename); + cuvsCagraIndexDestroy(cagra_index); } + From 7fb7be6e9c46fd07b298535b118301588bd16fba Mon Sep 17 00:00:00 2001 From: punAhuja Date: Fri, 20 Dec 2024 16:29:59 +0530 Subject: [PATCH 4/7] Implemented search on hnsw index that was generated, some issues persist --- .../java/com/nvidia/cuvs/CuVSResources.java | 10 +- .../java/com/nvidia/cuvs/common/Util.java | 22 +--- .../java/com/nvidia/cuvs/hnsw/HnswQuery.java | 26 ++++ .../cuvs/hnsw/HnswSearchParameters.java | 30 +++++ .../nvidia/cuvs/hnsw/HnswSearchResults.java | 32 +++++ .../java/com/nvidia/cuvs/hnsw/HnswUtil.java | 113 ++++++++++++++++++ .../java/com/nvidia/cuvs/TestCagraToHnsw.java | 30 ++++- java/internal/src/cuvs_java.c | 46 +++++++ 8 files changed, 284 insertions(+), 25 deletions(-) create mode 100644 java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswQuery.java create mode 100644 java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchParameters.java create mode 100644 java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchResults.java create mode 100644 java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java index de729e390c..f717d1232f 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java @@ -43,6 +43,7 @@ public class CuVSResources implements AutoCloseable { private final MethodHandle createResourcesMethodHandle; private final MethodHandle destroyResourcesMethodHandle; public final MethodHandle cagraToHnswHandle; + public final MethodHandle hnswSearchHandle; private MemorySegment resourcesMemorySegment; /** @@ -65,8 +66,13 @@ public CuVSResources() throws Throwable { cagraToHnswHandle = linker.downcallHandle( libcuvsNativeLibrary.find("convert_cagra_to_hnsw") .orElseThrow(() -> new IllegalStateException("convert_cagra_to_hnsw not found")), - FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS) - ); + FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS)); + hnswSearchHandle = linker.downcallHandle( + libcuvsNativeLibrary.find("cuvsHnswSearch") + .orElseThrow(() -> new IllegalStateException("cuvsHnswSearch not found")), + FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, + ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS)); + } /** diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java index f883416826..efbc68c0f3 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java @@ -34,26 +34,8 @@ import com.nvidia.cuvs.CuVSResources; public class Util { - - public static void serializeCagraToHnsw(CuVSResources resources, String cagraFilePath, String hnswFilePath) { - try (Arena arena = Arena.ofConfined()) { - MemorySegment cagraPathSegment = toCString(arena, cagraFilePath); - MemorySegment hnswPathSegment = toCString(arena, hnswFilePath); - - int result = (int) resources.cagraToHnswHandle.invokeExact(resources.getMemorySegment(), // resources - cagraPathSegment, - hnswPathSegment - ); - - if (result != 0) { - throw new RuntimeException("Failed to serialize CAGRA index to HNSW file. Error code: " + result); - } - } catch (Throwable e) { - throw new RuntimeException("Error during serialization: " + e.getMessage(), e); - } - } - - private static MemorySegment toCString(Arena arena, String string) { + + public static MemorySegment toCString(Arena arena, String string) { byte[] bytes = (string + "\0").getBytes(); MemorySegment segment = arena.allocate(bytes.length); segment.copyFrom(MemorySegment.ofArray(bytes)); diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswQuery.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswQuery.java new file mode 100644 index 0000000000..4ff8972c10 --- /dev/null +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswQuery.java @@ -0,0 +1,26 @@ +package com.nvidia.cuvs.hnsw; + +public class HnswQuery { + private final float[][] queryVectors; + private final int topK; + private final HnswSearchParameters searchParameters; + + public HnswQuery(float[][] queryVectors, int topK, HnswSearchParameters searchParams) { + this.queryVectors = queryVectors; + this.topK = topK; + this.searchParameters = searchParams; + } + + public float[][] getQueryVectors() { + return queryVectors; + } + + public int getTopK() { + return topK; + } + + public HnswSearchParameters getSearchParameters() { + return searchParameters; + } + +} diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchParameters.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchParameters.java new file mode 100644 index 0000000000..717706ff6f --- /dev/null +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchParameters.java @@ -0,0 +1,30 @@ +package com.nvidia.cuvs.hnsw; + +public class HnswSearchParameters { + + private int ef; + private int numThreads; + + public HnswSearchParameters(int ef, int numThreads) { + this.ef = ef; + this.numThreads = numThreads; + } + + public HnswSearchParameters withEf(int ef) { + this.ef = ef; + return this; + } + + public HnswSearchParameters withNumThreads(int threads) { + this.numThreads = threads; + return this; + } + + public int getEf() { + return ef; + } + + public int getNumThreads() { + return numThreads; + } +} \ No newline at end of file diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchResults.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchResults.java new file mode 100644 index 0000000000..57ef8f50a5 --- /dev/null +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchResults.java @@ -0,0 +1,32 @@ +package com.nvidia.cuvs.hnsw; + +public class HnswSearchResults { + private final int[][] neighbors; + private final float[][] distances; + private final int topK; + private final int numQueries; + + public HnswSearchResults(int[][] neighbors, float[][] distances, int topK, int numQueries) { + this.neighbors = neighbors; + this.distances = distances; + this.topK = topK; + this.numQueries = numQueries; + } + + // Getters for validation + public int[][] getNeighbors() { + return neighbors; + } + + public float[][] getDistances() { + return distances; + } + + public int getTopK() { + return topK; + } + + public int getNumQueries() { + return numQueries; + } +} diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java new file mode 100644 index 0000000000..b32724f2ee --- /dev/null +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java @@ -0,0 +1,113 @@ +package com.nvidia.cuvs.hnsw; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +import com.nvidia.cuvs.CuVSResources; +import com.nvidia.cuvs.common.Util; + +public class HnswUtil { + + public static void serializeCagraToHnsw(CuVSResources resources, String cagraFilePath, String hnswFilePath) { + try (Arena arena = Arena.ofConfined()) { + MemorySegment cagraPathSegment = Util.toCString(arena, cagraFilePath); + MemorySegment hnswPathSegment = Util.toCString(arena, hnswFilePath); + + int result = (int) resources.cagraToHnswHandle.invokeExact(resources.getMemorySegment(), cagraPathSegment, + hnswPathSegment); + + if (result != 0) { + throw new RuntimeException("Failed to serialize CAGRA index to HNSW file. Error code: " + result); + } + } catch (Throwable e) { + throw new RuntimeException("Error during serialization: " + e.getMessage(), e); + } + } + + public static HnswSearchResults search(CuVSResources resources, HnswQuery query) { + try { + // Extract query parameters + float[][] queryVectors = query.getQueryVectors(); + int topK = query.getTopK(); + HnswSearchParameters searchParams = query.getSearchParameters(); + + int numQueries = queryVectors.length; + + // Prepare tensors and allocate memory for results + MemorySegment queryTensor = toTensor(resources.arena, queryVectors); + MemorySegment neighborsMemory = resources.arena.allocate(ValueLayout.JAVA_INT.byteSize() * topK * numQueries); + MemorySegment distancesMemory = resources.arena.allocate(ValueLayout.JAVA_FLOAT.byteSize() * topK * numQueries); + + // Configure search parameters + MemorySegment efSegment = resources.arena.allocate(ValueLayout.JAVA_INT.byteSize()); + efSegment.set(ValueLayout.JAVA_INT, 0, searchParams.getEf()); + + MemorySegment numThreadsSegment = resources.arena.allocate(ValueLayout.JAVA_INT.byteSize()); + numThreadsSegment.set(ValueLayout.JAVA_INT, 0, searchParams.getNumThreads()); + + // Perform the search using the native handle + int result = (int) resources.hnswSearchHandle.invokeExact(resources.getMemorySegment(), // Resources + queryTensor, // Query tensor + neighborsMemory, // Neighbors output + distancesMemory, // Distances output + efSegment, // Search parameter (ef) + numThreadsSegment // Search parameter (numThreads) + ); + + if (result != 0) { + throw new RuntimeException("Failed to perform HNSW search. Error code: " + result); + } + + // Extract results from memory segments + int[] flatNeighbors = new int[topK * numQueries]; + float[] flatDistances = new float[topK * numQueries]; + + for (int i = 0; i < flatNeighbors.length; i++) { + flatNeighbors[i] = neighborsMemory.getAtIndex(ValueLayout.JAVA_INT, i); + } + for (int i = 0; i < flatDistances.length; i++) { + flatDistances[i] = distancesMemory.getAtIndex(ValueLayout.JAVA_FLOAT, i); + } + + // Reshape the flat results into 2D arrays + int[][] neighbors = reshape(flatNeighbors, numQueries, topK); + float[][] distances = reshape(flatDistances, numQueries, topK); + + return new HnswSearchResults(neighbors, distances, topK, numQueries); + } catch (Throwable e) { + throw new RuntimeException("Error during HNSW search: " + e.getMessage(), e); + } + } + + private static MemorySegment toTensor(Arena arena, float[][] vectors) { + int rows = vectors.length; + int cols = vectors[0].length; + MemorySegment tensor = arena.allocate(ValueLayout.JAVA_FLOAT.byteSize() * rows * cols); + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + tensor.setAtIndex(ValueLayout.JAVA_FLOAT, i * cols + j, vectors[i][j]); + } + } + + return tensor; + } + + private static int[][] reshape(int[] flatArray, int rows, int cols) { + int[][] reshaped = new int[rows][cols]; + for (int i = 0; i < rows; i++) { + System.arraycopy(flatArray, i * cols, reshaped[i], 0, cols); + } + return reshaped; + } + + private static float[][] reshape(float[] flatArray, int rows, int cols) { + float[][] reshaped = new float[rows][cols]; + for (int i = 0; i < rows; i++) { + System.arraycopy(flatArray, i * cols, reshaped[i], 0, cols); + } + return reshaped; + } + +} \ No newline at end of file diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java index c079f9e86d..96d19149e8 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java @@ -11,14 +11,17 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import com.nvidia.cuvs.common.Util; +import com.nvidia.cuvs.hnsw.HnswQuery; +import com.nvidia.cuvs.hnsw.HnswSearchParameters; +import com.nvidia.cuvs.hnsw.HnswSearchResults; +import com.nvidia.cuvs.hnsw.HnswUtil; public class TestCagraToHnsw { private static final Logger log = LoggerFactory.getLogger(TestCagraToHnsw.class); @Test - public void testSerialization() throws Throwable { + public void testSerializationAndSearch() throws Throwable { String cagraFilePath = "cagra_index.bin"; String hnswFilePath = "hnsw_index.bin"; @@ -48,12 +51,33 @@ public void testSerialization() throws Throwable { log.info("CAGRA index serialized to: {}", cagraFilePath); log.info("Converting CAGRA index to HNSW format..."); - Util.serializeCagraToHnsw(resources, cagraFilePath, hnswFilePath); + HnswUtil.serializeCagraToHnsw(resources, cagraFilePath, hnswFilePath); File hnswFile = new File(hnswFilePath); assertTrue("HNSW index file should exist", hnswFile.exists()); log.info("HNSW index file successfully created at: {}", hnswFilePath); + // Step 2: Perform a search + log.info("Starting HNSW search..."); + float[][] queryVectors = { + {2.0f, 3.0f, 4.0f}, + {5.0f, 6.0f, 7.0f} + }; + int topK = 2; + + // Create search parameters + HnswSearchParameters searchParams = new HnswSearchParameters(20, 2); + + // Create query object + HnswQuery query = new HnswQuery(queryVectors, topK, searchParams); + + // Perform the search + HnswSearchResults results = HnswUtil.search(resources, query); + + // Validate results + assertTrue("Search results should not be null", results != null); + log.info("Search results retrieved: {}", results); + Files.deleteIfExists(Path.of(cagraFilePath)); Files.deleteIfExists(Path.of(hnswFilePath)); log.info("Test completed successfully. Temporary files cleaned up."); diff --git a/java/internal/src/cuvs_java.c b/java/internal/src/cuvs_java.c index 88677248f3..cb0b82d07c 100644 --- a/java/internal/src/cuvs_java.c +++ b/java/internal/src/cuvs_java.c @@ -152,3 +152,49 @@ void convert_cagra_to_hnsw(cuvsResources_t resources, cuvsCagraIndexDestroy(cagra_index); } +void search_hnsw_index(cuvsHnswIndex_t index, + float* queries, + int topk, + long n_queries, + int dimensions, + cuvsResources_t cuvsResources, + uint64_t* neighbors_h, + float* distances_h, + int* returnValue, + cuvsHnswSearchParams_t search_params) { + + uint64_t* neighbors; + float* distances; + float* queries_d; + + cuvsRMMAlloc(cuvsResources, (void**)&queries_d, sizeof(float) * n_queries * dimensions); + cuvsRMMAlloc(cuvsResources, (void**)&neighbors, sizeof(uint64_t) * n_queries * topk); + cuvsRMMAlloc(cuvsResources, (void**)&distances, sizeof(float) * n_queries * topk); + + cudaMemcpy(queries_d, queries, sizeof(float) * n_queries * dimensions, cudaMemcpyDefault); + + int64_t queries_shape[2] = {n_queries, dimensions}; + DLManagedTensor queries_tensor = prepare_tensor(queries_d, queries_shape, kDLFloat); + + int64_t neighbors_shape[2] = {n_queries, topk}; + DLManagedTensor neighbors_tensor = prepare_tensor(neighbors, neighbors_shape, kDLUInt); + + int64_t distances_shape[2] = {n_queries, topk}; + DLManagedTensor distances_tensor = prepare_tensor(distances, distances_shape, kDLFloat); + + *returnValue = cuvsHnswSearch(cuvsResources, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor); + + if (*returnValue != CUVS_SUCCESS) { + printf("Error: Search failed. Error code: %d\n", *returnValue); + } else { + printf("Search completed successfully\n"); + } + + cudaMemcpy(neighbors_h, neighbors, sizeof(uint64_t) * n_queries * topk, cudaMemcpyDefault); + cudaMemcpy(distances_h, distances, sizeof(float) * n_queries * topk, cudaMemcpyDefault); + + cuvsRMMFree(cuvsResources, distances, sizeof(float) * n_queries * topk); + cuvsRMMFree(cuvsResources, neighbors, sizeof(uint64_t) * n_queries * topk); + cuvsRMMFree(cuvsResources, queries_d, sizeof(float) * n_queries * dimensions); +} + From 39b8f6b93e21e277de5b9a16541979805a296e2e Mon Sep 17 00:00:00 2001 From: punAhuja Date: Fri, 20 Dec 2024 16:55:42 +0530 Subject: [PATCH 5/7] Corrected parameters --- .../main/java/com/nvidia/cuvs/CuVSResources.java | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java index f717d1232f..3091b34ced 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java @@ -67,11 +67,15 @@ public CuVSResources() throws Throwable { libcuvsNativeLibrary.find("convert_cagra_to_hnsw") .orElseThrow(() -> new IllegalStateException("convert_cagra_to_hnsw not found")), FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS)); - hnswSearchHandle = linker.downcallHandle( - libcuvsNativeLibrary.find("cuvsHnswSearch") - .orElseThrow(() -> new IllegalStateException("cuvsHnswSearch not found")), - FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, - ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS)); + hnswSearchHandle = linker.downcallHandle(libcuvsNativeLibrary.find("cuvsHnswSearch").get(), + FunctionDescriptor.of(ValueLayout.JAVA_INT, // Return type + ValueLayout.ADDRESS, // cuvsResources_t + ValueLayout.ADDRESS, // cuvsHnswSearchParams_t + ValueLayout.ADDRESS, // cuvsHnswIndex_t + ValueLayout.ADDRESS, // DLManagedTensor* queries + ValueLayout.ADDRESS, // DLManagedTensor* neighbors + ValueLayout.ADDRESS // DLManagedTensor* distances + )); } From 3de4908c653d6a7f3a8c52be4ccbc91e773a8104 Mon Sep 17 00:00:00 2001 From: punAhuja Date: Fri, 27 Dec 2024 11:26:30 +0530 Subject: [PATCH 6/7] Updated search implementation. C function is invoked but crash is seen. --- .../java/com/nvidia/cuvs/CuVSResources.java | 22 +-- .../java/com/nvidia/cuvs/hnsw/HnswQuery.java | 11 +- .../nvidia/cuvs/hnsw/HnswSearchResults.java | 7 +- .../java/com/nvidia/cuvs/hnsw/HnswUtil.java | 91 ++++++------ .../nvidia/cuvs/CagraBuildAndSearchTest.java | 49 ++----- .../java/com/nvidia/cuvs/TestCagraToHnsw.java | 133 +++++++++--------- java/internal/src/cuvs_java.c | 28 +++- 7 files changed, 183 insertions(+), 158 deletions(-) diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java index 3091b34ced..61cf6e59ed 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java @@ -43,8 +43,8 @@ public class CuVSResources implements AutoCloseable { private final MethodHandle createResourcesMethodHandle; private final MethodHandle destroyResourcesMethodHandle; public final MethodHandle cagraToHnswHandle; - public final MethodHandle hnswSearchHandle; private MemorySegment resourcesMemorySegment; + public final MethodHandle hnswSearchHandle; /** * Constructor that allocates the resources needed for cuVS @@ -63,20 +63,26 @@ public CuVSResources() throws Throwable { destroyResourcesMethodHandle = linker.downcallHandle(libcuvsNativeLibrary.find("destroy_resources").get(), FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS)); createResources(); + cagraToHnswHandle = linker.downcallHandle( libcuvsNativeLibrary.find("convert_cagra_to_hnsw") .orElseThrow(() -> new IllegalStateException("convert_cagra_to_hnsw not found")), FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS)); - hnswSearchHandle = linker.downcallHandle(libcuvsNativeLibrary.find("cuvsHnswSearch").get(), + + hnswSearchHandle = linker.downcallHandle( + libcuvsNativeLibrary.find("search_hnsw_index") + .orElseThrow(() -> new RuntimeException("search_hnsw_index function not found in native library")), FunctionDescriptor.of(ValueLayout.JAVA_INT, // Return type ValueLayout.ADDRESS, // cuvsResources_t - ValueLayout.ADDRESS, // cuvsHnswSearchParams_t ValueLayout.ADDRESS, // cuvsHnswIndex_t - ValueLayout.ADDRESS, // DLManagedTensor* queries - ValueLayout.ADDRESS, // DLManagedTensor* neighbors - ValueLayout.ADDRESS // DLManagedTensor* distances + ValueLayout.ADDRESS, // float* queries + ValueLayout.JAVA_INT, // topK + ValueLayout.JAVA_INT, // n_queries + ValueLayout.JAVA_INT, // dimensions + ValueLayout.ADDRESS, // neighbors_h + ValueLayout.ADDRESS, // distances_h + ValueLayout.ADDRESS // search_params )); - } /** @@ -114,7 +120,7 @@ public MemorySegment getMemorySegment() { /** * Returns the loaded libcuvs_java.so as a {@link SymbolLookup} */ - protected SymbolLookup getLibcuvsNativeLibrary() { + public SymbolLookup getLibcuvsNativeLibrary() { return libcuvsNativeLibrary; } } \ No newline at end of file diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswQuery.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswQuery.java index 4ff8972c10..bd0b221164 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswQuery.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswQuery.java @@ -1,14 +1,19 @@ package com.nvidia.cuvs.hnsw; +import java.lang.foreign.MemorySegment; + public class HnswQuery { private final float[][] queryVectors; private final int topK; private final HnswSearchParameters searchParameters; + private final MemorySegment indexMemorySegment; - public HnswQuery(float[][] queryVectors, int topK, HnswSearchParameters searchParams) { + public HnswQuery(MemorySegment indexMemorySegment, float[][] queryVectors, int topK, + HnswSearchParameters searchParams) { this.queryVectors = queryVectors; this.topK = topK; this.searchParameters = searchParams; + this.indexMemorySegment = indexMemorySegment; } public float[][] getQueryVectors() { @@ -23,4 +28,8 @@ public HnswSearchParameters getSearchParameters() { return searchParameters; } + public MemorySegment getIndexMemorySegment() { + return indexMemorySegment; + } + } diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchResults.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchResults.java index 57ef8f50a5..038aafa1ba 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchResults.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchResults.java @@ -1,20 +1,19 @@ package com.nvidia.cuvs.hnsw; public class HnswSearchResults { - private final int[][] neighbors; + private final long[][] neighbors; private final float[][] distances; private final int topK; private final int numQueries; - public HnswSearchResults(int[][] neighbors, float[][] distances, int topK, int numQueries) { + public HnswSearchResults(long[][] neighbors, float[][] distances, int topK, int numQueries) { this.neighbors = neighbors; this.distances = distances; this.topK = topK; this.numQueries = numQueries; } - // Getters for validation - public int[][] getNeighbors() { + public long[][] getNeighbors() { return neighbors; } diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java index b32724f2ee..be0de9bac6 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java @@ -27,75 +27,84 @@ public static void serializeCagraToHnsw(CuVSResources resources, String cagraFil public static HnswSearchResults search(CuVSResources resources, HnswQuery query) { try { - // Extract query parameters + // Extract parameters from query float[][] queryVectors = query.getQueryVectors(); int topK = query.getTopK(); - HnswSearchParameters searchParams = query.getSearchParameters(); - + int dimensions = queryVectors[0].length; int numQueries = queryVectors.length; - // Prepare tensors and allocate memory for results - MemorySegment queryTensor = toTensor(resources.arena, queryVectors); - MemorySegment neighborsMemory = resources.arena.allocate(ValueLayout.JAVA_INT.byteSize() * topK * numQueries); - MemorySegment distancesMemory = resources.arena.allocate(ValueLayout.JAVA_FLOAT.byteSize() * topK * numQueries); - - // Configure search parameters - MemorySegment efSegment = resources.arena.allocate(ValueLayout.JAVA_INT.byteSize()); - efSegment.set(ValueLayout.JAVA_INT, 0, searchParams.getEf()); + // Use the shared arena from resources + Arena arena = resources.arena; - MemorySegment numThreadsSegment = resources.arena.allocate(ValueLayout.JAVA_INT.byteSize()); - numThreadsSegment.set(ValueLayout.JAVA_INT, 0, searchParams.getNumThreads()); + // Prepare flat memory for queries + MemorySegment queriesMemory = toFlatArray(arena, queryVectors); + MemorySegment neighborsMemory = arena.allocate(ValueLayout.JAVA_LONG.byteSize() * topK * numQueries); + MemorySegment distancesMemory = arena.allocate(ValueLayout.JAVA_FLOAT.byteSize() * topK * numQueries); - // Perform the search using the native handle - int result = (int) resources.hnswSearchHandle.invokeExact(resources.getMemorySegment(), // Resources - queryTensor, // Query tensor + // Configure search parameters + MemorySegment searchParamsMemory = arena.allocate(ValueLayout.JAVA_INT.byteSize() * 2); + searchParamsMemory.setAtIndex(ValueLayout.JAVA_INT, 0, query.getSearchParameters().getEf()); + searchParamsMemory.setAtIndex(ValueLayout.JAVA_INT, 1, query.getSearchParameters().getNumThreads()); + + // Invoke the HNSW search function using the native handle + int result = (int) resources.hnswSearchHandle.invokeExact(resources.getMemorySegment(), // cuVS resources + query.getIndexMemorySegment(), // cuvsHnswIndex_t + queriesMemory, // Queries + topK, // topK + numQueries, // n_queries + dimensions, // dimensions neighborsMemory, // Neighbors output distancesMemory, // Distances output - efSegment, // Search parameter (ef) - numThreadsSegment // Search parameter (numThreads) + searchParamsMemory // Search parameters ); if (result != 0) { - throw new RuntimeException("Failed to perform HNSW search. Error code: " + result); + throw new RuntimeException("HNSW search failed. Error code: " + result); } - // Extract results from memory segments - int[] flatNeighbors = new int[topK * numQueries]; - float[] flatDistances = new float[topK * numQueries]; - - for (int i = 0; i < flatNeighbors.length; i++) { - flatNeighbors[i] = neighborsMemory.getAtIndex(ValueLayout.JAVA_INT, i); - } - for (int i = 0; i < flatDistances.length; i++) { - flatDistances[i] = distancesMemory.getAtIndex(ValueLayout.JAVA_FLOAT, i); - } + // Extract neighbors and distances from memory + long[] flatNeighbors = extractLongArray(neighborsMemory, topK * numQueries); + float[] flatDistances = extractFloatArray(distancesMemory, topK * numQueries); - // Reshape the flat results into 2D arrays - int[][] neighbors = reshape(flatNeighbors, numQueries, topK); + long[][] neighbors = reshape(flatNeighbors, numQueries, topK); float[][] distances = reshape(flatDistances, numQueries, topK); return new HnswSearchResults(neighbors, distances, topK, numQueries); } catch (Throwable e) { - throw new RuntimeException("Error during HNSW search: " + e.getMessage(), e); + throw new RuntimeException("Error during HNSW search", e); } } - private static MemorySegment toTensor(Arena arena, float[][] vectors) { - int rows = vectors.length; - int cols = vectors[0].length; - MemorySegment tensor = arena.allocate(ValueLayout.JAVA_FLOAT.byteSize() * rows * cols); - + private static MemorySegment toFlatArray(Arena arena, float[][] array) { + int rows = array.length; + int cols = array[0].length; + MemorySegment flatArray = arena.allocate(ValueLayout.JAVA_FLOAT.byteSize() * rows * cols); for (int i = 0; i < rows; i++) { for (int j = 0; j < cols; j++) { - tensor.setAtIndex(ValueLayout.JAVA_FLOAT, i * cols + j, vectors[i][j]); + flatArray.setAtIndex(ValueLayout.JAVA_FLOAT, i * cols + j, array[i][j]); } } + return flatArray; + } + + private static long[] extractLongArray(MemorySegment memory, int size) { + long[] result = new long[size]; + for (int i = 0; i < size; i++) { + result[i] = memory.get(ValueLayout.JAVA_LONG, i * ValueLayout.JAVA_LONG.byteSize()); + } + return result; + } - return tensor; + private static float[] extractFloatArray(MemorySegment memory, int size) { + float[] result = new float[size]; + for (int i = 0; i < size; i++) { + result[i] = memory.get(ValueLayout.JAVA_FLOAT, i * ValueLayout.JAVA_FLOAT.byteSize()); + } + return result; } - private static int[][] reshape(int[] flatArray, int rows, int cols) { - int[][] reshaped = new int[rows][cols]; + private static long[][] reshape(long[] flatArray, int rows, int cols) { + long[][] reshaped = new long[rows][cols]; for (int i = 0; i < rows; i++) { System.arraycopy(flatArray, i * cols, reshaped[i], 0, cols); } diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java index 47e42f2a5b..615b0298d7 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java @@ -48,25 +48,15 @@ public class CagraBuildAndSearchTest { public void testIndexingAndSearchingFlow() throws Throwable { // Sample data and query - float[][] dataset = { - { 0.74021935f, 0.9209938f }, - { 0.03902049f, 0.9689629f }, - { 0.92514056f, 0.4463501f }, - { 0.6673192f, 0.10993068f } - }; + float[][] dataset = { { 0.74021935f, 0.9209938f }, { 0.03902049f, 0.9689629f }, { 0.92514056f, 0.4463501f }, + { 0.6673192f, 0.10993068f } }; Map map = Map.of(0, 0, 1, 1, 2, 2, 3, 3); - float[][] queries = { - { 0.48216683f, 0.0428398f }, - { 0.5084142f, 0.6545497f }, - { 0.51260436f, 0.2643005f }, - { 0.05198065f, 0.5789965f } - }; + float[][] queries = { { 0.48216683f, 0.0428398f }, { 0.5084142f, 0.6545497f }, { 0.51260436f, 0.2643005f }, + { 0.05198065f, 0.5789965f } }; // Expected search results - List> expectedResults = Arrays.asList( - Map.of(3, 0.038782578f, 2, 0.3590463f, 0, 0.83774555f), - Map.of(0, 0.12472608f, 2, 0.21700792f, 1, 0.31918612f), - Map.of(3, 0.047766715f, 2, 0.20332818f, 0, 0.48305473f), + List> expectedResults = Arrays.asList(Map.of(3, 0.038782578f, 2, 0.3590463f, 0, 0.83774555f), + Map.of(0, 0.12472608f, 2, 0.21700792f, 1, 0.31918612f), Map.of(3, 0.047766715f, 2, 0.20332818f, 0, 0.48305473f), Map.of(1, 0.15224178f, 0, 0.59063464f, 3, 0.5986642f)); for (int j = 0; j < 10; j++) { @@ -75,17 +65,11 @@ public void testIndexingAndSearchingFlow() throws Throwable { // Configure index parameters CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) - .withCagraGraphBuildAlgo(CagraGraphBuildAlgo.NN_DESCENT) - .withGraphDegree(1) - .withIntermediateGraphDegree(2) - .withNumWriterThreads(32) - .build(); + .withCagraGraphBuildAlgo(CagraGraphBuildAlgo.NN_DESCENT).withGraphDegree(1).withIntermediateGraphDegree(2) + .withNumWriterThreads(32).build(); // Create the index with the dataset - CagraIndex index = new CagraIndex.Builder(resources) - .withDataset(dataset) - .withIndexParams(indexParams) - .build(); + CagraIndex index = new CagraIndex.Builder(resources).withDataset(dataset).withIndexParams(indexParams).build(); // Saving the index on to the disk. String indexFileName = UUID.randomUUID().toString() + ".cag"; @@ -94,21 +78,14 @@ public void testIndexingAndSearchingFlow() throws Throwable { // Loading a CAGRA index from disk. File indexFile = new File(indexFileName); InputStream inputStream = new FileInputStream(indexFile); - CagraIndex loadedIndex = new CagraIndex.Builder(resources) - .from(inputStream) - .build(); + CagraIndex loadedIndex = new CagraIndex.Builder(resources).from(inputStream).build(); // Configure search parameters - CagraSearchParams searchParams = new CagraSearchParams.Builder(resources) - .build(); + CagraSearchParams searchParams = new CagraSearchParams.Builder(resources).build(); // Create a query object with the query vectors - CagraQuery cuvsQuery = new CagraQuery.Builder() - .withTopK(3) - .withSearchParams(searchParams) - .withQueryVectors(queries) - .withMapping(map) - .build(); + CagraQuery cuvsQuery = new CagraQuery.Builder().withTopK(3).withSearchParams(searchParams) + .withQueryVectors(queries).withMapping(map).build(); // Perform the search SearchResults results = index.search(cuvsQuery); diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java index 96d19149e8..2664cc3bdd 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java @@ -1,9 +1,12 @@ package com.nvidia.cuvs; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import java.io.File; import java.io.FileOutputStream; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; import java.nio.file.Files; import java.nio.file.Path; @@ -18,72 +21,68 @@ public class TestCagraToHnsw { - private static final Logger log = LoggerFactory.getLogger(TestCagraToHnsw.class); - - @Test - public void testSerializationAndSearch() throws Throwable { - String cagraFilePath = "cagra_index.bin"; - String hnswFilePath = "hnsw_index.bin"; - - try (CuVSResources resources = new CuVSResources()) { - log.info("Starting CAGRA to HNSW test."); - - // Step 1: Build a sample dataset - float[][] dataset = { - {1.0f, 2.0f, 3.0f}, - {4.0f, 5.0f, 6.0f}, - {7.0f, 8.0f, 9.0f} - }; - - log.info("Building CAGRA index..."); - CagraIndex cagraIndex = new CagraIndex.Builder(resources) - .withDataset(dataset) - .withIndexParams(new CagraIndexParams.Builder(resources).build()) - .build(); - log.info("CAGRA index built successfully."); - - log.info("Serializing CAGRA index to file: {}", cagraFilePath); - File tempFile = File.createTempFile("cagra_temp_", ".tmp"); - try (FileOutputStream outputStream = new FileOutputStream(cagraFilePath)) { - cagraIndex.serialize(outputStream, tempFile); - } - assertTrue("CAGRA index file should exist", new File(cagraFilePath).exists()); - log.info("CAGRA index serialized to: {}", cagraFilePath); - - log.info("Converting CAGRA index to HNSW format..."); - HnswUtil.serializeCagraToHnsw(resources, cagraFilePath, hnswFilePath); - - File hnswFile = new File(hnswFilePath); - assertTrue("HNSW index file should exist", hnswFile.exists()); - log.info("HNSW index file successfully created at: {}", hnswFilePath); - - // Step 2: Perform a search - log.info("Starting HNSW search..."); - float[][] queryVectors = { - {2.0f, 3.0f, 4.0f}, - {5.0f, 6.0f, 7.0f} - }; - int topK = 2; - - // Create search parameters - HnswSearchParameters searchParams = new HnswSearchParameters(20, 2); - - // Create query object - HnswQuery query = new HnswQuery(queryVectors, topK, searchParams); - - // Perform the search - HnswSearchResults results = HnswUtil.search(resources, query); - - // Validate results - assertTrue("Search results should not be null", results != null); - log.info("Search results retrieved: {}", results); - - Files.deleteIfExists(Path.of(cagraFilePath)); - Files.deleteIfExists(Path.of(hnswFilePath)); - log.info("Test completed successfully. Temporary files cleaned up."); - } catch (Exception e) { - log.error("Test failed: ", e); - throw e; // Re-throw exception to mark test as failed - } + private static final Logger log = LoggerFactory.getLogger(TestCagraToHnsw.class); + + @Test + public void testSerializationAndSearch() throws Throwable { + String cagraFilePath = "cagra_index.bin"; + String hnswFilePath = "hnsw_index.bin"; + + try (CuVSResources resources = new CuVSResources()) { + log.info("Starting CAGRA to HNSW test."); + + // Step 1: Build a sample dataset + float[][] dataset = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } }; + + log.info("Building CAGRA index..."); + CagraIndex cagraIndex = new CagraIndex.Builder(resources).withDataset(dataset) + .withIndexParams(new CagraIndexParams.Builder(resources).build()).build(); + log.info("CAGRA index built successfully."); + + log.info("Serializing CAGRA index to file: {}", cagraFilePath); + try (FileOutputStream outputStream = new FileOutputStream(cagraFilePath)) { + cagraIndex.serialize(outputStream); + } + assertTrue("CAGRA index file should exist", new File(cagraFilePath).exists()); + log.info("CAGRA index serialized to: {}", cagraFilePath); + + log.info("Converting CAGRA index to HNSW format..."); + HnswUtil.serializeCagraToHnsw(resources, cagraFilePath, hnswFilePath); + + File hnswFile = new File(hnswFilePath); + assertTrue("HNSW index file should exist", hnswFile.exists()); + log.info("HNSW index file successfully created at: {}", hnswFilePath); + + // Step 2: Prepare a memory segment for the HNSW index + MemorySegment hnswIndexSegment = resources.arena.allocate(ValueLayout.ADDRESS.byteSize()); + + // Step 3: Perform a search + log.info("Starting HNSW search..."); + float[][] queryVectors = { { 2.0f, 3.0f, 4.0f }, { 5.0f, 6.0f, 7.0f } }; + int topK = 2; + + // Create search parameters + HnswSearchParameters searchParams = new HnswSearchParameters(20, 2); + + // Create query object with the correct constructor + HnswQuery query = new HnswQuery(hnswIndexSegment, queryVectors, topK, searchParams); + + // Perform the search + HnswSearchResults results = HnswUtil.search(resources, query); + + // Validate results + assertNotNull("Search results should not be null", results); + assertTrue("Search results should have neighbors", results.getNeighbors().length > 0); + assertTrue("Search results should have distances", results.getDistances().length > 0); + log.info("Search results retrieved: Neighbors = {}, Distances = {}", results.getNeighbors(), + results.getDistances()); + + Files.deleteIfExists(Path.of(cagraFilePath)); + Files.deleteIfExists(Path.of(hnswFilePath)); + log.info("Test completed successfully. Temporary files cleaned up."); + } catch (Exception e) { + log.error("Test failed: ", e); + throw e; // Re-throw exception to mark test as failed } + } } diff --git a/java/internal/src/cuvs_java.c b/java/internal/src/cuvs_java.c index cb0b82d07c..8d4cf3ca40 100644 --- a/java/internal/src/cuvs_java.c +++ b/java/internal/src/cuvs_java.c @@ -146,7 +146,7 @@ void convert_cagra_to_hnsw(cuvsResources_t resources, printf("Error: Failed to serialize to HNSW format. Error code: %d\n", *return_value); cuvsCagraIndexDestroy(cagra_index); return; - } + } printf("Successfully serialized CAGRA index to HNSW format: %s\n", hnsw_filename); cuvsCagraIndexDestroy(cagra_index); @@ -163,16 +163,20 @@ void search_hnsw_index(cuvsHnswIndex_t index, int* returnValue, cuvsHnswSearchParams_t search_params) { + printf("C layer function invoked successfully\n"); uint64_t* neighbors; float* distances; float* queries_d; + // Allocate memory cuvsRMMAlloc(cuvsResources, (void**)&queries_d, sizeof(float) * n_queries * dimensions); cuvsRMMAlloc(cuvsResources, (void**)&neighbors, sizeof(uint64_t) * n_queries * topk); cuvsRMMAlloc(cuvsResources, (void**)&distances, sizeof(float) * n_queries * topk); + // Copy queries to device memory cudaMemcpy(queries_d, queries, sizeof(float) * n_queries * dimensions, cudaMemcpyDefault); + // Prepare tensors int64_t queries_shape[2] = {n_queries, dimensions}; DLManagedTensor queries_tensor = prepare_tensor(queries_d, queries_shape, kDLFloat); @@ -182,19 +186,41 @@ void search_hnsw_index(cuvsHnswIndex_t index, int64_t distances_shape[2] = {n_queries, topk}; DLManagedTensor distances_tensor = prepare_tensor(distances, distances_shape, kDLFloat); + // Perform search *returnValue = cuvsHnswSearch(cuvsResources, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor); + // Check and print results if (*returnValue != CUVS_SUCCESS) { printf("Error: Search failed. Error code: %d\n", *returnValue); } else { printf("Search completed successfully\n"); + + // Print neighbors and distances + printf("HNSW Search Results:\n"); + for (long i = 0; i < n_queries; i++) { + printf("Query %ld:\n", i); + printf(" Neighbors: "); + for (int j = 0; j < topk; j++) { + printf("%lu ", neighbors[i * topk + j]); + } + printf("\n"); + + printf(" Distances: "); + for (int j = 0; j < topk; j++) { + printf("%.6f ", distances[i * topk + j]); + } + printf("\n"); + } } + // Copy results back to host memory cudaMemcpy(neighbors_h, neighbors, sizeof(uint64_t) * n_queries * topk, cudaMemcpyDefault); cudaMemcpy(distances_h, distances, sizeof(float) * n_queries * topk, cudaMemcpyDefault); + // Free allocated resources cuvsRMMFree(cuvsResources, distances, sizeof(float) * n_queries * topk); cuvsRMMFree(cuvsResources, neighbors, sizeof(uint64_t) * n_queries * topk); cuvsRMMFree(cuvsResources, queries_d, sizeof(float) * n_queries * dimensions); } + From fbc3340e83217ce851aecc7cab69ed35bb9fe3a8 Mon Sep 17 00:00:00 2001 From: punAhuja Date: Fri, 27 Dec 2024 11:39:32 +0530 Subject: [PATCH 7/7] Reused buildMemorySegment from Util class --- .../main/java/com/nvidia/cuvs/hnsw/HnswUtil.java | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java index be0de9bac6..4c9d7dea97 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java @@ -37,7 +37,7 @@ public static HnswSearchResults search(CuVSResources resources, HnswQuery query) Arena arena = resources.arena; // Prepare flat memory for queries - MemorySegment queriesMemory = toFlatArray(arena, queryVectors); + MemorySegment queriesMemory = Util.buildMemorySegment(resources.linker, resources.arena, queryVectors); MemorySegment neighborsMemory = arena.allocate(ValueLayout.JAVA_LONG.byteSize() * topK * numQueries); MemorySegment distancesMemory = arena.allocate(ValueLayout.JAVA_FLOAT.byteSize() * topK * numQueries); @@ -75,18 +75,6 @@ public static HnswSearchResults search(CuVSResources resources, HnswQuery query) } } - private static MemorySegment toFlatArray(Arena arena, float[][] array) { - int rows = array.length; - int cols = array[0].length; - MemorySegment flatArray = arena.allocate(ValueLayout.JAVA_FLOAT.byteSize() * rows * cols); - for (int i = 0; i < rows; i++) { - for (int j = 0; j < cols; j++) { - flatArray.setAtIndex(ValueLayout.JAVA_FLOAT, i * cols + j, array[i][j]); - } - } - return flatArray; - } - private static long[] extractLongArray(MemorySegment memory, int size) { long[] result = new long[size]; for (int i = 0; i < size; i++) {