From b83d9663b5b704b85c7865afc0a9074762d0853b Mon Sep 17 00:00:00 2001 From: punAhuja Date: Mon, 18 Nov 2024 23:33:14 +0530 Subject: [PATCH 1/9] Some unit tests added --- .../main/java/com/nvidia/cuvs/CagraIndex.java | 6 + .../java/com/nvidia/cuvs/CagraIndexTest.java | 144 ++++++++++++++++++ 2 files changed, 150 insertions(+) create mode 100644 java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java index 8ad64d7d69..052221213f 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java @@ -117,6 +117,9 @@ private void initializeMethodHandles() throws IOException { * index */ private IndexReference build() throws Throwable { + if (dataset == null || dataset.length == 0 || dataset[0].length == 0) { + throw new IllegalArgumentException("Dataset cannot be null or empty"); + } long rows = dataset.length; long cols = dataset[0].length; MemoryLayout layout = resources.linker.canonicalLayouts().get("int"); @@ -168,6 +171,9 @@ public CagraSearchResults search(CagraQuery query) throws Throwable { * bytes into */ public void serialize(OutputStream outputStream) throws Throwable { + if (outputStream == null) { + throw new IllegalArgumentException("Output stream cannot be null"); + } serialize(outputStream, File.createTempFile(UUID.randomUUID().toString(), ".cag")); } diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java new file mode 100644 index 0000000000..afaeedd6bf --- /dev/null +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java @@ -0,0 +1,144 @@ +package com.nvidia.cuvs; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import org.junit.jupiter.api.Disabled; + + + +import java.io.FileOutputStream; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +public class CagraIndexTest { + + @Test + public void testInvalidDataset() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + // Use consistent dataset parameters as the working test + float[][] invalidDataset = null; // Simulate an invalid dataset + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder() + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .build(); + new CagraIndex.Builder(resources) + .withDataset(invalidDataset) + .withIndexParams(indexParams) + .build(); + }); + + assertEquals("Dataset cannot be null or empty", exception.getMessage()); + } + + @Test + public void testSerializationWithoutOutputStream() throws Throwable { + // Use the same dataset as the working test + float[][] dataset = { + {0.74021935f, 0.9209938f}, + {0.03902049f, 0.9689629f}, + {0.92514056f, 0.4463501f}, + {0.6673192f, 0.10993068f} + }; + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder() + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .build(); + + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + index.serialize(null); // Pass null output stream + }); + + assertEquals("Output stream cannot be null", exception.getMessage()); + } + @Disabled + @Test + public void testSingleElementDataset() throws Throwable { + // Match dataset and parameters to the working test + float[][] dataset = { + {0.74021935f, 0.9209938f}, + {0.03902049f, 0.9689629f}, + {0.92514056f, 0.4463501f}, + {0.6673192f, 0.10993068f} + }; + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder() + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .build(); + + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + float[][] query = { + {0.48216683f, 0.0428398f}, + {0.5084142f, 0.6545497f}, + {0.51260436f, 0.2643005f}, + {0.05198065f, 0.5789965f} + }; + + CagraQuery cuvsQuery = new CagraQuery.Builder() + .withTopK(3) + .withSearchParams(new CagraSearchParams.Builder().build()) + .withQueryVectors(query) + .build(); + + CagraSearchResults results = index.search(cuvsQuery); + + // Verify the results size matches the queries + assertEquals(query.length, results.getResults().size(), "Expected one result for each query"); + } + @Disabled + @Test + public void testSearchResultMapping() throws Throwable { + // Match dataset and parameters to the working test + float[][] dataset = { + {0.74021935f, 0.9209938f}, + {0.03902049f, 0.9689629f}, + {0.92514056f, 0.4463501f}, + {0.6673192f, 0.10993068f} + }; + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder() + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .build(); + + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + // Use consistent query and mapping + Map mapping = Map.of(0, 100, 1, 200, 2, 300, 3, 400); + float[][] query = { + {0.48216683f, 0.0428398f}, + {0.5084142f, 0.6545497f}, + {0.51260436f, 0.2643005f}, + {0.05198065f, 0.5789965f} + }; + + CagraQuery cuvsQuery = new CagraQuery.Builder() + .withTopK(3) + .withSearchParams(new CagraSearchParams.Builder().build()) + .withQueryVectors(query) + .withMapping(mapping) + .build(); + + CagraSearchResults results = index.search(cuvsQuery); + + // Verify mapped results contain expected keys + results.getResults().forEach(result -> { + assertTrue(result.containsKey(100) || result.containsKey(200) || result.containsKey(300) || result.containsKey(400)); + }); + } +} From 1b43ebd5ae6b587b63dc587a699000f3b5b0e0d1 Mon Sep 17 00:00:00 2001 From: punAhuja Date: Tue, 19 Nov 2024 21:52:27 +0530 Subject: [PATCH 2/9] Added some more test-cases --- .../java/com/nvidia/cuvs/CagraIndexTest.java | 162 +++++++++++++++++- 1 file changed, 155 insertions(+), 7 deletions(-) diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java index afaeedd6bf..bf5cf3ee8f 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java @@ -3,17 +3,19 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import org.junit.jupiter.api.Disabled; - - +import java.io.File; +import java.io.FileInputStream; import java.io.FileOutputStream; +import java.io.InputStream; import java.util.Map; +import java.util.UUID; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; public class CagraIndexTest { - + @Test public void testInvalidDataset() { Throwable exception = assertThrows(IllegalArgumentException.class, () -> { @@ -31,7 +33,7 @@ public void testInvalidDataset() { assertEquals("Dataset cannot be null or empty", exception.getMessage()); } - + @Test public void testSerializationWithoutOutputStream() throws Throwable { // Use the same dataset as the working test @@ -58,7 +60,7 @@ public void testSerializationWithoutOutputStream() throws Throwable { assertEquals("Output stream cannot be null", exception.getMessage()); } - @Disabled + @Test public void testSingleElementDataset() throws Throwable { // Match dataset and parameters to the working test @@ -97,7 +99,7 @@ public void testSingleElementDataset() throws Throwable { // Verify the results size matches the queries assertEquals(query.length, results.getResults().size(), "Expected one result for each query"); } - @Disabled + @Test public void testSearchResultMapping() throws Throwable { // Match dataset and parameters to the working test @@ -141,4 +143,150 @@ public void testSearchResultMapping() throws Throwable { assertTrue(result.containsKey(100) || result.containsKey(200) || result.containsKey(300) || result.containsKey(400)); }); } + + public void testResultsTopK() throws Throwable { + float[][] dataset = { + {0.74021935f, 0.9209938f}, + {0.03902049f, 0.9689629f}, + {0.92514056f, 0.4463501f} + }; + + float[][] queries = { + {0.48216683f, 0.0428398f} + }; + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder().build(); + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + CagraQuery query = new CagraQuery.Builder() + .withQueryVectors(queries) + .withTopK(2) + .withSearchParams(new CagraSearchParams.Builder().build()) + .build(); + + CagraSearchResults results = index.search(query); + + // Verify each query result contains exactly TopK neighbors + results.getResults().forEach(result -> assertEquals(2, result.size())); + } + + @Test + public void testEmptyResults() throws Throwable { + float[][] dataset = { + {10.0f, 10.0f}, + {20.0f, 20.0f} + }; + + float[][] queries = { + {1000.0f, 1000.0f} + }; + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder().build(); + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + CagraQuery query = new CagraQuery.Builder() + .withQueryVectors(queries) + .withTopK(2) + .withSearchParams(new CagraSearchParams.Builder().build()) + .build(); + + CagraSearchResults results = index.search(query); + System.out.println(results.getResults()); + + // Verify no neighbors were found + assertTrue(results.getResults().isEmpty()); + } + + @Test + public void testSearchWithDeletedIndexFile() throws Throwable { + // Dataset and Query + float[][] dataset = { + {0.74021935f, 0.9209938f}, + {0.03902049f, 0.9689629f}, + {0.92514056f, 0.4463501f}, + {0.6673192f, 0.10993068f} + }; + + float[][] queries = { + {0.48216683f, 0.0428398f}, + {0.5084142f, 0.6545497f} + }; + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder() + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .build(); + + // Create and serialize index + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + String indexFileName = UUID.randomUUID().toString() + ".cag"; + index.serialize(new FileOutputStream(indexFileName)); + + // Delete the serialized file + File indexFile = new File(indexFileName); + if (indexFile.exists()) { + indexFile.delete(); + } + + // Attempt to create an InputStream from the deleted file + Throwable exception = assertThrows(Exception.class, () -> { + try (InputStream inputStream = new FileInputStream(indexFile)) { + CagraIndex deletedIndex = new CagraIndex.Builder(resources) + .from(inputStream) + .build(); + + CagraQuery query = new CagraQuery.Builder() + .withTopK(3) + .withSearchParams(new CagraSearchParams.Builder().build()) + .withQueryVectors(queries) + .build(); + + deletedIndex.search(query); + } + }); + + // Assert the exception type + assertTrue(exception instanceof java.io.FileNotFoundException, "Expected FileNotFoundException"); + } + + @Test + public void testNullQueryVectors() throws Throwable { + float[][] dataset = { + {0.74021935f, 0.9209938f}, + {0.03902049f, 0.9689629f} + }; + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder().build(); + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + CagraQuery invalidQuery = new CagraQuery.Builder() + .withQueryVectors(null) + .withTopK(3) + .withSearchParams(new CagraSearchParams.Builder().build()) + .build(); + + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + index.search(invalidQuery); + }); + + assertEquals("Query vectors cannot be null", exception.getMessage()); + } + + } From b1cb501ac7bf74fe074130c4d38b46e4944bcfd1 Mon Sep 17 00:00:00 2001 From: punAhuja Date: Tue, 26 Nov 2024 15:39:02 +0530 Subject: [PATCH 3/9] Generating random values using generated seed value sample test --- .../java/com/nvidia/cuvs/CagraIndexTest.java | 98 ++++++++++++------- 1 file changed, 62 insertions(+), 36 deletions(-) diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java index bf5cf3ee8f..71412ed7f4 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java @@ -9,13 +9,14 @@ import java.io.FileOutputStream; import java.io.InputStream; import java.util.Map; +import java.util.Random; import java.util.UUID; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; public class CagraIndexTest { - + @Disabled @Test public void testInvalidDataset() { Throwable exception = assertThrows(IllegalArgumentException.class, () -> { @@ -33,7 +34,7 @@ public void testInvalidDataset() { assertEquals("Dataset cannot be null or empty", exception.getMessage()); } - + @Disabled @Test public void testSerializationWithoutOutputStream() throws Throwable { // Use the same dataset as the working test @@ -60,7 +61,7 @@ public void testSerializationWithoutOutputStream() throws Throwable { assertEquals("Output stream cannot be null", exception.getMessage()); } - + @Disabled @Test public void testSingleElementDataset() throws Throwable { // Match dataset and parameters to the working test @@ -99,7 +100,7 @@ public void testSingleElementDataset() throws Throwable { // Verify the results size matches the queries assertEquals(query.length, results.getResults().size(), "Expected one result for each query"); } - + @Disabled @Test public void testSearchResultMapping() throws Throwable { // Match dataset and parameters to the working test @@ -144,36 +145,61 @@ public void testSearchResultMapping() throws Throwable { }); } - public void testResultsTopK() throws Throwable { - float[][] dataset = { - {0.74021935f, 0.9209938f}, - {0.03902049f, 0.9689629f}, - {0.92514056f, 0.4463501f} - }; - - float[][] queries = { - {0.48216683f, 0.0428398f} - }; - - CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder().build(); - CagraIndex index = new CagraIndex.Builder(resources) - .withDataset(dataset) - .withIndexParams(indexParams) - .build(); - - CagraQuery query = new CagraQuery.Builder() - .withQueryVectors(queries) - .withTopK(2) - .withSearchParams(new CagraSearchParams.Builder().build()) - .build(); - - CagraSearchResults results = index.search(query); - - // Verify each query result contains exactly TopK neighbors - results.getResults().forEach(result -> assertEquals(2, result.size())); - } - + @Test + public void testResultsTopKWithRandomValues() throws Throwable { + long seed = System.currentTimeMillis(); + String seedProperty = System.getProperty("test.seed"); + if (seedProperty != null) { + seed = Long.parseLong(seedProperty); + } + System.out.println("Using seed: " + seed); + + Random random = new Random(seed); + + int numRows = random.nextInt(10) + 1; // 1 - 10 rows + int numCols = random.nextInt(5) + 1; // 1 - 5 columns + float[][] dataset = new float[numRows][numCols]; + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + dataset[i][j] = random.nextFloat() * 100; + } + } + + int numQueries = random.nextInt(5) + 1; // 1 - 5 queries + float[][] queries = new float[numQueries][numCols]; + for (int i = 0; i < numQueries; i++) { + for (int j = 0; j < numCols; j++) { + queries[i][j] = random.nextFloat() * 100; + } + } + + int topK = random.nextInt(numRows) + 1; + System.out.println("Dataset size: " + numRows + "x" + numCols); + System.out.println("Query size: " + numQueries + "x" + numCols); + System.out.println("TopK: " + topK); + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder().build(); + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + CagraQuery query = new CagraQuery.Builder() + .withQueryVectors(queries) + .withTopK(topK) + .withSearchParams(new CagraSearchParams.Builder().build()) + .build(); + + CagraSearchResults results = index.search(query); + + results.getResults().forEach(result -> { + System.out.println("Result size: " + result.size()); + assertEquals(topK, result.size(), "TopK mismatch for query."); + }); + } + + @Disabled @Test public void testEmptyResults() throws Throwable { float[][] dataset = { @@ -204,7 +230,7 @@ public void testEmptyResults() throws Throwable { // Verify no neighbors were found assertTrue(results.getResults().isEmpty()); } - + @Disabled @Test public void testSearchWithDeletedIndexFile() throws Throwable { // Dataset and Query @@ -260,7 +286,7 @@ public void testSearchWithDeletedIndexFile() throws Throwable { // Assert the exception type assertTrue(exception instanceof java.io.FileNotFoundException, "Expected FileNotFoundException"); } - + @Disabled @Test public void testNullQueryVectors() throws Throwable { float[][] dataset = { From 989d7e2325d7d9ef2f307d67b48eaf89cfc4efaf Mon Sep 17 00:00:00 2001 From: Ishan Chattopadhyaya Date: Tue, 26 Nov 2024 21:59:36 +0530 Subject: [PATCH 4/9] Using lucene-test-framework for randomization logic --- java/cuvs-java/pom.xml | 14 ++++- .../nvidia/cuvs/CagraBuildAndSearchTest.java | 9 +-- .../java/com/nvidia/cuvs/CagraIndexTest.java | 55 ++++++++++--------- 3 files changed, 45 insertions(+), 33 deletions(-) diff --git a/java/cuvs-java/pom.xml b/java/cuvs-java/pom.xml index aacbad2ca2..e4b5a43bd4 100644 --- a/java/cuvs-java/pom.xml +++ b/java/cuvs-java/pom.xml @@ -58,10 +58,18 @@ runtime + + junit + junit + 4.13.1 + test + + - org.junit.jupiter - junit-jupiter-api - 5.10.0 + org.apache.lucene + lucene-test-framework + 9.12.0 + test diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java index c5788d3427..bfedd646c3 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchTest.java @@ -16,7 +16,8 @@ package com.nvidia.cuvs; -import static org.junit.jupiter.api.Assertions.assertEquals; + +import static org.junit.Assert.*; import java.io.File; import java.io.FileInputStream; @@ -28,7 +29,7 @@ import java.util.Map; import java.util.UUID; -import org.junit.jupiter.api.Test; +import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -100,14 +101,14 @@ public void testIndexingAndSearchingFlow() throws Throwable { // Check results log.info(results.getResults().toString()); - assertEquals(expectedResults, results.getResults(), "Results different than expected"); + assertEquals("Results different than expected", expectedResults, results.getResults()); // Search from deserialized index results = loadedIndex.search(cuvsQuery); // Check results log.info(results.getResults().toString()); - assertEquals(expectedResults, results.getResults(), "Results different than expected"); + assertEquals("Results different than expected", expectedResults, results.getResults()); // Cleanup if (indexFile.exists()) { diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java index 71412ed7f4..246ca70103 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java @@ -1,22 +1,34 @@ package com.nvidia.cuvs; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; - import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.InputStream; +import java.lang.invoke.MethodHandles; import java.util.Map; import java.util.Random; import java.util.UUID; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; +import org.apache.lucene.tests.util.LuceneTestCase; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.carrotsearch.randomizedtesting.RandomizedContext; + +public class CagraIndexTest extends LuceneTestCase { + Random random; + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + @Before + public void setup() { + this.random = random(); + log.info("Test seed: " +RandomizedContext.current().getRunnerSeedAsString()); + } -public class CagraIndexTest { - @Disabled + @Ignore @Test public void testInvalidDataset() { Throwable exception = assertThrows(IllegalArgumentException.class, () -> { @@ -34,7 +46,7 @@ public void testInvalidDataset() { assertEquals("Dataset cannot be null or empty", exception.getMessage()); } - @Disabled + @Ignore @Test public void testSerializationWithoutOutputStream() throws Throwable { // Use the same dataset as the working test @@ -61,7 +73,7 @@ public void testSerializationWithoutOutputStream() throws Throwable { assertEquals("Output stream cannot be null", exception.getMessage()); } - @Disabled + @Ignore @Test public void testSingleElementDataset() throws Throwable { // Match dataset and parameters to the working test @@ -98,9 +110,9 @@ public void testSingleElementDataset() throws Throwable { CagraSearchResults results = index.search(cuvsQuery); // Verify the results size matches the queries - assertEquals(query.length, results.getResults().size(), "Expected one result for each query"); + assertEquals("Expected one result for each query", query.length, results.getResults().size()); } - @Disabled + @Ignore @Test public void testSearchResultMapping() throws Throwable { // Match dataset and parameters to the working test @@ -147,15 +159,6 @@ public void testSearchResultMapping() throws Throwable { @Test public void testResultsTopKWithRandomValues() throws Throwable { - long seed = System.currentTimeMillis(); - String seedProperty = System.getProperty("test.seed"); - if (seedProperty != null) { - seed = Long.parseLong(seedProperty); - } - System.out.println("Using seed: " + seed); - - Random random = new Random(seed); - int numRows = random.nextInt(10) + 1; // 1 - 10 rows int numCols = random.nextInt(5) + 1; // 1 - 5 columns float[][] dataset = new float[numRows][numCols]; @@ -195,11 +198,11 @@ public void testResultsTopKWithRandomValues() throws Throwable { results.getResults().forEach(result -> { System.out.println("Result size: " + result.size()); - assertEquals(topK, result.size(), "TopK mismatch for query."); + assertEquals("TopK mismatch for query.", topK, result.size()); }); } - @Disabled + @Ignore @Test public void testEmptyResults() throws Throwable { float[][] dataset = { @@ -230,7 +233,7 @@ public void testEmptyResults() throws Throwable { // Verify no neighbors were found assertTrue(results.getResults().isEmpty()); } - @Disabled + @Ignore @Test public void testSearchWithDeletedIndexFile() throws Throwable { // Dataset and Query @@ -284,9 +287,9 @@ public void testSearchWithDeletedIndexFile() throws Throwable { }); // Assert the exception type - assertTrue(exception instanceof java.io.FileNotFoundException, "Expected FileNotFoundException"); + assertTrue("Expected FileNotFoundException", exception instanceof java.io.FileNotFoundException); } - @Disabled + @Ignore @Test public void testNullQueryVectors() throws Throwable { float[][] dataset = { From 49168c89019aebae921c1c80e5e97d72c1530e75 Mon Sep 17 00:00:00 2001 From: punAhuja Date: Fri, 29 Nov 2024 06:08:53 +0530 Subject: [PATCH 5/9] Created separate class for randomized tests --- .../java/com/nvidia/cuvs/CagraIndexTest.java | 28 +-- .../com/nvidia/cuvs/CagraRandomizedTest.java | 184 ++++++++++++++++++ 2 files changed, 198 insertions(+), 14 deletions(-) create mode 100644 java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java index 246ca70103..bc0d2f5dbd 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraIndexTest.java @@ -35,7 +35,7 @@ public void testInvalidDataset() { // Use consistent dataset parameters as the working test float[][] invalidDataset = null; // Simulate an invalid dataset CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder() + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) .build(); new CagraIndex.Builder(resources) @@ -58,7 +58,7 @@ public void testSerializationWithoutOutputStream() throws Throwable { }; CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder() + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) .build(); @@ -85,7 +85,7 @@ public void testSingleElementDataset() throws Throwable { }; CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder() + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) .build(); @@ -103,7 +103,7 @@ public void testSingleElementDataset() throws Throwable { CagraQuery cuvsQuery = new CagraQuery.Builder() .withTopK(3) - .withSearchParams(new CagraSearchParams.Builder().build()) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) .withQueryVectors(query) .build(); @@ -124,7 +124,7 @@ public void testSearchResultMapping() throws Throwable { }; CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder() + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) .build(); @@ -144,7 +144,7 @@ public void testSearchResultMapping() throws Throwable { CagraQuery cuvsQuery = new CagraQuery.Builder() .withTopK(3) - .withSearchParams(new CagraSearchParams.Builder().build()) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) .withQueryVectors(query) .withMapping(mapping) .build(); @@ -182,7 +182,7 @@ public void testResultsTopKWithRandomValues() throws Throwable { System.out.println("TopK: " + topK); CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder().build(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); CagraIndex index = new CagraIndex.Builder(resources) .withDataset(dataset) .withIndexParams(indexParams) @@ -191,7 +191,7 @@ public void testResultsTopKWithRandomValues() throws Throwable { CagraQuery query = new CagraQuery.Builder() .withQueryVectors(queries) .withTopK(topK) - .withSearchParams(new CagraSearchParams.Builder().build()) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) .build(); CagraSearchResults results = index.search(query); @@ -215,7 +215,7 @@ public void testEmptyResults() throws Throwable { }; CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder().build(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); CagraIndex index = new CagraIndex.Builder(resources) .withDataset(dataset) .withIndexParams(indexParams) @@ -224,7 +224,7 @@ public void testEmptyResults() throws Throwable { CagraQuery query = new CagraQuery.Builder() .withQueryVectors(queries) .withTopK(2) - .withSearchParams(new CagraSearchParams.Builder().build()) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) .build(); CagraSearchResults results = index.search(query); @@ -250,7 +250,7 @@ public void testSearchWithDeletedIndexFile() throws Throwable { }; CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder() + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) .build(); @@ -278,7 +278,7 @@ public void testSearchWithDeletedIndexFile() throws Throwable { CagraQuery query = new CagraQuery.Builder() .withTopK(3) - .withSearchParams(new CagraSearchParams.Builder().build()) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) .withQueryVectors(queries) .build(); @@ -298,7 +298,7 @@ public void testNullQueryVectors() throws Throwable { }; CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder().build(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); CagraIndex index = new CagraIndex.Builder(resources) .withDataset(dataset) .withIndexParams(indexParams) @@ -307,7 +307,7 @@ public void testNullQueryVectors() throws Throwable { CagraQuery invalidQuery = new CagraQuery.Builder() .withQueryVectors(null) .withTopK(3) - .withSearchParams(new CagraSearchParams.Builder().build()) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) .build(); Throwable exception = assertThrows(IllegalArgumentException.class, () -> { diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java new file mode 100644 index 0000000000..8b07752f25 --- /dev/null +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java @@ -0,0 +1,184 @@ +package com.nvidia.cuvs; + +import java.lang.invoke.MethodHandles; +import java.util.Random; +import java.util.UUID; + +import org.apache.lucene.tests.util.LuceneTestCase; +import org.junit.Before; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.io.File; +import java.io.FileInputStream; +import java.io.InputStream; +import java.io.FileOutputStream; +import com.carrotsearch.randomizedtesting.RandomizedContext; +import org.junit.Ignore; + + +import static org.junit.Assert.assertEquals; + +public class CagraRandomizedTest extends LuceneTestCase { + private Random random; + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + @Before + public void setup() { + // Initialize the random instance and log the test seed for reproducibility + this.random = random(); + log.info("Test seed: " + RandomizedContext.current().getRunnerSeedAsString()); + } + @Ignore + @Test + public void testResultsTopKWithRandomValues() throws Throwable { + // Generate a random dataset + int numRows = random.nextInt(10) + 1; // 1 - 10 rows + int numCols = random.nextInt(5) + 1; // 1 - 5 columns + float[][] dataset = new float[numRows][numCols]; + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + dataset[i][j] = random.nextFloat() * 100; // Random values between 0 and 100 + } + } + + // Generate random query vectors + int numQueries = random.nextInt(5) + 1; // 1 - 5 queries + float[][] queries = new float[numQueries][numCols]; + for (int i = 0; i < numQueries; i++) { + for (int j = 0; j < numCols; j++) { + queries[i][j] = random.nextFloat() * 100; // Random values between 0 and 100 + } + } + + // Set TopK to be within the range of dataset size + int topK = random.nextInt(numRows) + 1; + + // Log dataset and query information for debugging + log.info("Dataset size: {}x{}", numRows, numCols); + log.info("Query size: {}x{}", numQueries, numCols); + log.info("TopK: {}", topK); + + log.info("Dataset:"); + for (float[] row : dataset) { + log.info(java.util.Arrays.toString(row)); + } + + log.info("Queries:"); + for (float[] query : queries) { + log.info(java.util.Arrays.toString(query)); + } + + // Create CuVSResources + CuVSResources resources = new CuVSResources(); + + // Create index parameters + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); + + // Create the index + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + // Create the query object + CagraQuery query = new CagraQuery.Builder() + .withQueryVectors(queries) + .withTopK(topK) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .build(); + + // Perform the search + CagraSearchResults results = index.search(query); + + // Validate results + results.getResults().forEach(result -> { + log.info("Result size: {}", result.size()); + assertEquals("TopK mismatch for query.", Math.min(topK, numRows), result.size()); + }); + } + + @Test + public void testSearchWithDeletedIndexFile() throws Throwable { + Random random = random(); // Use LuceneTestCase random for reproducibility + + // Generate random dataset + int numRows = random.nextInt(10) + 1; // 1 - 10 rows + int numCols = random.nextInt(5) + 1; // 1 - 5 columns + float[][] dataset = new float[numRows][numCols]; + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + dataset[i][j] = random.nextFloat() * 100; // Random values between 0 and 100 + } + } + + // Generate random query vectors + int numQueries = random.nextInt(5) + 1; // 1 - 5 queries + float[][] queries = new float[numQueries][numCols]; + for (int i = 0; i < numQueries; i++) { + for (int j = 0; j < numCols; j++) { + queries[i][j] = random.nextFloat() * 100; // Random values between 0 and 100 + } + } + + // Set TopK value + int topK = random.nextInt(numRows) + 1; + + // Log dataset and query details + System.out.println("Dataset size: " + numRows + "x" + numCols); + System.out.println("Query size: " + numQueries + "x" + numCols); + System.out.println("TopK: " + topK); + + System.out.println("Dataset:"); + for (float[] row : dataset) { + System.out.println(java.util.Arrays.toString(row)); + } + + System.out.println("Queries:"); + for (float[] query : queries) { + System.out.println(java.util.Arrays.toString(query)); + } + + // Create resources and index parameters + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .build(); + + // Create and serialize the index + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + String indexFileName = UUID.randomUUID().toString() + ".cag"; + index.serialize(new FileOutputStream(indexFileName)); + + // Delete the serialized file + File indexFile = new File(indexFileName); + if (indexFile.exists()) { + indexFile.delete(); + } + + // Attempt to create an InputStream from the deleted file + Throwable exception = assertThrows(Exception.class, () -> { + try (InputStream inputStream = new FileInputStream(indexFile)) { + CagraIndex deletedIndex = new CagraIndex.Builder(resources) + .from(inputStream) + .build(); + + CagraQuery query = new CagraQuery.Builder() + .withTopK(topK) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .withQueryVectors(queries) + .build(); + + deletedIndex.search(query); + } + }); + + // Assert the exception type + assertTrue("Expected FileNotFoundException", exception instanceof java.io.FileNotFoundException); + } + +} From dbcb55a56e600a0cc3d75c6f791a3208c426779a Mon Sep 17 00:00:00 2001 From: punAhuja Date: Fri, 29 Nov 2024 14:51:16 +0530 Subject: [PATCH 6/9] Added some more tests to the RandomizedTest class --- .../com/nvidia/cuvs/CagraRandomizedTest.java | 137 +++++++++++++++--- 1 file changed, 114 insertions(+), 23 deletions(-) diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java index 8b07752f25..08aabebe19 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java @@ -1,6 +1,7 @@ package com.nvidia.cuvs; import java.lang.invoke.MethodHandles; +import java.util.Map; import java.util.Random; import java.util.UUID; @@ -25,36 +26,139 @@ public class CagraRandomizedTest extends LuceneTestCase { @Before public void setup() { - // Initialize the random instance and log the test seed for reproducibility + this.random = random(); log.info("Test seed: " + RandomizedContext.current().getRunnerSeedAsString()); } + @Ignore + @Test + public void testInvalidDataset() throws Throwable { + float[][] invalidDataset = null; // Simulate an invalid dataset + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .build(); + + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + new CagraIndex.Builder(resources) + .withDataset(invalidDataset) + .withIndexParams(indexParams) + .build(); + }); + + assertEquals("Dataset cannot be null or empty", exception.getMessage()); + } + @Ignore + @Test + public void testSerializationWithoutOutputStream() throws Throwable { + // Randomize dataset + int numRows = random.nextInt(10) + 1; + int numCols = random.nextInt(5) + 1; + float[][] dataset = new float[numRows][numCols]; + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + dataset[i][j] = random.nextFloat() * 100; + } + } + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .build(); + + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + index.serialize(null); + }); + + assertEquals("Output stream cannot be null", exception.getMessage()); + } + + @Test + public void testSearchResultMapping() throws Throwable { + // Randomize dataset + int numRows = random.nextInt(10) + 1; + int numCols = random.nextInt(5) + 1; + float[][] dataset = new float[numRows][numCols]; + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + dataset[i][j] = random.nextFloat() * 100; + } + } + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) + .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) + .build(); + + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + + Map mapping = new java.util.HashMap<>(); + for (int i = 0; i < numRows; i++) { + mapping.put(i, i + 1000); + } + + // Randomize query vectors + float[][] query = new float[4][numCols]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < numCols; j++) { + query[i][j] = random.nextFloat() * 100; + } + } + + CagraQuery cuvsQuery = new CagraQuery.Builder() + .withTopK(3) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .withQueryVectors(query) + .withMapping(mapping) + .build(); + + CagraSearchResults results = index.search(cuvsQuery); + + // Validate the results + results.getResults().forEach(result -> { + result.keySet().forEach(key -> { + assertNotNull("Key should not be null", key); + assertTrue("Key not in mapping: " + key, mapping.containsValue(key)); + }); + }); + } + + @Ignore @Test public void testResultsTopKWithRandomValues() throws Throwable { // Generate a random dataset - int numRows = random.nextInt(10) + 1; // 1 - 10 rows - int numCols = random.nextInt(5) + 1; // 1 - 5 columns + int numRows = random.nextInt(10) + 1; + int numCols = random.nextInt(5) + 1; float[][] dataset = new float[numRows][numCols]; for (int i = 0; i < numRows; i++) { for (int j = 0; j < numCols; j++) { - dataset[i][j] = random.nextFloat() * 100; // Random values between 0 and 100 + dataset[i][j] = random.nextFloat() * 100; } } // Generate random query vectors - int numQueries = random.nextInt(5) + 1; // 1 - 5 queries + int numQueries = random.nextInt(5) + 1; float[][] queries = new float[numQueries][numCols]; for (int i = 0; i < numQueries; i++) { for (int j = 0; j < numCols; j++) { - queries[i][j] = random.nextFloat() * 100; // Random values between 0 and 100 + queries[i][j] = random.nextFloat() * 100; } } - // Set TopK to be within the range of dataset size + int topK = random.nextInt(numRows) + 1; - // Log dataset and query information for debugging log.info("Dataset size: {}x{}", numRows, numCols); log.info("Query size: {}x{}", numQueries, numCols); log.info("TopK: {}", topK); @@ -69,38 +173,32 @@ public void testResultsTopKWithRandomValues() throws Throwable { log.info(java.util.Arrays.toString(query)); } - // Create CuVSResources CuVSResources resources = new CuVSResources(); - // Create index parameters CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); - // Create the index CagraIndex index = new CagraIndex.Builder(resources) .withDataset(dataset) .withIndexParams(indexParams) .build(); - // Create the query object CagraQuery query = new CagraQuery.Builder() .withQueryVectors(queries) .withTopK(topK) .withSearchParams(new CagraSearchParams.Builder(resources).build()) .build(); - // Perform the search CagraSearchResults results = index.search(query); - // Validate results results.getResults().forEach(result -> { log.info("Result size: {}", result.size()); assertEquals("TopK mismatch for query.", Math.min(topK, numRows), result.size()); }); } - + @Ignore @Test public void testSearchWithDeletedIndexFile() throws Throwable { - Random random = random(); // Use LuceneTestCase random for reproducibility + Random random = random(); // Generate random dataset int numRows = random.nextInt(10) + 1; // 1 - 10 rows @@ -121,10 +219,8 @@ public void testSearchWithDeletedIndexFile() throws Throwable { } } - // Set TopK value int topK = random.nextInt(numRows) + 1; - // Log dataset and query details System.out.println("Dataset size: " + numRows + "x" + numCols); System.out.println("Query size: " + numQueries + "x" + numCols); System.out.println("TopK: " + topK); @@ -139,13 +235,11 @@ public void testSearchWithDeletedIndexFile() throws Throwable { System.out.println(java.util.Arrays.toString(query)); } - // Create resources and index parameters CuVSResources resources = new CuVSResources(); CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) .build(); - // Create and serialize the index CagraIndex index = new CagraIndex.Builder(resources) .withDataset(dataset) .withIndexParams(indexParams) @@ -154,13 +248,11 @@ public void testSearchWithDeletedIndexFile() throws Throwable { String indexFileName = UUID.randomUUID().toString() + ".cag"; index.serialize(new FileOutputStream(indexFileName)); - // Delete the serialized file File indexFile = new File(indexFileName); if (indexFile.exists()) { indexFile.delete(); } - // Attempt to create an InputStream from the deleted file Throwable exception = assertThrows(Exception.class, () -> { try (InputStream inputStream = new FileInputStream(indexFile)) { CagraIndex deletedIndex = new CagraIndex.Builder(resources) @@ -177,7 +269,6 @@ public void testSearchWithDeletedIndexFile() throws Throwable { } }); - // Assert the exception type assertTrue("Expected FileNotFoundException", exception instanceof java.io.FileNotFoundException); } From c4fb3d2794227a53420ff3c11a395282d9109b66 Mon Sep 17 00:00:00 2001 From: punAhuja Date: Fri, 29 Nov 2024 20:02:31 +0530 Subject: [PATCH 7/9] Added more tests to class --- .../main/java/com/nvidia/cuvs/CagraIndex.java | 3 + .../com/nvidia/cuvs/CagraRandomizedTest.java | 196 +++++++++++++++++- 2 files changed, 188 insertions(+), 11 deletions(-) diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java index 052221213f..2c9b51825d 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java @@ -141,6 +141,9 @@ private IndexReference build() throws Throwable { * @return an instance of {@link CagraSearchResults} containing the results */ public CagraSearchResults search(CagraQuery query) throws Throwable { + if (query.getQueryVectors() == null) { + throw new IllegalArgumentException("Query vectors cannot be null"); + } long numQueries = query.getQueryVectors().length; long numBlocks = query.getTopK() * numQueries; int vectorDimension = numQueries > 0 ? query.getQueryVectors()[0].length : 0; diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java index 08aabebe19..069739a507 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java @@ -17,9 +17,6 @@ import com.carrotsearch.randomizedtesting.RandomizedContext; import org.junit.Ignore; - -import static org.junit.Assert.assertEquals; - public class CagraRandomizedTest extends LuceneTestCase { private Random random; private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); @@ -78,7 +75,7 @@ public void testSerializationWithoutOutputStream() throws Throwable { assertEquals("Output stream cannot be null", exception.getMessage()); } - + @Ignore @Test public void testSearchResultMapping() throws Throwable { // Randomize dataset @@ -200,22 +197,20 @@ public void testResultsTopKWithRandomValues() throws Throwable { public void testSearchWithDeletedIndexFile() throws Throwable { Random random = random(); - // Generate random dataset - int numRows = random.nextInt(10) + 1; // 1 - 10 rows - int numCols = random.nextInt(5) + 1; // 1 - 5 columns + int numRows = random.nextInt(10) + 1; + int numCols = random.nextInt(5) + 1; float[][] dataset = new float[numRows][numCols]; for (int i = 0; i < numRows; i++) { for (int j = 0; j < numCols; j++) { - dataset[i][j] = random.nextFloat() * 100; // Random values between 0 and 100 + dataset[i][j] = random.nextFloat() * 100; } } - // Generate random query vectors - int numQueries = random.nextInt(5) + 1; // 1 - 5 queries + int numQueries = random.nextInt(5) + 1; float[][] queries = new float[numQueries][numCols]; for (int i = 0; i < numQueries; i++) { for (int j = 0; j < numCols; j++) { - queries[i][j] = random.nextFloat() * 100; // Random values between 0 and 100 + queries[i][j] = random.nextFloat() * 100; } } @@ -271,5 +266,184 @@ public void testSearchWithDeletedIndexFile() throws Throwable { assertTrue("Expected FileNotFoundException", exception instanceof java.io.FileNotFoundException); } + @Ignore + @Test + public void testNullQueryVectors() throws Throwable { + // Generate a random dataset + int numRows = random.nextInt(10) + 1; // At least 1 row + int numCols = random.nextInt(5) + 1; // At least 1 column + float[][] dataset = new float[numRows][numCols]; + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + dataset[i][j] = random.nextFloat() * 100; + } + } + + // Log dataset for debugging + System.out.println("Dataset size: " + numRows + "x" + numCols); + for (float[] row : dataset) { + System.out.println(java.util.Arrays.toString(row)); + } + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + // Create an invalid query with null query vectors + CagraQuery invalidQuery = new CagraQuery.Builder() + .withQueryVectors(null) + .withTopK(3) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .build(); + + // Assert that an exception is thrown + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + index.search(invalidQuery); + }); + + // Verify the exception message + assertEquals("Query vectors cannot be null", exception.getMessage()); + } + @Ignore + @Test + public void testTopKExceedsDatasetSize() throws Throwable { + // Generate a random dataset + int numRows = random.nextInt(10) + 1; + int numCols = random.nextInt(5) + 1; + float[][] dataset = new float[numRows][numCols]; + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + dataset[i][j] = random.nextFloat() * 100; + } + } + + int numQueries = random.nextInt(5) + 1; // At least 1 query + float[][] queries = new float[numQueries][numCols]; + for (int i = 0; i < numQueries; i++) { + for (int j = 0; j < numCols; j++) { + queries[i][j] = random.nextFloat() * 100; + } + } + + // Set TopK to exceed dataset size + int topK = numRows + random.nextInt(5) + 1; + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + CagraQuery query = new CagraQuery.Builder() + .withQueryVectors(queries) + .withTopK(topK) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .build(); + + CagraSearchResults results = index.search(query); + + + results.getResults().forEach(result -> { + assertTrue( + "Result size should not exceed the smaller of TopK or dataset size", + result.size() <= topK + ); + }); + } + @Ignore + @Test + public void testDuplicateQueryVectors() throws Throwable { + // Generate a random dataset + int numRows = random.nextInt(10) + 1; + int numCols = random.nextInt(5) + 1; + float[][] dataset = new float[numRows][numCols]; + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + dataset[i][j] = random.nextFloat() * 100; + } + } + + float[][] queries = new float[2][numCols]; + for (int j = 0; j < numCols; j++) { + queries[0][j] = random.nextFloat() * 100; + } + System.arraycopy(queries[0], 0, queries[1], 0, numCols); // Duplicate query + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + CagraQuery query = new CagraQuery.Builder() + .withQueryVectors(queries) + .withTopK(3) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .build(); + + CagraSearchResults results = index.search(query); + + // Validate the results for duplicate queries are identical + assertEquals("Results for duplicate queries should be the same", + results.getResults().get(0), results.getResults().get(1)); + } + + @Test + public void testSerializationDeserializationConsistency() throws Throwable { + + int numRows = random.nextInt(10) + 1; + int numCols = random.nextInt(5) + 1; + float[][] dataset = new float[numRows][numCols]; + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + dataset[i][j] = random.nextFloat() * 100; + } + } + + + int numQueries = random.nextInt(5) + 1; + float[][] queries = new float[numQueries][numCols]; + for (int i = 0; i < numQueries; i++) { + for (int j = 0; j < numCols; j++) { + queries[i][j] = random.nextFloat() * 100; + } + } + + CuVSResources resources = new CuVSResources(); + CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); + CagraIndex index = new CagraIndex.Builder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + + File tempFile = File.createTempFile("cagra-index", ".cag"); + tempFile.deleteOnExit(); + index.serialize(new FileOutputStream(tempFile)); + + + CagraIndex deserializedIndex = new CagraIndex.Builder(resources) + .from(new FileInputStream(tempFile)) + .build(); + + CagraQuery query = new CagraQuery.Builder() + .withQueryVectors(queries) + .withTopK(3) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .build(); + + // Validate that results from the original and deserialized index are identical + CagraSearchResults originalResults = index.search(query); + CagraSearchResults deserializedResults = deserializedIndex.search(query); + + assertEquals("Results from original and deserialized index should match", + originalResults.getResults(), deserializedResults.getResults()); + } + } From 8e3a46fddfed734525b6d9d9f7949a5430d48ad8 Mon Sep 17 00:00:00 2001 From: punAhuja Date: Mon, 2 Dec 2024 20:31:12 +0530 Subject: [PATCH 8/9] Removed other tests, keeping one working testcase --- .../com/nvidia/cuvs/CagraRandomizedTest.java | 358 +----------------- 1 file changed, 1 insertion(+), 357 deletions(-) diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java index 069739a507..cb6f2936f6 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java @@ -27,112 +27,7 @@ public void setup() { this.random = random(); log.info("Test seed: " + RandomizedContext.current().getRunnerSeedAsString()); } - @Ignore - @Test - public void testInvalidDataset() throws Throwable { - float[][] invalidDataset = null; // Simulate an invalid dataset - - CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) - .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) - .build(); - - Throwable exception = assertThrows(IllegalArgumentException.class, () -> { - new CagraIndex.Builder(resources) - .withDataset(invalidDataset) - .withIndexParams(indexParams) - .build(); - }); - - assertEquals("Dataset cannot be null or empty", exception.getMessage()); - } - @Ignore - @Test - public void testSerializationWithoutOutputStream() throws Throwable { - // Randomize dataset - int numRows = random.nextInt(10) + 1; - int numCols = random.nextInt(5) + 1; - float[][] dataset = new float[numRows][numCols]; - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numCols; j++) { - dataset[i][j] = random.nextFloat() * 100; - } - } - - CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) - .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) - .build(); - - CagraIndex index = new CagraIndex.Builder(resources) - .withDataset(dataset) - .withIndexParams(indexParams) - .build(); - - Throwable exception = assertThrows(IllegalArgumentException.class, () -> { - index.serialize(null); - }); - - assertEquals("Output stream cannot be null", exception.getMessage()); - } - @Ignore - @Test - public void testSearchResultMapping() throws Throwable { - // Randomize dataset - int numRows = random.nextInt(10) + 1; - int numCols = random.nextInt(5) + 1; - float[][] dataset = new float[numRows][numCols]; - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numCols; j++) { - dataset[i][j] = random.nextFloat() * 100; - } - } - - CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) - .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) - .build(); - - CagraIndex index = new CagraIndex.Builder(resources) - .withDataset(dataset) - .withIndexParams(indexParams) - .build(); - - - Map mapping = new java.util.HashMap<>(); - for (int i = 0; i < numRows; i++) { - mapping.put(i, i + 1000); - } - - // Randomize query vectors - float[][] query = new float[4][numCols]; - for (int i = 0; i < 4; i++) { - for (int j = 0; j < numCols; j++) { - query[i][j] = random.nextFloat() * 100; - } - } - - CagraQuery cuvsQuery = new CagraQuery.Builder() - .withTopK(3) - .withSearchParams(new CagraSearchParams.Builder(resources).build()) - .withQueryVectors(query) - .withMapping(mapping) - .build(); - - CagraSearchResults results = index.search(cuvsQuery); - - // Validate the results - results.getResults().forEach(result -> { - result.keySet().forEach(key -> { - assertNotNull("Key should not be null", key); - assertTrue("Key not in mapping: " + key, mapping.containsValue(key)); - }); - }); - } - - - @Ignore - @Test + @Test public void testResultsTopKWithRandomValues() throws Throwable { // Generate a random dataset int numRows = random.nextInt(10) + 1; @@ -192,258 +87,7 @@ public void testResultsTopKWithRandomValues() throws Throwable { assertEquals("TopK mismatch for query.", Math.min(topK, numRows), result.size()); }); } - @Ignore - @Test - public void testSearchWithDeletedIndexFile() throws Throwable { - Random random = random(); - - int numRows = random.nextInt(10) + 1; - int numCols = random.nextInt(5) + 1; - float[][] dataset = new float[numRows][numCols]; - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numCols; j++) { - dataset[i][j] = random.nextFloat() * 100; - } - } - - int numQueries = random.nextInt(5) + 1; - float[][] queries = new float[numQueries][numCols]; - for (int i = 0; i < numQueries; i++) { - for (int j = 0; j < numCols; j++) { - queries[i][j] = random.nextFloat() * 100; - } - } - - int topK = random.nextInt(numRows) + 1; - - System.out.println("Dataset size: " + numRows + "x" + numCols); - System.out.println("Query size: " + numQueries + "x" + numCols); - System.out.println("TopK: " + topK); - - System.out.println("Dataset:"); - for (float[] row : dataset) { - System.out.println(java.util.Arrays.toString(row)); - } - - System.out.println("Queries:"); - for (float[] query : queries) { - System.out.println(java.util.Arrays.toString(query)); - } - - CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder(resources) - .withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT) - .build(); - - CagraIndex index = new CagraIndex.Builder(resources) - .withDataset(dataset) - .withIndexParams(indexParams) - .build(); - - String indexFileName = UUID.randomUUID().toString() + ".cag"; - index.serialize(new FileOutputStream(indexFileName)); - - File indexFile = new File(indexFileName); - if (indexFile.exists()) { - indexFile.delete(); - } - - Throwable exception = assertThrows(Exception.class, () -> { - try (InputStream inputStream = new FileInputStream(indexFile)) { - CagraIndex deletedIndex = new CagraIndex.Builder(resources) - .from(inputStream) - .build(); - - CagraQuery query = new CagraQuery.Builder() - .withTopK(topK) - .withSearchParams(new CagraSearchParams.Builder(resources).build()) - .withQueryVectors(queries) - .build(); - - deletedIndex.search(query); - } - }); - - assertTrue("Expected FileNotFoundException", exception instanceof java.io.FileNotFoundException); - } - @Ignore - @Test - public void testNullQueryVectors() throws Throwable { - // Generate a random dataset - int numRows = random.nextInt(10) + 1; // At least 1 row - int numCols = random.nextInt(5) + 1; // At least 1 column - float[][] dataset = new float[numRows][numCols]; - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numCols; j++) { - dataset[i][j] = random.nextFloat() * 100; - } - } - - // Log dataset for debugging - System.out.println("Dataset size: " + numRows + "x" + numCols); - for (float[] row : dataset) { - System.out.println(java.util.Arrays.toString(row)); - } - - CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); - CagraIndex index = new CagraIndex.Builder(resources) - .withDataset(dataset) - .withIndexParams(indexParams) - .build(); - - // Create an invalid query with null query vectors - CagraQuery invalidQuery = new CagraQuery.Builder() - .withQueryVectors(null) - .withTopK(3) - .withSearchParams(new CagraSearchParams.Builder(resources).build()) - .build(); - - // Assert that an exception is thrown - Throwable exception = assertThrows(IllegalArgumentException.class, () -> { - index.search(invalidQuery); - }); - - // Verify the exception message - assertEquals("Query vectors cannot be null", exception.getMessage()); - } - @Ignore - @Test - public void testTopKExceedsDatasetSize() throws Throwable { - // Generate a random dataset - int numRows = random.nextInt(10) + 1; - int numCols = random.nextInt(5) + 1; - float[][] dataset = new float[numRows][numCols]; - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numCols; j++) { - dataset[i][j] = random.nextFloat() * 100; - } - } - - int numQueries = random.nextInt(5) + 1; // At least 1 query - float[][] queries = new float[numQueries][numCols]; - for (int i = 0; i < numQueries; i++) { - for (int j = 0; j < numCols; j++) { - queries[i][j] = random.nextFloat() * 100; - } - } - - // Set TopK to exceed dataset size - int topK = numRows + random.nextInt(5) + 1; - - CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); - CagraIndex index = new CagraIndex.Builder(resources) - .withDataset(dataset) - .withIndexParams(indexParams) - .build(); - - CagraQuery query = new CagraQuery.Builder() - .withQueryVectors(queries) - .withTopK(topK) - .withSearchParams(new CagraSearchParams.Builder(resources).build()) - .build(); - - CagraSearchResults results = index.search(query); - - - results.getResults().forEach(result -> { - assertTrue( - "Result size should not exceed the smaller of TopK or dataset size", - result.size() <= topK - ); - }); - } - @Ignore - @Test - public void testDuplicateQueryVectors() throws Throwable { - // Generate a random dataset - int numRows = random.nextInt(10) + 1; - int numCols = random.nextInt(5) + 1; - float[][] dataset = new float[numRows][numCols]; - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numCols; j++) { - dataset[i][j] = random.nextFloat() * 100; - } - } - - float[][] queries = new float[2][numCols]; - for (int j = 0; j < numCols; j++) { - queries[0][j] = random.nextFloat() * 100; - } - System.arraycopy(queries[0], 0, queries[1], 0, numCols); // Duplicate query - - CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); - CagraIndex index = new CagraIndex.Builder(resources) - .withDataset(dataset) - .withIndexParams(indexParams) - .build(); - - CagraQuery query = new CagraQuery.Builder() - .withQueryVectors(queries) - .withTopK(3) - .withSearchParams(new CagraSearchParams.Builder(resources).build()) - .build(); - - CagraSearchResults results = index.search(query); - - // Validate the results for duplicate queries are identical - assertEquals("Results for duplicate queries should be the same", - results.getResults().get(0), results.getResults().get(1)); - } - @Test - public void testSerializationDeserializationConsistency() throws Throwable { - - int numRows = random.nextInt(10) + 1; - int numCols = random.nextInt(5) + 1; - float[][] dataset = new float[numRows][numCols]; - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numCols; j++) { - dataset[i][j] = random.nextFloat() * 100; - } - } - - - int numQueries = random.nextInt(5) + 1; - float[][] queries = new float[numQueries][numCols]; - for (int i = 0; i < numQueries; i++) { - for (int j = 0; j < numCols; j++) { - queries[i][j] = random.nextFloat() * 100; - } - } - - CuVSResources resources = new CuVSResources(); - CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); - CagraIndex index = new CagraIndex.Builder(resources) - .withDataset(dataset) - .withIndexParams(indexParams) - .build(); - - - File tempFile = File.createTempFile("cagra-index", ".cag"); - tempFile.deleteOnExit(); - index.serialize(new FileOutputStream(tempFile)); - - - CagraIndex deserializedIndex = new CagraIndex.Builder(resources) - .from(new FileInputStream(tempFile)) - .build(); - - CagraQuery query = new CagraQuery.Builder() - .withQueryVectors(queries) - .withTopK(3) - .withSearchParams(new CagraSearchParams.Builder(resources).build()) - .build(); - - // Validate that results from the original and deserialized index are identical - CagraSearchResults originalResults = index.search(query); - CagraSearchResults deserializedResults = deserializedIndex.search(query); - - assertEquals("Results from original and deserialized index should match", - originalResults.getResults(), deserializedResults.getResults()); - } } From e382bf1149519ef13b8e63710b35d153ffa07904 Mon Sep 17 00:00:00 2001 From: punAhuja Date: Fri, 6 Dec 2024 14:28:41 +0530 Subject: [PATCH 9/9] Added a test with new random logic, and comparing actual and expected search results --- .../com/nvidia/cuvs/CagraRandomizedTest.java | 159 ++++++++++++------ .../java/com/nvidia/cuvs/CuVSTestCase.java | 19 +++ 2 files changed, 127 insertions(+), 51 deletions(-) create mode 100644 java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java index cb6f2936f6..cf58a89065 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraRandomizedTest.java @@ -1,93 +1,150 @@ package com.nvidia.cuvs; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.List; import java.util.Map; -import java.util.Random; -import java.util.UUID; +import java.util.TreeMap; -import org.apache.lucene.tests.util.LuceneTestCase; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.File; -import java.io.FileInputStream; -import java.io.InputStream; -import java.io.FileOutputStream; -import com.carrotsearch.randomizedtesting.RandomizedContext; -import org.junit.Ignore; - -public class CagraRandomizedTest extends LuceneTestCase { - private Random random; + +import com.carrotsearch.randomizedtesting.RandomizedRunner; + +@RunWith(RandomizedRunner.class) +public class CagraRandomizedTest extends CuVSTestCase { + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); @Before public void setup() { - - this.random = random(); - log.info("Test seed: " + RandomizedContext.current().getRunnerSeedAsString()); + initializeRandom(); + log.info("Random context initialized for test."); } - @Test + + @Test public void testResultsTopKWithRandomValues() throws Throwable { + // Use old-style random generation logic + int datasetSize = random.nextInt(400) + 1; + int dimensions = random.nextInt(500) + 1; + int numQueries = random.nextInt(500) + 1; + int topK = random.nextInt(datasetSize) + 1; + // Generate a random dataset - int numRows = random.nextInt(10) + 1; - int numCols = random.nextInt(5) + 1; - float[][] dataset = new float[numRows][numCols]; - for (int i = 0; i < numRows; i++) { - for (int j = 0; j < numCols; j++) { + float[][] dataset = new float[datasetSize][dimensions]; + for (int i = 0; i < datasetSize; i++) { + for (int j = 0; j < dimensions; j++) { dataset[i][j] = random.nextFloat() * 100; } } // Generate random query vectors - int numQueries = random.nextInt(5) + 1; - float[][] queries = new float[numQueries][numCols]; + float[][] queries = new float[numQueries][dimensions]; for (int i = 0; i < numQueries; i++) { - for (int j = 0; j < numCols; j++) { + for (int j = 0; j < dimensions; j++) { queries[i][j] = random.nextFloat() * 100; } } - - int topK = random.nextInt(numRows) + 1; - - log.info("Dataset size: {}x{}", numRows, numCols); - log.info("Query size: {}x{}", numQueries, numCols); + log.info("Dataset size: {}x{}", datasetSize, dimensions); + log.info("Query size: {}x{}", numQueries, dimensions); log.info("TopK: {}", topK); - log.info("Dataset:"); - for (float[] row : dataset) { - log.info(java.util.Arrays.toString(row)); - } + // Debugging: Log dataset and queries + if (log.isDebugEnabled()) { + log.debug("Dataset:"); + for (float[] row : dataset) { + log.debug(java.util.Arrays.toString(row)); + } - log.info("Queries:"); - for (float[] query : queries) { - log.info(java.util.Arrays.toString(query)); + log.debug("Queries:"); + for (float[] query : queries) { + log.debug(java.util.Arrays.toString(query)); + } } - CuVSResources resources = new CuVSResources(); + // Sanity checks + assert dataset.length > 0 : "Dataset is empty."; + assert queries.length > 0 : "Queries are empty."; + assert dimensions > 0 : "Invalid dimensions."; + assert topK > 0 && topK <= datasetSize : "Invalid topK value."; + + // Generate expected results using brute force + List> expected = generateExpectedResults(topK, dataset, queries); + // Create CuVS index and query + CuVSResources resources = new CuVSResources(); CagraIndexParams indexParams = new CagraIndexParams.Builder(resources).build(); + CagraIndex index = new CagraIndex.Builder(resources).withDataset(dataset).withIndexParams(indexParams).build(); - CagraIndex index = new CagraIndex.Builder(resources) - .withDataset(dataset) - .withIndexParams(indexParams) - .build(); + log.info("Index built successfully."); CagraQuery query = new CagraQuery.Builder() - .withQueryVectors(queries) - .withTopK(topK) - .withSearchParams(new CagraSearchParams.Builder(resources).build()) - .build(); + .withQueryVectors(queries) + .withTopK(topK) + .withSearchParams(new CagraSearchParams.Builder(resources).build()) + .build(); + log.info("Query built successfully. Executing search..."); + + // Execute search and retrieve results CagraSearchResults results = index.search(query); - results.getResults().forEach(result -> { - log.info("Result size: {}", result.size()); - assertEquals("TopK mismatch for query.", Math.min(topK, numRows), result.size()); - }); + // actual vs. expected results + for (int i = 0; i < results.getResults().size(); i++) { + Map result = results.getResults().get(i); + log.info("Actual result for query {}: {}", i, result.keySet()); + log.info("Expected result for query {}: {}", i, expected.get(i)); + + assertEquals("TopK mismatch for query.", Math.min(topK, datasetSize), result.size()); + + // Sort result by values (distances) and extract keys + List sortedResultKeys = result.entrySet().stream() + .sorted(Map.Entry.comparingByValue()) // Sort by value (distance) + .map(Map.Entry::getKey) // Extract sorted keys + .toList(); + + log.info("Sorted Actual result for query {}: {}", i, sortedResultKeys); + + // Compare using primitive int arrays + assertArrayEquals( + "Query " + i + " mismatched", + expected.get(i).stream().mapToInt(Integer::intValue).toArray(), + sortedResultKeys.stream().mapToInt(Integer::intValue).toArray() + ); + } + } - - + private List> generateExpectedResults(int topK, float[][] dataset, float[][] queries) { + List> neighborsResult = new ArrayList<>(); + int dimensions = dataset[0].length; + + for (float[] query : queries) { + Map distances = new TreeMap<>(); + for (int j = 0; j < dataset.length; j++) { + double distance = 0; + for (int k = 0; k < dimensions; k++) { + distance += (query[k] - dataset[j][k]) * (query[k] - dataset[j][k]); + } + distances.put(j, Math.sqrt(distance)); + } + + // Sort by distance and select the topK nearest neighbors + List neighbors = distances.entrySet().stream() + .sorted(Map.Entry.comparingByValue()) + .map(Map.Entry::getKey) + .toList(); + neighborsResult.add(neighbors.subList(0, Math.min(topK, dataset.length))); + } + + log.info("Expected results generated successfully."); + return neighborsResult; + } } diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java new file mode 100644 index 0000000000..a07cb4ebfe --- /dev/null +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java @@ -0,0 +1,19 @@ +package com.nvidia.cuvs; + +import java.lang.invoke.MethodHandles; +import java.util.Random; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.carrotsearch.randomizedtesting.RandomizedContext; + +public abstract class CuVSTestCase { + protected Random random; + private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + protected void initializeRandom() { + random = RandomizedContext.current().getRandom(); + log.info("Test seed: " + RandomizedContext.current().getRunnerSeedAsString()); + } +}