diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java index 67394b53b6..70be98d74a 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java @@ -380,7 +380,7 @@ protected IndexReference(MemorySegment indexMemorySegment) { * * @return index MemorySegment */ - protected MemorySegment getMemorySegment() { + public MemorySegment getMemorySegment() { return memorySegment; } } diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java index 60f68981b1..61cf6e59ed 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java @@ -42,7 +42,9 @@ public class CuVSResources implements AutoCloseable { private final MethodHandle createResourcesMethodHandle; private final MethodHandle destroyResourcesMethodHandle; + public final MethodHandle cagraToHnswHandle; private MemorySegment resourcesMemorySegment; + public final MethodHandle hnswSearchHandle; /** * Constructor that allocates the resources needed for cuVS @@ -61,6 +63,26 @@ public CuVSResources() throws Throwable { destroyResourcesMethodHandle = linker.downcallHandle(libcuvsNativeLibrary.find("destroy_resources").get(), FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS)); createResources(); + + cagraToHnswHandle = linker.downcallHandle( + libcuvsNativeLibrary.find("convert_cagra_to_hnsw") + .orElseThrow(() -> new IllegalStateException("convert_cagra_to_hnsw not found")), + FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS)); + + hnswSearchHandle = linker.downcallHandle( + libcuvsNativeLibrary.find("search_hnsw_index") + .orElseThrow(() -> new RuntimeException("search_hnsw_index function not found in native library")), + FunctionDescriptor.of(ValueLayout.JAVA_INT, // Return type + ValueLayout.ADDRESS, // cuvsResources_t + ValueLayout.ADDRESS, // cuvsHnswIndex_t + ValueLayout.ADDRESS, // float* queries + ValueLayout.JAVA_INT, // topK + ValueLayout.JAVA_INT, // n_queries + ValueLayout.JAVA_INT, // dimensions + ValueLayout.ADDRESS, // neighbors_h + ValueLayout.ADDRESS, // distances_h + ValueLayout.ADDRESS // search_params + )); } /** @@ -91,14 +113,14 @@ public void close() { * * @return cuvsResources MemorySegment */ - protected MemorySegment getMemorySegment() { + public MemorySegment getMemorySegment() { return resourcesMemorySegment; } /** * Returns the loaded libcuvs_java.so as a {@link SymbolLookup} */ - protected SymbolLookup getLibcuvsNativeLibrary() { + public SymbolLookup getLibcuvsNativeLibrary() { return libcuvsNativeLibrary; } } \ No newline at end of file diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java index 750e49d642..efbc68c0f3 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java @@ -25,11 +25,23 @@ import java.lang.foreign.MemoryLayout.PathElement; import java.lang.foreign.MemorySegment; import java.lang.foreign.ValueLayout; +import java.lang.invoke.MethodHandle; import java.lang.invoke.VarHandle; +import java.nio.charset.StandardCharsets; import org.apache.commons.io.IOUtils; +import com.nvidia.cuvs.CuVSResources; + public class Util { + + public static MemorySegment toCString(Arena arena, String string) { + byte[] bytes = (string + "\0").getBytes(); + MemorySegment segment = arena.allocate(bytes.length); + segment.copyFrom(MemorySegment.ofArray(bytes)); + return segment; + } + /** * A utility method for getting an instance of {@link MemorySegment} for a * {@link String}. @@ -97,4 +109,32 @@ public static File loadLibraryFromJar(String path) throws IOException { return temp; } + + public static void convertCagraToHnsw(CuVSResources resources, String cagraFilePath, String hnswFilePath) { + try { + MethodHandle convertCagraToHnswHandle = resources.cagraToHnswHandle; + + byte[] cagraBytes = cagraFilePath.getBytes(StandardCharsets.UTF_8); + byte[] hnswBytes = hnswFilePath.getBytes(StandardCharsets.UTF_8); + + MemorySegment cagraFileSegment = resources.arena.allocate(cagraBytes.length + 1); + MemorySegment hnswFileSegment = resources.arena.allocate(hnswBytes.length + 1); + + cagraFileSegment.asSlice(0, cagraBytes.length).copyFrom(MemorySegment.ofArray(cagraBytes)); + cagraFileSegment.set(ValueLayout.JAVA_BYTE, cagraBytes.length, (byte) 0); + + hnswFileSegment.asSlice(0, hnswBytes.length).copyFrom(MemorySegment.ofArray(hnswBytes)); + hnswFileSegment.set(ValueLayout.JAVA_BYTE, hnswBytes.length, (byte) 0); + + int returnValue = (int) convertCagraToHnswHandle.invoke(resources.getMemorySegment(), cagraFileSegment, + hnswFileSegment); + + if (returnValue != 0) { + throw new RuntimeException("Failed to convert CAGRA to HNSW, error code: " + returnValue); + } + } catch (Throwable e) { + throw new RuntimeException("Error during CAGRA to HNSW conversion", e); + } + } + } diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswQuery.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswQuery.java new file mode 100644 index 0000000000..bd0b221164 --- /dev/null +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswQuery.java @@ -0,0 +1,35 @@ +package com.nvidia.cuvs.hnsw; + +import java.lang.foreign.MemorySegment; + +public class HnswQuery { + private final float[][] queryVectors; + private final int topK; + private final HnswSearchParameters searchParameters; + private final MemorySegment indexMemorySegment; + + public HnswQuery(MemorySegment indexMemorySegment, float[][] queryVectors, int topK, + HnswSearchParameters searchParams) { + this.queryVectors = queryVectors; + this.topK = topK; + this.searchParameters = searchParams; + this.indexMemorySegment = indexMemorySegment; + } + + public float[][] getQueryVectors() { + return queryVectors; + } + + public int getTopK() { + return topK; + } + + public HnswSearchParameters getSearchParameters() { + return searchParameters; + } + + public MemorySegment getIndexMemorySegment() { + return indexMemorySegment; + } + +} diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchParameters.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchParameters.java new file mode 100644 index 0000000000..717706ff6f --- /dev/null +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchParameters.java @@ -0,0 +1,30 @@ +package com.nvidia.cuvs.hnsw; + +public class HnswSearchParameters { + + private int ef; + private int numThreads; + + public HnswSearchParameters(int ef, int numThreads) { + this.ef = ef; + this.numThreads = numThreads; + } + + public HnswSearchParameters withEf(int ef) { + this.ef = ef; + return this; + } + + public HnswSearchParameters withNumThreads(int threads) { + this.numThreads = threads; + return this; + } + + public int getEf() { + return ef; + } + + public int getNumThreads() { + return numThreads; + } +} \ No newline at end of file diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchResults.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchResults.java new file mode 100644 index 0000000000..038aafa1ba --- /dev/null +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswSearchResults.java @@ -0,0 +1,31 @@ +package com.nvidia.cuvs.hnsw; + +public class HnswSearchResults { + private final long[][] neighbors; + private final float[][] distances; + private final int topK; + private final int numQueries; + + public HnswSearchResults(long[][] neighbors, float[][] distances, int topK, int numQueries) { + this.neighbors = neighbors; + this.distances = distances; + this.topK = topK; + this.numQueries = numQueries; + } + + public long[][] getNeighbors() { + return neighbors; + } + + public float[][] getDistances() { + return distances; + } + + public int getTopK() { + return topK; + } + + public int getNumQueries() { + return numQueries; + } +} diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java new file mode 100644 index 0000000000..4c9d7dea97 --- /dev/null +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java @@ -0,0 +1,110 @@ +package com.nvidia.cuvs.hnsw; + +import java.lang.foreign.Arena; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; + +import com.nvidia.cuvs.CuVSResources; +import com.nvidia.cuvs.common.Util; + +public class HnswUtil { + + public static void serializeCagraToHnsw(CuVSResources resources, String cagraFilePath, String hnswFilePath) { + try (Arena arena = Arena.ofConfined()) { + MemorySegment cagraPathSegment = Util.toCString(arena, cagraFilePath); + MemorySegment hnswPathSegment = Util.toCString(arena, hnswFilePath); + + int result = (int) resources.cagraToHnswHandle.invokeExact(resources.getMemorySegment(), cagraPathSegment, + hnswPathSegment); + + if (result != 0) { + throw new RuntimeException("Failed to serialize CAGRA index to HNSW file. Error code: " + result); + } + } catch (Throwable e) { + throw new RuntimeException("Error during serialization: " + e.getMessage(), e); + } + } + + public static HnswSearchResults search(CuVSResources resources, HnswQuery query) { + try { + // Extract parameters from query + float[][] queryVectors = query.getQueryVectors(); + int topK = query.getTopK(); + int dimensions = queryVectors[0].length; + int numQueries = queryVectors.length; + + // Use the shared arena from resources + Arena arena = resources.arena; + + // Prepare flat memory for queries + MemorySegment queriesMemory = Util.buildMemorySegment(resources.linker, resources.arena, queryVectors); + MemorySegment neighborsMemory = arena.allocate(ValueLayout.JAVA_LONG.byteSize() * topK * numQueries); + MemorySegment distancesMemory = arena.allocate(ValueLayout.JAVA_FLOAT.byteSize() * topK * numQueries); + + // Configure search parameters + MemorySegment searchParamsMemory = arena.allocate(ValueLayout.JAVA_INT.byteSize() * 2); + searchParamsMemory.setAtIndex(ValueLayout.JAVA_INT, 0, query.getSearchParameters().getEf()); + searchParamsMemory.setAtIndex(ValueLayout.JAVA_INT, 1, query.getSearchParameters().getNumThreads()); + + // Invoke the HNSW search function using the native handle + int result = (int) resources.hnswSearchHandle.invokeExact(resources.getMemorySegment(), // cuVS resources + query.getIndexMemorySegment(), // cuvsHnswIndex_t + queriesMemory, // Queries + topK, // topK + numQueries, // n_queries + dimensions, // dimensions + neighborsMemory, // Neighbors output + distancesMemory, // Distances output + searchParamsMemory // Search parameters + ); + + if (result != 0) { + throw new RuntimeException("HNSW search failed. Error code: " + result); + } + + // Extract neighbors and distances from memory + long[] flatNeighbors = extractLongArray(neighborsMemory, topK * numQueries); + float[] flatDistances = extractFloatArray(distancesMemory, topK * numQueries); + + long[][] neighbors = reshape(flatNeighbors, numQueries, topK); + float[][] distances = reshape(flatDistances, numQueries, topK); + + return new HnswSearchResults(neighbors, distances, topK, numQueries); + } catch (Throwable e) { + throw new RuntimeException("Error during HNSW search", e); + } + } + + private static long[] extractLongArray(MemorySegment memory, int size) { + long[] result = new long[size]; + for (int i = 0; i < size; i++) { + result[i] = memory.get(ValueLayout.JAVA_LONG, i * ValueLayout.JAVA_LONG.byteSize()); + } + return result; + } + + private static float[] extractFloatArray(MemorySegment memory, int size) { + float[] result = new float[size]; + for (int i = 0; i < size; i++) { + result[i] = memory.get(ValueLayout.JAVA_FLOAT, i * ValueLayout.JAVA_FLOAT.byteSize()); + } + return result; + } + + private static long[][] reshape(long[] flatArray, int rows, int cols) { + long[][] reshaped = new long[rows][cols]; + for (int i = 0; i < rows; i++) { + System.arraycopy(flatArray, i * cols, reshaped[i], 0, cols); + } + return reshaped; + } + + private static float[][] reshape(float[] flatArray, int rows, int cols) { + float[][] reshaped = new float[rows][cols]; + for (int i = 0; i < rows; i++) { + System.arraycopy(flatArray, i * cols, reshaped[i], 0, cols); + } + return reshaped; + } + +} \ No newline at end of file diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java index 47e42f2a5b..615b0298d7 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java @@ -48,25 +48,15 @@ public class CagraBuildAndSearchTest { public void testIndexingAndSearchingFlow() throws Throwable { // Sample data and query - float[][] dataset = { - { 0.74021935f, 0.9209938f }, - { 0.03902049f, 0.9689629f }, - { 0.92514056f, 0.4463501f }, - { 0.6673192f, 0.10993068f } - }; + float[][] dataset = { { 0.74021935f, 0.9209938f }, { 0.03902049f, 0.9689629f }, { 0.92514056f, 0.4463501f }, + { 0.6673192f, 0.10993068f } }; Map map = Map.of(0, 0, 1, 1, 2, 2, 3, 3); - float[][] queries = { - { 0.48216683f, 0.0428398f }, - { 0.5084142f, 0.6545497f }, - { 0.51260436f, 0.2643005f }, - { 0.05198065f, 0.5789965f } - }; + float[][] queries = { { 0.48216683f, 0.0428398f }, { 0.5084142f, 0.6545497f }, { 0.51260436f, 0.2643005f }, + { 0.05198065f, 0.5789965f } }; // Expected search results - List> expectedResults = Arrays.asList( - Map.of(3, 0.038782578f, 2, 0.3590463f, 0, 0.83774555f), - Map.of(0, 0.12472608f, 2, 0.21700792f, 1, 0.31918612f), - Map.of(3, 0.047766715f, 2, 0.20332818f, 0, 0.48305473f), + List> expectedResults = Arrays.asList(Map.of(3, 0.038782578f, 2, 0.3590463f, 0, 0.83774555f), + Map.of(0, 0.12472608f, 2, 0.21700792f, 1, 0.31918612f), Map.of(3, 0.047766715f, 2, 0.20332818f, 0, 0.48305473f), Map.of(1, 0.15224178f, 0, 0.59063464f, 3, 0.5986642f)); for (int j = 0; j < 10; j++) { @@ -75,17 +65,11 @@ public void testIndexingAndSearchingFlow() throws Throwable { // Configure index parameters CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) - .withCagraGraphBuildAlgo(CagraGraphBuildAlgo.NN_DESCENT) - .withGraphDegree(1) - .withIntermediateGraphDegree(2) - .withNumWriterThreads(32) - .build(); + .withCagraGraphBuildAlgo(CagraGraphBuildAlgo.NN_DESCENT).withGraphDegree(1).withIntermediateGraphDegree(2) + .withNumWriterThreads(32).build(); // Create the index with the dataset - CagraIndex index = new CagraIndex.Builder(resources) - .withDataset(dataset) - .withIndexParams(indexParams) - .build(); + CagraIndex index = new CagraIndex.Builder(resources).withDataset(dataset).withIndexParams(indexParams).build(); // Saving the index on to the disk. String indexFileName = UUID.randomUUID().toString() + ".cag"; @@ -94,21 +78,14 @@ public void testIndexingAndSearchingFlow() throws Throwable { // Loading a CAGRA index from disk. File indexFile = new File(indexFileName); InputStream inputStream = new FileInputStream(indexFile); - CagraIndex loadedIndex = new CagraIndex.Builder(resources) - .from(inputStream) - .build(); + CagraIndex loadedIndex = new CagraIndex.Builder(resources).from(inputStream).build(); // Configure search parameters - CagraSearchParams searchParams = new CagraSearchParams.Builder(resources) - .build(); + CagraSearchParams searchParams = new CagraSearchParams.Builder(resources).build(); // Create a query object with the query vectors - CagraQuery cuvsQuery = new CagraQuery.Builder() - .withTopK(3) - .withSearchParams(searchParams) - .withQueryVectors(queries) - .withMapping(map) - .build(); + CagraQuery cuvsQuery = new CagraQuery.Builder().withTopK(3).withSearchParams(searchParams) + .withQueryVectors(queries).withMapping(map).build(); // Perform the search SearchResults results = index.search(cuvsQuery); diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java new file mode 100644 index 0000000000..2664cc3bdd --- /dev/null +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/TestCagraToHnsw.java @@ -0,0 +1,88 @@ +package com.nvidia.cuvs; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.io.FileOutputStream; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.nio.file.Files; +import java.nio.file.Path; + +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.nvidia.cuvs.hnsw.HnswQuery; +import com.nvidia.cuvs.hnsw.HnswSearchParameters; +import com.nvidia.cuvs.hnsw.HnswSearchResults; +import com.nvidia.cuvs.hnsw.HnswUtil; + +public class TestCagraToHnsw { + + private static final Logger log = LoggerFactory.getLogger(TestCagraToHnsw.class); + + @Test + public void testSerializationAndSearch() throws Throwable { + String cagraFilePath = "cagra_index.bin"; + String hnswFilePath = "hnsw_index.bin"; + + try (CuVSResources resources = new CuVSResources()) { + log.info("Starting CAGRA to HNSW test."); + + // Step 1: Build a sample dataset + float[][] dataset = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } }; + + log.info("Building CAGRA index..."); + CagraIndex cagraIndex = new CagraIndex.Builder(resources).withDataset(dataset) + .withIndexParams(new CagraIndexParams.Builder(resources).build()).build(); + log.info("CAGRA index built successfully."); + + log.info("Serializing CAGRA index to file: {}", cagraFilePath); + try (FileOutputStream outputStream = new FileOutputStream(cagraFilePath)) { + cagraIndex.serialize(outputStream); + } + assertTrue("CAGRA index file should exist", new File(cagraFilePath).exists()); + log.info("CAGRA index serialized to: {}", cagraFilePath); + + log.info("Converting CAGRA index to HNSW format..."); + HnswUtil.serializeCagraToHnsw(resources, cagraFilePath, hnswFilePath); + + File hnswFile = new File(hnswFilePath); + assertTrue("HNSW index file should exist", hnswFile.exists()); + log.info("HNSW index file successfully created at: {}", hnswFilePath); + + // Step 2: Prepare a memory segment for the HNSW index + MemorySegment hnswIndexSegment = resources.arena.allocate(ValueLayout.ADDRESS.byteSize()); + + // Step 3: Perform a search + log.info("Starting HNSW search..."); + float[][] queryVectors = { { 2.0f, 3.0f, 4.0f }, { 5.0f, 6.0f, 7.0f } }; + int topK = 2; + + // Create search parameters + HnswSearchParameters searchParams = new HnswSearchParameters(20, 2); + + // Create query object with the correct constructor + HnswQuery query = new HnswQuery(hnswIndexSegment, queryVectors, topK, searchParams); + + // Perform the search + HnswSearchResults results = HnswUtil.search(resources, query); + + // Validate results + assertNotNull("Search results should not be null", results); + assertTrue("Search results should have neighbors", results.getNeighbors().length > 0); + assertTrue("Search results should have distances", results.getDistances().length > 0); + log.info("Search results retrieved: Neighbors = {}, Distances = {}", results.getNeighbors(), + results.getDistances()); + + Files.deleteIfExists(Path.of(cagraFilePath)); + Files.deleteIfExists(Path.of(hnswFilePath)); + log.info("Test completed successfully. Temporary files cleaned up."); + } catch (Exception e) { + log.error("Test failed: ", e); + throw e; // Re-throw exception to mark test as failed + } + } +} diff --git a/java/internal/src/cuvs_java.c b/java/internal/src/cuvs_java.c index b6a078e240..8d4cf3ca40 100644 --- a/java/internal/src/cuvs_java.c +++ b/java/internal/src/cuvs_java.c @@ -21,6 +21,7 @@ #include #include #include +#include #define try bool __HadError=false; #define catch(x) ExitJmp:if(__HadError) @@ -65,8 +66,7 @@ cuvsCagraIndex_t build_cagra_index(float *dataset, long rows, long dimensions, c index_params->compression = compression_params; *returnValue = cuvsCagraBuild(cuvsResources, index_params, &dataset_tensor, index); - - omp_set_num_threads(1); + omp_set_num_threads(1); return index; } @@ -113,3 +113,114 @@ void search_cagra_index(cuvsCagraIndex_t index, float *queries, int topk, long n cuvsRMMFree(cuvsResources, neighbors, sizeof(uint32_t) * n_queries * topk); cuvsRMMFree(cuvsResources, queries_d, sizeof(float) * n_queries * dimensions); } + +void convert_cagra_to_hnsw(cuvsResources_t resources, + const char* cagra_filename, + const char* hnsw_filename, + int* return_value) { + if (!resources || !cagra_filename || !hnsw_filename) { + *return_value = -1; // Invalid parameters + printf("Error: Invalid parameters provided\n"); + return; + } + + printf("Initializing conversion: CAGRA -> HNSW\n"); + printf("CAGRA file: %s\n", cagra_filename); + printf("HNSW file: %s\n", hnsw_filename); + + cuvsCagraIndex_t cagra_index; + cuvsCagraIndexCreate(&cagra_index); + + // Step 1: Deserialize the CAGRA index + *return_value = cuvsCagraDeserialize(resources, cagra_filename, cagra_index); + if (*return_value != CUVS_SUCCESS) { + printf("Error: Failed to deserialize CAGRA index. Error code: %d\n", *return_value); + cuvsCagraIndexDestroy(cagra_index); + return; + } + printf("Successfully deserialized CAGRA index\n"); + + // Step 2: Serialize to HNSW format + *return_value = cuvsCagraSerializeToHnswlib(resources, hnsw_filename, cagra_index); + if (*return_value != CUVS_SUCCESS) { + printf("Error: Failed to serialize to HNSW format. Error code: %d\n", *return_value); + cuvsCagraIndexDestroy(cagra_index); + return; + } + + printf("Successfully serialized CAGRA index to HNSW format: %s\n", hnsw_filename); + cuvsCagraIndexDestroy(cagra_index); +} + +void search_hnsw_index(cuvsHnswIndex_t index, + float* queries, + int topk, + long n_queries, + int dimensions, + cuvsResources_t cuvsResources, + uint64_t* neighbors_h, + float* distances_h, + int* returnValue, + cuvsHnswSearchParams_t search_params) { + + printf("C layer function invoked successfully\n"); + uint64_t* neighbors; + float* distances; + float* queries_d; + + // Allocate memory + cuvsRMMAlloc(cuvsResources, (void**)&queries_d, sizeof(float) * n_queries * dimensions); + cuvsRMMAlloc(cuvsResources, (void**)&neighbors, sizeof(uint64_t) * n_queries * topk); + cuvsRMMAlloc(cuvsResources, (void**)&distances, sizeof(float) * n_queries * topk); + + // Copy queries to device memory + cudaMemcpy(queries_d, queries, sizeof(float) * n_queries * dimensions, cudaMemcpyDefault); + + // Prepare tensors + int64_t queries_shape[2] = {n_queries, dimensions}; + DLManagedTensor queries_tensor = prepare_tensor(queries_d, queries_shape, kDLFloat); + + int64_t neighbors_shape[2] = {n_queries, topk}; + DLManagedTensor neighbors_tensor = prepare_tensor(neighbors, neighbors_shape, kDLUInt); + + int64_t distances_shape[2] = {n_queries, topk}; + DLManagedTensor distances_tensor = prepare_tensor(distances, distances_shape, kDLFloat); + + // Perform search + *returnValue = cuvsHnswSearch(cuvsResources, search_params, index, &queries_tensor, &neighbors_tensor, &distances_tensor); + + // Check and print results + if (*returnValue != CUVS_SUCCESS) { + printf("Error: Search failed. Error code: %d\n", *returnValue); + } else { + printf("Search completed successfully\n"); + + // Print neighbors and distances + printf("HNSW Search Results:\n"); + for (long i = 0; i < n_queries; i++) { + printf("Query %ld:\n", i); + printf(" Neighbors: "); + for (int j = 0; j < topk; j++) { + printf("%lu ", neighbors[i * topk + j]); + } + printf("\n"); + + printf(" Distances: "); + for (int j = 0; j < topk; j++) { + printf("%.6f ", distances[i * topk + j]); + } + printf("\n"); + } + } + + // Copy results back to host memory + cudaMemcpy(neighbors_h, neighbors, sizeof(uint64_t) * n_queries * topk, cudaMemcpyDefault); + cudaMemcpy(distances_h, distances, sizeof(float) * n_queries * topk, cudaMemcpyDefault); + + // Free allocated resources + cuvsRMMFree(cuvsResources, distances, sizeof(float) * n_queries * topk); + cuvsRMMFree(cuvsResources, neighbors, sizeof(uint64_t) * n_queries * topk); + cuvsRMMFree(cuvsResources, queries_d, sizeof(float) * n_queries * dimensions); +} + +