Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.Arrays;
import java.util.List;
import java.util.BitSet;

/**
* CagraQuery holds the CagraSearchParams and the query vectors to be used while
Expand All @@ -31,6 +32,8 @@ public class CagraQuery {
private List<Integer> mapping;
private float[][] queryVectors;
private int topK;
private BitSet[] prefilters;
private int numDocs;

/**
* Constructs an instance of {@link CagraQuery} using cagraSearchParameters,
Expand All @@ -42,12 +45,14 @@ public class CagraQuery {
* @param mapping an instance of ID mapping
* @param topK the top k results to return
*/
public CagraQuery(CagraSearchParams cagraSearchParameters, float[][] queryVectors, List<Integer> mapping, int topK) {
public CagraQuery(CagraSearchParams cagraSearchParameters, float[][] queryVectors, List<Integer> mapping, int topK, BitSet[] prefilters, int numDocs) {
super();
this.cagraSearchParameters = cagraSearchParameters;
this.queryVectors = queryVectors;
this.mapping = mapping;
this.topK = topK;
this.prefilters = prefilters;
this.numDocs = numDocs;
}

/**
Expand Down Expand Up @@ -85,6 +90,14 @@ public List<Integer> getMapping() {
public int getTopK() {
return topK;
}

public BitSet[] getPrefilters() {
return prefilters;
}

public int getNumDocs() {
return numDocs;
}

@Override
public String toString() {
Expand All @@ -101,6 +114,8 @@ public static class Builder {
private float[][] queryVectors;
private List<Integer> mapping;
private int topK = 2;
private BitSet[] prefilters;
private int numDocs;

/**
* Default constructor.
Expand Down Expand Up @@ -152,14 +167,20 @@ public Builder withTopK(int topK) {
this.topK = topK;
return this;
}

public Builder withPrefilter(BitSet[] prefilters, int numDocs) {
this.prefilters = prefilters;
this.numDocs = numDocs;
return this;
}

/**
* Builds an instance of CuVSQuery.
*
* @return an instance of CuVSQuery
*/
public CagraQuery build() {
return new CagraQuery(cagraSearchParams, queryVectors, mapping, topK);
return new CagraQuery(cagraSearchParams, queryVectors, mapping, topK, prefilters, numDocs);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import com.nvidia.cuvs.internal.panama.cuvsIvfPqIndexParams;
import com.nvidia.cuvs.internal.panama.cuvsIvfPqParams;
import com.nvidia.cuvs.internal.panama.cuvsIvfPqSearchParams;
import java.util.BitSet;

/**
* {@link CagraIndex} encapsulates a CAGRA index, along with methods to interact
Expand All @@ -73,7 +74,7 @@ public class CagraIndexImpl implements CagraIndex {
FunctionDescriptor.of(ADDRESS, ADDRESS, C_LONG, C_LONG, ADDRESS, ADDRESS, ADDRESS, ADDRESS, C_INT));

private static final MethodHandle searchMethodHandle = downcallHandle("search_cagra_index",
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, C_INT, C_LONG, C_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS, ADDRESS));
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, C_INT, C_LONG, C_INT, ADDRESS, ADDRESS, ADDRESS, ADDRESS, ADDRESS, ADDRESS, C_LONG));

private static final MethodHandle serializeMethodHandle = downcallHandle("serialize_cagra_index",
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, ADDRESS, ADDRESS));
Expand Down Expand Up @@ -209,6 +210,16 @@ public SearchResults search(CagraQuery query) throws Throwable {
MemorySegment neighborsMemorySegment = resources.getArena().allocate(neighborsSequenceLayout);
MemorySegment distancesMemorySegment = resources.getArena().allocate(distancesSequenceLayout);
MemorySegment floatsSeg = Util.buildMemorySegment(resources.getArena(), query.getQueryVectors());

long prefilterDataLength = 0;
MemorySegment prefilterData = MemorySegment.NULL;
if (query.getPrefilters() != null && query.getPrefilters().length > 0) {
BitSet concatenated = Util.concatenate(query.getPrefilters(), query.getNumDocs());
long[] longArray = concatenated.toLongArray();
prefilterData = Util.buildMemorySegment(resources.getArena(), longArray);
prefilterDataLength = query.getNumDocs() * query.getPrefilters().length;
}


try (var localArena = Arena.ofConfined()) {
MemorySegment returnValue = localArena.allocate(C_INT);
Expand All @@ -222,7 +233,9 @@ public SearchResults search(CagraQuery query) throws Throwable {
neighborsMemorySegment,
distancesMemorySegment,
returnValue,
segmentFromSearchParams(query.getCagraSearchParameters())
segmentFromSearchParams(query.getCagraSearchParameters()),
prefilterData,
prefilterDataLength
);
checkError(returnValue.get(C_INT, 0L), "searchMethodHandle");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import static com.carrotsearch.randomizedtesting.RandomizedTest.assumeTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertFalse;

import java.io.File;
import java.io.FileInputStream;
Expand All @@ -32,6 +34,7 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.BitSet;

import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -121,6 +124,72 @@ public void testIndexingAndSearchingFlow() throws Throwable {

}
}

@Test
public void testPrefilteringReducesResults() throws Throwable {
float[][] dataset = {
{ 0.0f, 0.0f },
{ 10.0f, 10.0f },
{ 1.0f, 1.0f },
{ 2.0f, 2.0f }
};
List<Integer> map = List.of(0, 1, 2, 3);
float[][] queries = {
{0.1f, 0.1f}
};

CagraIndexParams indexParams = new CagraIndexParams.Builder()
.withCagraGraphBuildAlgo(CagraGraphBuildAlgo.NN_DESCENT)
.withGraphDegree(2)
.withIntermediateGraphDegree(4)
.withNumWriterThreads(2)
.withMetric(CuvsDistanceType.L2Expanded)
.build();

try (CuVSResources resources = CuVSResources.create()) {
CagraIndex index = CagraIndex.newBuilder(resources)
.withDataset(dataset)
.withIndexParams(indexParams)
.build();

// No prefilter (all points allowed)
CagraSearchParams searchParams = new CagraSearchParams.Builder(resources).build();
CagraQuery fullQuery = new CagraQuery.Builder()
.withTopK(2)
.withSearchParams(searchParams)
.withQueryVectors(queries)
.withMapping(map)
.build();

SearchResults fullResults = index.search(fullQuery);
List<Map<Integer, Float>> full = fullResults.getResults();
log.info("Full results: {}", full);

// Apply prefilter: only allow ids 0 and 3 (bitset: 1100)
BitSet prefilterBits = new BitSet(4);
prefilterBits.set(1);
prefilterBits.set(2);
prefilterBits.set(3);
BitSet[] prefilters = new BitSet[] { prefilterBits };

CagraQuery filteredQuery = new CagraQuery.Builder()
.withTopK(2)
.withSearchParams(searchParams)
.withQueryVectors(queries)
.withMapping(map)
.withPrefilter(prefilters, 4)
.build();

SearchResults filteredResults = index.search(filteredQuery);
List<Map<Integer, Float>> filtered = filteredResults.getResults();
log.info("Filtered results: {}", filtered);

assertTrue(full.get(0).containsKey(0));
assertFalse(filtered.get(0).containsKey(0));
assertTrue(filtered.get(0).containsKey(2));
}
}


private Runnable indexAndQueryOnce(float[][] dataset, List<Integer> map, float[][] queries,
List<Map<Integer, Float>> expectedResults, CuVSResources resources) throws Throwable, FileNotFoundException {
Expand Down
62 changes: 51 additions & 11 deletions java/internal/src/cuvs_java.c
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,27 @@ void deserialize_cagra_index(cuvsResources_t cuvs_resources, cuvsCagraIndex_t in
* @param[out] return_value return value for cuvsCagraSearch function call
* @param[in] search_params reference to cuvsCagraSearchParams_t holding the search parameters
*/
void search_cagra_index(cuvsCagraIndex_t index, float *queries, int topk, long n_queries, int dimensions,
cuvsResources_t cuvs_resources, int *neighbors_h, float *distances_h, int *return_value, cuvsCagraSearchParams_t search_params) {

void search_cagra_index(cuvsCagraIndex_t index,
float *queries,
int topk,
long n_queries,
int dimensions,
cuvsResources_t cuvs_resources,
int *neighbors_h,
float *distances_h,
int *return_value,
cuvsCagraSearchParams_t search_params,
uint32_t *prefilter_data,
long prefilter_data_length) {
cudaStream_t stream;
cuvsStreamGet(cuvs_resources, &stream);

uint32_t *neighbors;
float *distances, *queries_d;
cuvsRMMAlloc(cuvs_resources, (void**) &queries_d, sizeof(float) * n_queries * dimensions);
cuvsRMMAlloc(cuvs_resources, (void**) &neighbors, sizeof(uint32_t) * n_queries * topk);
cuvsRMMAlloc(cuvs_resources, (void**) &distances, sizeof(float) * n_queries * topk);

cuvsRMMAlloc(cuvs_resources, (void **) &queries_d, sizeof(float) * n_queries * dimensions);
cuvsRMMAlloc(cuvs_resources, (void **) &neighbors, sizeof(uint32_t) * n_queries * topk);
cuvsRMMAlloc(cuvs_resources, (void **) &distances, sizeof(float) * n_queries * topk);

cudaMemcpy(queries_d, queries, sizeof(float) * n_queries * dimensions, cudaMemcpyDefault);

Expand All @@ -191,21 +201,51 @@ void search_cagra_index(cuvsCagraIndex_t index, float *queries, int topk, long n

cuvsStreamSync(cuvs_resources);

cuvsFilter filter; // TODO: Implement Cagra Pre-Filtering, but leave it as no-op for now
filter.type = NO_FILTER;
filter.addr = (uintptr_t)NULL;
cuvsFilter filter;
uint32_t *prefilter_d = NULL;
int64_t prefilter_len = 0;
DLManagedTensor *prefilter_tensor_ptr = NULL;

if (prefilter_data == NULL || prefilter_data_length == 0) {
filter.type = NO_FILTER;
filter.addr = (uintptr_t) NULL;
} else {
int64_t prefilter_shape[1] = {(prefilter_data_length + 31) / 32};
prefilter_len = prefilter_shape[0];

*return_value = cuvsCagraSearch(cuvs_resources, search_params, index, &queries_tensor, &neighbors_tensor,
&distances_tensor, filter);
cuvsRMMAlloc(cuvs_resources, (void **) &prefilter_d, sizeof(uint32_t) * prefilter_len);
cudaMemcpy(prefilter_d, prefilter_data, sizeof(uint32_t) * prefilter_len, cudaMemcpyHostToDevice);

prefilter_tensor_ptr = (DLManagedTensor *) malloc(sizeof(DLManagedTensor));
*prefilter_tensor_ptr = prepare_tensor(prefilter_d, prefilter_shape, kDLUInt, 32, 1, kDLCUDA);

filter.type = BITSET;
filter.addr = (uintptr_t) prefilter_tensor_ptr;
}

*return_value = cuvsCagraSearch(cuvs_resources,
search_params,
index,
&queries_tensor,
&neighbors_tensor,
&distances_tensor,
filter);

cudaMemcpy(neighbors_h, neighbors, sizeof(uint32_t) * n_queries * topk, cudaMemcpyDefault);
cudaMemcpy(distances_h, distances, sizeof(float) * n_queries * topk, cudaMemcpyDefault);

cuvsRMMFree(cuvs_resources, distances, sizeof(float) * n_queries * topk);
cuvsRMMFree(cuvs_resources, neighbors, sizeof(uint32_t) * n_queries * topk);
cuvsRMMFree(cuvs_resources, queries_d, sizeof(float) * n_queries * dimensions);
if (prefilter_d != NULL) {
cuvsRMMFree(cuvs_resources, prefilter_d, sizeof(uint32_t) * prefilter_len);
}
if (prefilter_tensor_ptr != NULL) {
free(prefilter_tensor_ptr);
}
}


/**
* @brief De-allocate BRUTEFORCE index
*
Expand Down