diff --git a/java/cuvs-java/pom.xml b/java/cuvs-java/pom.xml
index aacbad2ca2..ac2d8f4826 100644
--- a/java/cuvs-java/pom.xml
+++ b/java/cuvs-java/pom.xml
@@ -14,144 +14,152 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
--->
+-->
- 4.0.0
- com.nvidia.cuvs
- cuvs-java
- 24.12.1
- cuvs-java
- jar
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
+ 4.0.0
+ com.nvidia.cuvs
+ cuvs-java
+ 24.12.1
+ cuvs-java
+ jar
-
- 22
- 22
- UTF-8
- UTF-8
-
+
+ 22
+ 22
+ UTF-8
+ UTF-8
+
-
-
- commons-io
- commons-io
- 2.15.1
-
+
+
+ commons-io
+ commons-io
+ 2.15.1
+
-
- com.github.fommil
- jniloader
- 1.1
-
+
+ com.github.fommil
+ jniloader
+ 1.1
+
-
- org.slf4j
- slf4j-api
- 2.0.13
-
+
+ org.slf4j
+ slf4j-api
+ 2.0.13
+
-
- org.slf4j
- slf4j-simple
- 2.0.13
- runtime
-
+
+ org.slf4j
+ slf4j-simple
+ 2.0.13
+ runtime
+
-
- org.junit.jupiter
- junit-jupiter-api
- 5.10.0
-
+
+ junit
+ junit
+ 4.13.1
+ test
+
-
+
+ org.apache.lucene
+ lucene-test-framework
+ 9.12.0
+ test
+
+
-
-
-
- org.apache.maven.plugins
- maven-surefire-plugin
- 2.7
-
-
- ${project.build.directory}/classes
-
-
-
-
- org.apache.maven.plugins
- maven-dependency-plugin
- 2.10
-
-
- copy
- compile
-
- copy
-
-
-
-
- com.nvidia.cuvs
- cuvs-java-internal
- 24.12
- so
- false
-
- ${project.build.directory}/classes
- libcuvs_java.so
-
-
-
-
-
-
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+ 2.7
+
+
+ ${project.build.directory}/classes
+
+
+
+
+ org.apache.maven.plugins
+ maven-dependency-plugin
+ 2.10
+
+
+ copy
+ compile
+
+ copy
+
+
+
+
+ com.nvidia.cuvs
+ cuvs-java-internal
+ 24.12
+ so
+ false
+
+ ${project.build.directory}/classes
+ libcuvs_java.so
+
+
+
+
+
+
-
- org.apache.maven.plugins
- maven-assembly-plugin
- 3.4.2
-
-
- jar-with-dependencies
-
-
- add
-
-
-
-
- assemble-all
- package
-
- single
-
-
-
-
-
- org.apache.maven.plugins
- maven-jar-plugin
- 2.2
-
-
-
- true
-
- com.nvidia.cuvs.examples.CagraExample
-
-
-
-
-
- org.apache.maven.plugins
- maven-javadoc-plugin
- 3.6.2
-
- com.nvidia.cuvs.examples,com.nvidia.cuvs.panama
- ${project.build.directory}
-
-
-
-
+
+ org.apache.maven.plugins
+ maven-assembly-plugin
+ 3.4.2
+
+
+ jar-with-dependencies
+
+
+ add
+
+
+
+
+ assemble-all
+ package
+
+ single
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+ 2.2
+
+
+
+ true
+
+ com.nvidia.cuvs.examples.CagraExample
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-javadoc-plugin
+ 3.6.2
+
+
+ com.nvidia.cuvs.examples,com.nvidia.cuvs.panama
+ ${project.build.directory}
+
+
+
+
diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java
index 5020c0b6d8..373a295a8e 100644
--- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java
+++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSResources.java
@@ -41,6 +41,7 @@ public class CuVSResources {
private final MethodHandle createResourceMethodHandle;
private final MemorySegment memorySegment;
+ private final MethodHandle getGpuDetailsHandle;
/**
* Constructor that allocates the resources needed for cuVS
@@ -61,6 +62,10 @@ public CuVSResources() throws Throwable {
MemorySegment returnValueMemorySegment = arena.allocate(returnValueMemoryLayout);
memorySegment = (MemorySegment) createResourceMethodHandle.invokeExact(returnValueMemorySegment);
+ getGpuDetailsHandle = linker.downcallHandle(
+ libcuvsNativeLibrary.find("get_gpu_details")
+ .orElseThrow(() -> new IllegalStateException("get_gpu_details not found in library")),
+ FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.ADDRESS, ValueLayout.JAVA_INT, ValueLayout.JAVA_INT));
}
/**
@@ -78,4 +83,13 @@ protected MemorySegment getMemorySegment() {
protected SymbolLookup getLibcuvsNativeLibrary() {
return libcuvsNativeLibrary;
}
+
+ /**
+ * Gets the MethodHandle for the `get_gpu_details` function.
+ *
+ * @return MethodHandle for `get_gpu_details`
+ */
+ public MethodHandle getGpuDetailsHandle() {
+ return getGpuDetailsHandle;
+ }
}
\ No newline at end of file
diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/GpuDetail.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/GpuDetail.java
new file mode 100644
index 0000000000..fdddb88c09
--- /dev/null
+++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/GpuDetail.java
@@ -0,0 +1,30 @@
+package com.nvidia.cuvs.common;
+
+public class GpuDetail {
+ private final String name;
+ private final long totalMemory;
+ private final long freeMemory;
+
+ public GpuDetail(String name, long totalMemory, long freeMemory) {
+ this.name = name;
+ this.totalMemory = totalMemory;
+ this.freeMemory = freeMemory;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ public long getTotalMemory() {
+ return totalMemory;
+ }
+
+ public long getFreeMemory() {
+ return freeMemory;
+ }
+
+ @Override
+ public String toString() {
+ return "GpuDetail{" + "name='" + name + '\'' + ", totalMemory=" + totalMemory + ", freeMemory=" + freeMemory + '}';
+ }
+}
diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java
index 750e49d642..fba6fb427a 100644
--- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java
+++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/common/Util.java
@@ -20,16 +20,50 @@
import java.io.FileOutputStream;
import java.io.IOException;
import java.lang.foreign.Arena;
+import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.Linker;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemoryLayout.PathElement;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
+import java.lang.invoke.MethodHandle;
import java.lang.invoke.VarHandle;
import org.apache.commons.io.IOUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.nvidia.cuvs.CuVSResources;
+import com.nvidia.cuvs.panama.GpuDetailLayout;
public class Util {
+
+ private static final Logger log = LoggerFactory.getLogger(Util.class);
+
+ public static GpuDetail[] getGpuDetails(CuVSResources resources, int maxGpus, int maxDetailLength) {
+ try {
+
+ MemorySegment detailsSegment = resources.arena.allocate(maxGpus * maxDetailLength);
+
+ int gpuCount = (int) resources.getGpuDetailsHandle().invoke(detailsSegment, maxGpus, maxDetailLength);
+
+ if (gpuCount < 0) {
+ throw new RuntimeException("Failed to retrieve GPU details");
+ }
+
+ GpuDetail[] gpuDetails = new GpuDetail[gpuCount];
+ for (int i = 0; i < gpuCount; i++) {
+ MemorySegment structSegment = detailsSegment.asSlice(i * GpuDetailLayout.LAYOUT.byteSize(),
+ GpuDetailLayout.LAYOUT.byteSize());
+ gpuDetails[i] = GpuDetailLayout.fromMemorySegment(structSegment);
+ }
+
+ return gpuDetails;
+ } catch (Throwable e) {
+ throw new RuntimeException("Failed to invoke get_gpu_details", e);
+ }
+ }
+
/**
* A utility method for getting an instance of {@link MemorySegment} for a
* {@link String}.
diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/panama/GpuDetailLayout.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/panama/GpuDetailLayout.java
new file mode 100644
index 0000000000..ad3aaddeda
--- /dev/null
+++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/panama/GpuDetailLayout.java
@@ -0,0 +1,27 @@
+package com.nvidia.cuvs.panama;
+
+import java.lang.foreign.MemoryLayout;
+import java.lang.foreign.MemorySegment;
+import java.lang.foreign.ValueLayout;
+import java.nio.charset.StandardCharsets;
+
+import com.nvidia.cuvs.common.GpuDetail;
+
+public class GpuDetailLayout {
+ // Define the struct layout
+ public static final MemoryLayout LAYOUT = MemoryLayout.structLayout(
+ MemoryLayout.sequenceLayout(64, ValueLayout.JAVA_BYTE).withName("name"),
+ ValueLayout.JAVA_LONG.withName("totalMemory"), ValueLayout.JAVA_LONG.withName("freeMemory"));
+
+ public static final int MAX_NAME_LENGTH = 64;
+
+ public static GpuDetail fromMemorySegment(MemorySegment segment) {
+ // Extract fields from the memory segment
+ String name = new String(segment.asSlice(0, MAX_NAME_LENGTH).toArray(ValueLayout.JAVA_BYTE), StandardCharsets.UTF_8)
+ .trim();
+ long totalMemory = segment.get(ValueLayout.JAVA_LONG, MAX_NAME_LENGTH);
+ long freeMemory = segment.get(ValueLayout.JAVA_LONG, MAX_NAME_LENGTH + ValueLayout.JAVA_LONG.byteSize());
+
+ return new GpuDetail(name, totalMemory, freeMemory);
+ }
+}
diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java
index c5788d3427..f5aaf80ae4 100644
--- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java
+++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java
@@ -16,7 +16,7 @@
package com.nvidia.cuvs;
-import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.Assert.assertEquals;
import java.io.File;
import java.io.FileInputStream;
@@ -28,7 +28,7 @@
import java.util.Map;
import java.util.UUID;
-import org.junit.jupiter.api.Test;
+import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -100,14 +100,14 @@ public void testIndexingAndSearchingFlow() throws Throwable {
// Check results
log.info(results.getResults().toString());
- assertEquals(expectedResults, results.getResults(), "Results different than expected");
+ assertEquals("Results different than expected", expectedResults, results.getResults());
// Search from deserialized index
results = loadedIndex.search(cuvsQuery);
// Check results
log.info(results.getResults().toString());
- assertEquals(expectedResults, results.getResults(), "Results different than expected");
+ assertEquals("Results different than expected", expectedResults, results.getResults());
// Cleanup
if (indexFile.exists()) {
diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java
new file mode 100644
index 0000000000..8cb6173c02
--- /dev/null
+++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java
@@ -0,0 +1,79 @@
+package com.nvidia.cuvs;
+
+import java.lang.invoke.MethodHandles;
+import java.util.Random;
+
+import org.apache.lucene.tests.util.LuceneTestCase;
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.carrotsearch.randomizedtesting.RandomizedContext;
+
+public class CagraRandomizedTest extends LuceneTestCase {
+ private Random random;
+ private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
+
+ @Before
+ public void setup() {
+
+ this.random = random();
+ log.info("Test seed: " + RandomizedContext.current().getRunnerSeedAsString());
+ }
+
+ @Test
+ public void testResultsTopKWithRandomValues() throws Throwable {
+ // Generate a random dataset
+ int numRows = random.nextInt(10) + 1;
+ int numCols = random.nextInt(5) + 1;
+ float[][] dataset = new float[numRows][numCols];
+ for (int i = 0; i < numRows; i++) {
+ for (int j = 0; j < numCols; j++) {
+ dataset[i][j] = random.nextFloat() * 100;
+ }
+ }
+
+ // Generate random query vectors
+ int numQueries = random.nextInt(5) + 1;
+ float[][] queries = new float[numQueries][numCols];
+ for (int i = 0; i < numQueries; i++) {
+ for (int j = 0; j < numCols; j++) {
+ queries[i][j] = random.nextFloat() * 100;
+ }
+ }
+
+ int topK = random.nextInt(numRows) + 1;
+
+ log.info("Dataset size: {}x{}", numRows, numCols);
+ log.info("Query size: {}x{}", numQueries, numCols);
+ log.info("TopK: {}", topK);
+
+ log.info("Dataset:");
+ for (float[] row : dataset) {
+ log.info(java.util.Arrays.toString(row));
+ }
+
+ log.info("Queries:");
+ for (float[] query : queries) {
+ log.info(java.util.Arrays.toString(query));
+ }
+
+ CuVSResources resources = new CuVSResources();
+
+ CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build();
+
+ CagraIndex index = new CagraIndex.Builder(resources).withDataset(dataset).withIndexParams(indexParams).build();
+
+ CagraQuery query = new CagraQuery.Builder().withQueryVectors(queries).withTopK(topK)
+ .withSearchParams(new CagraSearchParams.Builder(resources).build()).build();
+
+ CagraSearchResults results = index.search(query);
+
+ results.getResults().forEach(result -> {
+ log.info("Result size: {}", result.size());
+ assertEquals("TopK mismatch for query.", Math.min(topK, numRows), result.size());
+ });
+ }
+
+}
diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/common/TestUtil.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/common/TestUtil.java
new file mode 100644
index 0000000000..03156de712
--- /dev/null
+++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/common/TestUtil.java
@@ -0,0 +1,40 @@
+package com.nvidia.cuvs.common;
+
+import com.nvidia.cuvs.CuVSResources;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+public class TestUtil {
+
+ private static final Logger log = LoggerFactory.getLogger(TestUtil.class);
+
+ @Test
+ public void testGpuDetails() {
+ try {
+ CuVSResources resources = new CuVSResources();
+
+ int maxGpus = 10;
+ int maxDetailLength = 256;
+
+ GpuDetail[] gpuDetails = Util.getGpuDetails(resources, maxGpus, maxDetailLength);
+
+ assertNotNull("GPU details should not be null", gpuDetails);
+ assertTrue("GPU details array should contain at least one GPU", gpuDetails.length > 0);
+
+ log.info("Number of GPUs: {}", gpuDetails.length);
+ for (GpuDetail detail : gpuDetails) {
+ log.info("GPU Name: {}", detail.getName());
+ log.info("Total Memory (MB): {}", detail.getTotalMemory()/(1024*1024));
+ log.info("Free Memory (MB): {}", detail.getFreeMemory()/(1024*1024));
+ }
+
+ } catch (Throwable e) {
+ log.error("Test failed due to an exception: {}", e.getMessage(), e);
+ throw new RuntimeException("Test failed due to an exception: " + e.getMessage(), e);
+ }
+ }
+}
diff --git a/java/internal/src/cuvs_java.c b/java/internal/src/cuvs_java.c
index ec9ecb6af8..b9de569418 100644
--- a/java/internal/src/cuvs_java.c
+++ b/java/internal/src/cuvs_java.c
@@ -20,6 +20,8 @@
#include
#include
#include
+#include
+#include
cuvsResources_t create_resource(int *returnValue) {
cuvsResources_t cuvsResources;
@@ -91,3 +93,30 @@ void search_cagra_index(cuvsCagraIndex_t index, float *queries, int topk, long n
cudaMemcpy(neighbors_h, neighbors, sizeof(uint32_t) * n_queries * topk, cudaMemcpyDefault);
cudaMemcpy(distances_h, distances, sizeof(float) * n_queries * topk, cudaMemcpyDefault);
}
+
+typedef struct {
+ char name[64];
+ size_t totalMemory;
+ size_t freeMemory;
+} GpuDetail;
+
+int get_gpu_details(GpuDetail *details, int maxGpus) {
+ int deviceCount = 0;
+ cudaError_t err = cudaGetDeviceCount(&deviceCount);
+
+ if (err != cudaSuccess) {
+ return -1;
+ }
+
+ for (int i = 0; i < deviceCount && i < maxGpus; i++) {
+ struct cudaDeviceProp deviceProp;
+ cudaGetDeviceProperties(&deviceProp, i);
+
+ strncpy(details[i].name, deviceProp.name, sizeof(details[i].name) - 1);
+ details[i].name[sizeof(details[i].name) - 1] = '\0'; // Null-terminate
+
+ cudaMemGetInfo(&details[i].freeMemory, &details[i].totalMemory);
+ }
+
+ return deviceCount;
+}