Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ protected IndexReference(MemorySegment indexMemorySegment) {
*
* @return index MemorySegment
*/
protected MemorySegment getMemorySegment() {
public MemorySegment getMemorySegment() {
return memorySegment;
}
}
Expand Down
26 changes: 24 additions & 2 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ public class CuVSResources implements AutoCloseable {

private final MethodHandle createResourcesMethodHandle;
private final MethodHandle destroyResourcesMethodHandle;
public final MethodHandle cagraToHnswHandle;
private MemorySegment resourcesMemorySegment;
public final MethodHandle hnswSearchHandle;

/**
* Constructor that allocates the resources needed for cuVS
Expand All @@ -61,6 +63,26 @@ public CuVSResources() throws Throwable {
destroyResourcesMethodHandle = linker.downcallHandle(libcuvsNativeLibrary.find("destroy_resources").get(),
FunctionDescriptor.ofVoid(ValueLayout.ADDRESS, ValueLayout.ADDRESS));
createResources();

cagraToHnswHandle = linker.downcallHandle(
libcuvsNativeLibrary.find("convert_cagra_to_hnsw")
.orElseThrow(() -> new IllegalStateException("convert_cagra_to_hnsw not found")),
FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS));

hnswSearchHandle = linker.downcallHandle(
libcuvsNativeLibrary.find("search_hnsw_index")
.orElseThrow(() -> new RuntimeException("search_hnsw_index function not found in native library")),
FunctionDescriptor.of(ValueLayout.JAVA_INT, // Return type
ValueLayout.ADDRESS, // cuvsResources_t
ValueLayout.ADDRESS, // cuvsHnswIndex_t
ValueLayout.ADDRESS, // float* queries
ValueLayout.JAVA_INT, // topK
ValueLayout.JAVA_INT, // n_queries
ValueLayout.JAVA_INT, // dimensions
ValueLayout.ADDRESS, // neighbors_h
ValueLayout.ADDRESS, // distances_h
ValueLayout.ADDRESS // search_params
));
}

/**
Expand Down Expand Up @@ -91,14 +113,14 @@ public void close() {
*
* @return cuvsResources MemorySegment
*/
protected MemorySegment getMemorySegment() {
public MemorySegment getMemorySegment() {
return resourcesMemorySegment;
}

/**
* Returns the loaded libcuvs_java.so as a {@link SymbolLookup}
*/
protected SymbolLookup getLibcuvsNativeLibrary() {
public SymbolLookup getLibcuvsNativeLibrary() {
return libcuvsNativeLibrary;
}
}
40 changes: 40 additions & 0 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,23 @@
import java.lang.foreign.MemoryLayout.PathElement;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.VarHandle;
import java.nio.charset.StandardCharsets;

import org.apache.commons.io.IOUtils;

import com.nvidia.cuvs.CuVSResources;

public class Util {

public static MemorySegment toCString(Arena arena, String string) {
byte[] bytes = (string + "\0").getBytes();
MemorySegment segment = arena.allocate(bytes.length);
segment.copyFrom(MemorySegment.ofArray(bytes));
return segment;
}

/**
* A utility method for getting an instance of {@link MemorySegment} for a
* {@link String}.
Expand Down Expand Up @@ -97,4 +109,32 @@ public static File loadLibraryFromJar(String path) throws IOException {

return temp;
}

public static void convertCagraToHnsw(CuVSResources resources, String cagraFilePath, String hnswFilePath) {
try {
MethodHandle convertCagraToHnswHandle = resources.cagraToHnswHandle;

byte[] cagraBytes = cagraFilePath.getBytes(StandardCharsets.UTF_8);
byte[] hnswBytes = hnswFilePath.getBytes(StandardCharsets.UTF_8);

MemorySegment cagraFileSegment = resources.arena.allocate(cagraBytes.length + 1);
MemorySegment hnswFileSegment = resources.arena.allocate(hnswBytes.length + 1);

cagraFileSegment.asSlice(0, cagraBytes.length).copyFrom(MemorySegment.ofArray(cagraBytes));
cagraFileSegment.set(ValueLayout.JAVA_BYTE, cagraBytes.length, (byte) 0);

hnswFileSegment.asSlice(0, hnswBytes.length).copyFrom(MemorySegment.ofArray(hnswBytes));
hnswFileSegment.set(ValueLayout.JAVA_BYTE, hnswBytes.length, (byte) 0);

int returnValue = (int) convertCagraToHnswHandle.invoke(resources.getMemorySegment(), cagraFileSegment,
hnswFileSegment);

if (returnValue != 0) {
throw new RuntimeException("Failed to convert CAGRA to HNSW, error code: " + returnValue);
}
} catch (Throwable e) {
throw new RuntimeException("Error during CAGRA to HNSW conversion", e);
}
}

}
35 changes: 35 additions & 0 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswQuery.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.nvidia.cuvs.hnsw;

import java.lang.foreign.MemorySegment;

public class HnswQuery {
private final float[][] queryVectors;
private final int topK;
private final HnswSearchParameters searchParameters;
private final MemorySegment indexMemorySegment;

public HnswQuery(MemorySegment indexMemorySegment, float[][] queryVectors, int topK,
HnswSearchParameters searchParams) {
this.queryVectors = queryVectors;
this.topK = topK;
this.searchParameters = searchParams;
this.indexMemorySegment = indexMemorySegment;
}

public float[][] getQueryVectors() {
return queryVectors;
}

public int getTopK() {
return topK;
}

public HnswSearchParameters getSearchParameters() {
return searchParameters;
}

public MemorySegment getIndexMemorySegment() {
return indexMemorySegment;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.nvidia.cuvs.hnsw;

public class HnswSearchParameters {

private int ef;
private int numThreads;

public HnswSearchParameters(int ef, int numThreads) {
this.ef = ef;
this.numThreads = numThreads;
}

public HnswSearchParameters withEf(int ef) {
this.ef = ef;
return this;
}

public HnswSearchParameters withNumThreads(int threads) {
this.numThreads = threads;
return this;
}

public int getEf() {
return ef;
}

public int getNumThreads() {
return numThreads;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.nvidia.cuvs.hnsw;

public class HnswSearchResults {
private final long[][] neighbors;
private final float[][] distances;
private final int topK;
private final int numQueries;

public HnswSearchResults(long[][] neighbors, float[][] distances, int topK, int numQueries) {
this.neighbors = neighbors;
this.distances = distances;
this.topK = topK;
this.numQueries = numQueries;
}

public long[][] getNeighbors() {
return neighbors;
}

public float[][] getDistances() {
return distances;
}

public int getTopK() {
return topK;
}

public int getNumQueries() {
return numQueries;
}
}
110 changes: 110 additions & 0 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/hnsw/HnswUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package com.nvidia.cuvs.hnsw;

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;

import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.common.Util;

public class HnswUtil {

public static void serializeCagraToHnsw(CuVSResources resources, String cagraFilePath, String hnswFilePath) {
try (Arena arena = Arena.ofConfined()) {
MemorySegment cagraPathSegment = Util.toCString(arena, cagraFilePath);
MemorySegment hnswPathSegment = Util.toCString(arena, hnswFilePath);

int result = (int) resources.cagraToHnswHandle.invokeExact(resources.getMemorySegment(), cagraPathSegment,
hnswPathSegment);

if (result != 0) {
throw new RuntimeException("Failed to serialize CAGRA index to HNSW file. Error code: " + result);
}
} catch (Throwable e) {
throw new RuntimeException("Error during serialization: " + e.getMessage(), e);
}
}

public static HnswSearchResults search(CuVSResources resources, HnswQuery query) {
try {
// Extract parameters from query
float[][] queryVectors = query.getQueryVectors();
int topK = query.getTopK();
int dimensions = queryVectors[0].length;
int numQueries = queryVectors.length;

// Use the shared arena from resources
Arena arena = resources.arena;

// Prepare flat memory for queries
MemorySegment queriesMemory = Util.buildMemorySegment(resources.linker, resources.arena, queryVectors);
MemorySegment neighborsMemory = arena.allocate(ValueLayout.JAVA_LONG.byteSize() * topK * numQueries);
MemorySegment distancesMemory = arena.allocate(ValueLayout.JAVA_FLOAT.byteSize() * topK * numQueries);

// Configure search parameters
MemorySegment searchParamsMemory = arena.allocate(ValueLayout.JAVA_INT.byteSize() * 2);
searchParamsMemory.setAtIndex(ValueLayout.JAVA_INT, 0, query.getSearchParameters().getEf());
searchParamsMemory.setAtIndex(ValueLayout.JAVA_INT, 1, query.getSearchParameters().getNumThreads());

// Invoke the HNSW search function using the native handle
int result = (int) resources.hnswSearchHandle.invokeExact(resources.getMemorySegment(), // cuVS resources
query.getIndexMemorySegment(), // cuvsHnswIndex_t
queriesMemory, // Queries
topK, // topK
numQueries, // n_queries
dimensions, // dimensions
neighborsMemory, // Neighbors output
distancesMemory, // Distances output
searchParamsMemory // Search parameters
);

if (result != 0) {
throw new RuntimeException("HNSW search failed. Error code: " + result);
}

// Extract neighbors and distances from memory
long[] flatNeighbors = extractLongArray(neighborsMemory, topK * numQueries);
float[] flatDistances = extractFloatArray(distancesMemory, topK * numQueries);

long[][] neighbors = reshape(flatNeighbors, numQueries, topK);
float[][] distances = reshape(flatDistances, numQueries, topK);

return new HnswSearchResults(neighbors, distances, topK, numQueries);
} catch (Throwable e) {
throw new RuntimeException("Error during HNSW search", e);
}
}

private static long[] extractLongArray(MemorySegment memory, int size) {
long[] result = new long[size];
for (int i = 0; i < size; i++) {
result[i] = memory.get(ValueLayout.JAVA_LONG, i * ValueLayout.JAVA_LONG.byteSize());
}
return result;
}

private static float[] extractFloatArray(MemorySegment memory, int size) {
float[] result = new float[size];
for (int i = 0; i < size; i++) {
result[i] = memory.get(ValueLayout.JAVA_FLOAT, i * ValueLayout.JAVA_FLOAT.byteSize());
}
return result;
}

private static long[][] reshape(long[] flatArray, int rows, int cols) {
long[][] reshaped = new long[rows][cols];
for (int i = 0; i < rows; i++) {
System.arraycopy(flatArray, i * cols, reshaped[i], 0, cols);
}
return reshaped;
}

private static float[][] reshape(float[] flatArray, int rows, int cols) {
float[][] reshaped = new float[rows][cols];
for (int i = 0; i < rows; i++) {
System.arraycopy(flatArray, i * cols, reshaped[i], 0, cols);
}
return reshaped;
}

}
Loading
Loading