From fd17702970eda2bad11962d0069c1b6c11bacf8e Mon Sep 17 00:00:00 2001 From: punAhuja Date: Fri, 2 May 2025 15:55:39 +0530 Subject: [PATCH 1/3] Changes to support prefiltering in CAGRA --- .../main/java/com/nvidia/cuvs/CagraQuery.java | 25 +++++++- java/internal/src/cuvs_java.c | 62 +++++++++++++++---- 2 files changed, 74 insertions(+), 13 deletions(-) diff --git a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraQuery.java b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraQuery.java index de2fc3f417..e0603eef9c 100644 --- a/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraQuery.java +++ b/java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraQuery.java @@ -18,6 +18,7 @@ import java.util.Arrays; import java.util.List; +import java.util.BitSet; /** * CagraQuery holds the CagraSearchParams and the query vectors to be used while @@ -31,6 +32,8 @@ public class CagraQuery { private List mapping; private float[][] queryVectors; private int topK; + private BitSet[] prefilters; + private int numDocs; /** * Constructs an instance of {@link CagraQuery} using cagraSearchParameters, @@ -42,12 +45,14 @@ public class CagraQuery { * @param mapping an instance of ID mapping * @param topK the top k results to return */ - public CagraQuery(CagraSearchParams cagraSearchParameters, float[][] queryVectors, List mapping, int topK) { + public CagraQuery(CagraSearchParams cagraSearchParameters, float[][] queryVectors, List mapping, int topK, BitSet[] prefilters, int numDocs) { super(); this.cagraSearchParameters = cagraSearchParameters; this.queryVectors = queryVectors; this.mapping = mapping; this.topK = topK; + this.prefilters = prefilters; + this.numDocs = numDocs; } /** @@ -85,6 +90,14 @@ public List getMapping() { public int getTopK() { return topK; } + + public BitSet[] getPrefilters() { + return prefilters; + } + + public int getNumDocs() { + return numDocs; + } @Override public String toString() { @@ -101,6 +114,8 @@ public static class Builder { private float[][] queryVectors; private List mapping; private int topK = 2; + private BitSet[] prefilters; + private int numDocs; /** * Default constructor. @@ -152,6 +167,12 @@ public Builder withTopK(int topK) { this.topK = topK; return this; } + + public Builder withPrefilter(BitSet[] prefilters, int numDocs) { + this.prefilters = prefilters; + this.numDocs = numDocs; + return this; + } /** * Builds an instance of CuVSQuery. @@ -159,7 +180,7 @@ public Builder withTopK(int topK) { * @return an instance of CuVSQuery */ public CagraQuery build() { - return new CagraQuery(cagraSearchParams, queryVectors, mapping, topK); + return new CagraQuery(cagraSearchParams, queryVectors, mapping, topK, prefilters, numDocs); } } } diff --git a/java/internal/src/cuvs_java.c b/java/internal/src/cuvs_java.c index 0b9c2f794e..62f8f72e89 100644 --- a/java/internal/src/cuvs_java.c +++ b/java/internal/src/cuvs_java.c @@ -166,17 +166,27 @@ void deserialize_cagra_index(cuvsResources_t cuvs_resources, cuvsCagraIndex_t in * @param[out] return_value return value for cuvsCagraSearch function call * @param[in] search_params reference to cuvsCagraSearchParams_t holding the search parameters */ -void search_cagra_index(cuvsCagraIndex_t index, float *queries, int topk, long n_queries, int dimensions, - cuvsResources_t cuvs_resources, int *neighbors_h, float *distances_h, int *return_value, cuvsCagraSearchParams_t search_params) { - +void search_cagra_index(cuvsCagraIndex_t index, + float *queries, + int topk, + long n_queries, + int dimensions, + cuvsResources_t cuvs_resources, + int *neighbors_h, + float *distances_h, + int *return_value, + cuvsCagraSearchParams_t search_params, + uint32_t *prefilter_data, + long prefilter_data_length) { cudaStream_t stream; cuvsStreamGet(cuvs_resources, &stream); uint32_t *neighbors; float *distances, *queries_d; - cuvsRMMAlloc(cuvs_resources, (void**) &queries_d, sizeof(float) * n_queries * dimensions); - cuvsRMMAlloc(cuvs_resources, (void**) &neighbors, sizeof(uint32_t) * n_queries * topk); - cuvsRMMAlloc(cuvs_resources, (void**) &distances, sizeof(float) * n_queries * topk); + + cuvsRMMAlloc(cuvs_resources, (void **) &queries_d, sizeof(float) * n_queries * dimensions); + cuvsRMMAlloc(cuvs_resources, (void **) &neighbors, sizeof(uint32_t) * n_queries * topk); + cuvsRMMAlloc(cuvs_resources, (void **) &distances, sizeof(float) * n_queries * topk); cudaMemcpy(queries_d, queries, sizeof(float) * n_queries * dimensions, cudaMemcpyDefault); @@ -191,12 +201,35 @@ void search_cagra_index(cuvsCagraIndex_t index, float *queries, int topk, long n cuvsStreamSync(cuvs_resources); - cuvsFilter filter; // TODO: Implement Cagra Pre-Filtering, but leave it as no-op for now - filter.type = NO_FILTER; - filter.addr = (uintptr_t)NULL; + cuvsFilter filter; + uint32_t *prefilter_d = NULL; + int64_t prefilter_len = 0; + DLManagedTensor *prefilter_tensor_ptr = NULL; + + if (prefilter_data == NULL || prefilter_data_length == 0) { + filter.type = NO_FILTER; + filter.addr = (uintptr_t) NULL; + } else { + int64_t prefilter_shape[1] = {(prefilter_data_length + 31) / 32}; + prefilter_len = prefilter_shape[0]; - *return_value = cuvsCagraSearch(cuvs_resources, search_params, index, &queries_tensor, &neighbors_tensor, - &distances_tensor, filter); + cuvsRMMAlloc(cuvs_resources, (void **) &prefilter_d, sizeof(uint32_t) * prefilter_len); + cudaMemcpy(prefilter_d, prefilter_data, sizeof(uint32_t) * prefilter_len, cudaMemcpyHostToDevice); + + prefilter_tensor_ptr = (DLManagedTensor *) malloc(sizeof(DLManagedTensor)); + *prefilter_tensor_ptr = prepare_tensor(prefilter_d, prefilter_shape, kDLUInt, 32, 1, kDLCUDA); + + filter.type = BITSET; + filter.addr = (uintptr_t) prefilter_tensor_ptr; + } + + *return_value = cuvsCagraSearch(cuvs_resources, + search_params, + index, + &queries_tensor, + &neighbors_tensor, + &distances_tensor, + filter); cudaMemcpy(neighbors_h, neighbors, sizeof(uint32_t) * n_queries * topk, cudaMemcpyDefault); cudaMemcpy(distances_h, distances, sizeof(float) * n_queries * topk, cudaMemcpyDefault); @@ -204,8 +237,15 @@ void search_cagra_index(cuvsCagraIndex_t index, float *queries, int topk, long n cuvsRMMFree(cuvs_resources, distances, sizeof(float) * n_queries * topk); cuvsRMMFree(cuvs_resources, neighbors, sizeof(uint32_t) * n_queries * topk); cuvsRMMFree(cuvs_resources, queries_d, sizeof(float) * n_queries * dimensions); + if (prefilter_d != NULL) { + cuvsRMMFree(cuvs_resources, prefilter_d, sizeof(uint32_t) * prefilter_len); + } + if (prefilter_tensor_ptr != NULL) { + free(prefilter_tensor_ptr); + } } + /** * @brief De-allocate BRUTEFORCE index * From c22264a87e2ae6922ef7c4be832860ca5cea770a Mon Sep 17 00:00:00 2001 From: punAhuja Date: Fri, 2 May 2025 23:11:21 +0530 Subject: [PATCH 2/3] CagraIndexImpl file changes for invoking C API --- .../nvidia/cuvs/internal/CagraIndexImpl.java | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java index 7788354022..f447cb9ca1 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java @@ -54,6 +54,7 @@ import com.nvidia.cuvs.internal.panama.cuvsIvfPqIndexParams; import com.nvidia.cuvs.internal.panama.cuvsIvfPqParams; import com.nvidia.cuvs.internal.panama.cuvsIvfPqSearchParams; +import java.util.BitSet; /** * {@link CagraIndex} encapsulates a CAGRA index, along with methods to interact @@ -73,7 +74,7 @@ public class CagraIndexImpl implements CagraIndex { FunctionDescriptor.of(ADDRESS, ADDRESS, C_LONG, C_LONG, ADDRESS, ADDRESS, ADDRESS, ADDRESS, C_INT)); private static final MethodHandle searchMethodHandle = downcallHandle("search_cagra_index", - FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, C_INT, C_LONG, C_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS, ADDRESS)); + FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, C_INT, C_LONG, C_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS, ADDRESS, ADDRESS, C_LONG)); private static final MethodHandle serializeMethodHandle = downcallHandle("serialize_cagra_index", FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, ADDRESS, ADDRESS)); @@ -209,6 +210,16 @@ public SearchResults search(CagraQuery query) throws Throwable { MemorySegment neighborsMemorySegment = resources.getArena().allocate(neighborsSequenceLayout); MemorySegment distancesMemorySegment = resources.getArena().allocate(distancesSequenceLayout); MemorySegment floatsSeg = Util.buildMemorySegment(resources.getArena(), query.getQueryVectors()); + + long prefilterDataLength = 0; + MemorySegment prefilterData = MemorySegment.NULL; + if (query.getPrefilters() != null && query.getPrefilters().length > 0) { + BitSet concatenated = Util.concatenate(query.getPrefilters(), query.getNumDocs()); + long[] longArray = concatenated.toLongArray(); + prefilterData = Util.buildMemorySegment(resources.getArena(), longArray); + prefilterDataLength = query.getNumDocs() * query.getPrefilters().length; + } + try (var localArena = Arena.ofConfined()) { MemorySegment returnValue = localArena.allocate(C_INT); @@ -222,7 +233,9 @@ public SearchResults search(CagraQuery query) throws Throwable { neighborsMemorySegment, distancesMemorySegment, returnValue, - segmentFromSearchParams(query.getCagraSearchParameters()) + segmentFromSearchParams(query.getCagraSearchParameters()), + prefilterData, + prefilterDataLength ); checkError(returnValue.get(C_INT, 0L), "searchMethodHandle"); } From 6ad287712b9b9c0a1b7d0f8d81f24b6ca1e08e11 Mon Sep 17 00:00:00 2001 From: punAhuja Date: Fri, 2 May 2025 23:49:20 +0530 Subject: [PATCH 3/3] Added some tests to test prefiltering in CAGRA --- .../nvidia/cuvs/CagraBuildAndSearchIT.java | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchIT.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchIT.java index 75b21aed82..5f43b5acc4 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchIT.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchIT.java @@ -18,6 +18,8 @@ import static com.carrotsearch.randomizedtesting.RandomizedTest.assumeTrue; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; import java.io.File; import java.io.FileInputStream; @@ -32,6 +34,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.BitSet; import org.junit.Before; import org.junit.Test; @@ -121,6 +124,72 @@ public void testIndexingAndSearchingFlow() throws Throwable { } } + + @Test + public void testPrefilteringReducesResults() throws Throwable { + float[][] dataset = { + { 0.0f, 0.0f }, + { 10.0f, 10.0f }, + { 1.0f, 1.0f }, + { 2.0f, 2.0f } + }; + List map = List.of(0, 1, 2, 3); + float[][] queries = { + {0.1f, 0.1f} + }; + + CagraIndexParams indexParams = new CagraIndexParams.Builder() + .withCagraGraphBuildAlgo(CagraGraphBuildAlgo.NN_DESCENT) + .withGraphDegree(2) + .withIntermediateGraphDegree(4) + .withNumWriterThreads(2) + .withMetric(CuvsDistanceType.L2Expanded) + .build(); + + try (CuVSResources resources = CuVSResources.create()) { + CagraIndex index = CagraIndex.newBuilder(resources) + .withDataset(dataset) + .withIndexParams(indexParams) + .build(); + + // No prefilter (all points allowed) + CagraSearchParams searchParams = new CagraSearchParams.Builder(resources).build(); + CagraQuery fullQuery = new CagraQuery.Builder() + .withTopK(2) + .withSearchParams(searchParams) + .withQueryVectors(queries) + .withMapping(map) + .build(); + + SearchResults fullResults = index.search(fullQuery); + List> full = fullResults.getResults(); + log.info("Full results: {}", full); + + // Apply prefilter: only allow ids 0 and 3 (bitset: 1100) + BitSet prefilterBits = new BitSet(4); + prefilterBits.set(1); + prefilterBits.set(2); + prefilterBits.set(3); + BitSet[] prefilters = new BitSet[] { prefilterBits }; + + CagraQuery filteredQuery = new CagraQuery.Builder() + .withTopK(2) + .withSearchParams(searchParams) + .withQueryVectors(queries) + .withMapping(map) + .withPrefilter(prefilters, 4) + .build(); + + SearchResults filteredResults = index.search(filteredQuery); + List> filtered = filteredResults.getResults(); + log.info("Filtered results: {}", filtered); + + assertTrue(full.get(0).containsKey(0)); + assertFalse(filtered.get(0).containsKey(0)); + assertTrue(filtered.get(0).containsKey(2)); + } +} + private Runnable indexAndQueryOnce(float[][] dataset, List map, float[][] queries, List> expectedResults, CuVSResources resources) throws Throwable, FileNotFoundException {