Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ public static void main(String[] args) throws Exception {
// .dataset(braintrust.fetchDataset("my-dataset-name"))
.taskFunction(getFoodType)
.scorers(
// to fetch a remote scorer:
// braintrust.fetchScorer("my-remote-scorer-6d9f"),
Scorer.of(
"exact_match",
(expected, result) -> expected.equals(result) ? 1.0 : 0.0))
Expand Down
37 changes: 37 additions & 0 deletions src/main/java/dev/braintrust/Braintrust.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import dev.braintrust.config.BraintrustConfig;
import dev.braintrust.eval.Dataset;
import dev.braintrust.eval.Eval;
import dev.braintrust.eval.Scorer;
import dev.braintrust.prompt.BraintrustPromptLoader;
import dev.braintrust.trace.BraintrustTracing;
import io.opentelemetry.api.OpenTelemetry;
Expand Down Expand Up @@ -173,4 +174,40 @@ public <INPUT, OUTPUT> Dataset<INPUT, OUTPUT> fetchDataset(
var projectName = apiClient.getOrCreateProjectAndOrgInfo(config).project().name();
return Dataset.fetchFromBraintrust(apiClient(), projectName, datasetName, datasetVersion);
}

/**
* Fetch a scorer from Braintrust by slug, using the default project from configuration.
*
* @param scorerSlug the unique slug identifier for the scorer
* @return a Scorer that invokes the remote function
*/
public <INPUT, OUTPUT> Scorer<INPUT, OUTPUT> fetchScorer(String scorerSlug) {
return fetchScorer(scorerSlug, null);
}

/**
* Fetch a scorer from Braintrust by slug, using the default project from configuration.
*
* @param scorerSlug the unique slug identifier for the scorer
* @param version optional version of the scorer to fetch
* @return a Scorer that invokes the remote function
*/
public <INPUT, OUTPUT> Scorer<INPUT, OUTPUT> fetchScorer(
String scorerSlug, @Nullable String version) {
var projectName = apiClient.getOrCreateProjectAndOrgInfo(config).project().name();
return Scorer.fetchFromBraintrust(apiClient, projectName, scorerSlug, version);
}

/**
* Fetch a scorer from Braintrust by project name and slug.
*
* @param projectName the name of the project containing the scorer
* @param scorerSlug the unique slug identifier for the scorer
* @param version optional version of the scorer to fetch
* @return a Scorer that invokes the remote function
*/
public <INPUT, OUTPUT> Scorer<INPUT, OUTPUT> fetchScorer(
String projectName, String scorerSlug, @Nullable String version) {
return Scorer.fetchFromBraintrust(apiClient, projectName, scorerSlug, version);
}
}
159 changes: 159 additions & 0 deletions src/main/java/dev/braintrust/api/BraintrustApiClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,26 @@ Optional<Prompt> getPrompt(
/** Query datasets by project name and dataset name */
List<Dataset> queryDatasets(String projectName, String datasetName);

/**
* Get a function by project name and slug, with optional version.
*
* @param projectName the name of the project containing the function
* @param slug the unique slug identifier for the function
* @param version optional version identifier (transaction id or version string)
* @return the function if found
*/
Optional<Function> getFunction(
@Nonnull String projectName, @Nonnull String slug, @Nullable String version);

/**
* Invoke a function (scorer, prompt, or tool) by its ID.
*
* @param functionId the ID of the function to invoke
* @param request the invocation request containing input, expected output, etc.
* @return the result of the function invocation
*/
Object invokeFunction(@Nonnull String functionId, @Nonnull FunctionInvokeRequest request);

static BraintrustApiClient of(BraintrustConfig config) {
return new HttpImpl(config);
}
Expand Down Expand Up @@ -296,6 +316,54 @@ public List<Dataset> queryDatasets(String projectName, String datasetName) {
}
}

@Override
public Optional<Function> getFunction(
@Nonnull String projectName, @Nonnull String slug, @Nullable String version) {
Objects.requireNonNull(projectName, "projectName must not be null");
Objects.requireNonNull(slug, "slug must not be null");
try {
var uriBuilder = new StringBuilder("/v1/function?");
uriBuilder.append("slug=").append(slug);
uriBuilder.append("&project_name=").append(projectName);

if (version != null && !version.isEmpty()) {
uriBuilder.append("&version=").append(version);
}

FunctionListResponse response =
getAsync(uriBuilder.toString(), FunctionListResponse.class).get();

if (response.objects() == null || response.objects().isEmpty()) {
return Optional.empty();
}

if (response.objects().size() > 1) {
throw new ApiException(
"Multiple functions found for slug: "
+ slug
+ ", projectName: "
+ projectName);
}

return Optional.of(response.objects().get(0));
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}

@Override
public Object invokeFunction(
@Nonnull String functionId, @Nonnull FunctionInvokeRequest request) {
Objects.requireNonNull(functionId, "functionId must not be null");
Objects.requireNonNull(request, "request must not be null");
try {
String path = "/v1/function/" + functionId + "/invoke";
return postAsync(path, request, Object.class).get();
} catch (InterruptedException | ExecutionException e) {
throw new ApiException("Failed to invoke function: " + functionId, e);
}
}

private <T> CompletableFuture<T> getAsync(String path, Class<T> responseType) {
var request =
HttpRequest.newBuilder()
Expand Down Expand Up @@ -399,6 +467,9 @@ class InMemoryImpl implements BraintrustApiClient {
private final Set<Experiment> experiments =
Collections.newSetFromMap(new ConcurrentHashMap<>());
private final List<Prompt> prompts = new ArrayList<>();
private final List<Function> functions = new ArrayList<>();
private final Map<String, java.util.function.Function<FunctionInvokeRequest, Object>>
functionInvokers = new ConcurrentHashMap<>();

public InMemoryImpl(OrganizationAndProjectInfo... organizationAndProjectInfos) {
this.organizationAndProjectInfos =
Expand Down Expand Up @@ -583,6 +654,18 @@ public Optional<Dataset> getDataset(String datasetId) {
public List<Dataset> queryDatasets(String projectName, String datasetName) {
return List.of();
}

@Override
public Optional<Function> getFunction(
@Nonnull String projectName, @Nonnull String slug, @Nullable String version) {
throw new RuntimeException("will not be invoked");
}

@Override
public Object invokeFunction(
@Nonnull String functionId, @Nonnull FunctionInvokeRequest request) {
throw new RuntimeException("will not be invoked");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm beginning to switch all braintrust tests away from mocks/double in favor of VCR so these test methods won't be needed

}
}

// Request/Response DTOs
Expand Down Expand Up @@ -681,4 +764,80 @@ record Prompt(
Optional<Object> metadata) {}

record PromptListResponse(List<Prompt> objects) {}

// Function models for remote scorers/prompts/tools

/**
* Represents a Braintrust function (scorer, prompt, tool, or task). Functions can be invoked
* remotely via the API.
*/
record Function(
String id,
String projectId,
String orgId,
String name,
String slug,
Optional<String> description,
String created,
Optional<Object> functionData,
Optional<Object> promptData,
Optional<List<String>> tags,
Optional<Object> metadata,
Optional<String> functionType,
Optional<Object> origin,
Optional<Object> functionSchema) {}

record FunctionListResponse(List<Function> objects) {}

/**
* Request body for invoking a function. The input field wraps the function arguments.
*
* <p>For remote Python/TypeScript scorers, the scorer handler parameters (input, output,
* expected, metadata) must be wrapped in the outer input field.
*/
record FunctionInvokeRequest(@Nullable Object input, @Nullable String version) {

/** Create a simple invoke request with just input */
public static FunctionInvokeRequest of(Object input) {
return new FunctionInvokeRequest(input, null);
}

/** Create a simple invoke request with input and version */
public static FunctionInvokeRequest of(Object input, @Nullable String version) {
return new FunctionInvokeRequest(input, version);
}

/**
* Create an invoke request for a scorer with input, output, expected, and metadata. This
* maps to the standard scorer handler signature: handler(input, output, expected, metadata)
*
* <p>The scorer args are wrapped in the outer input field as required by the invoke API.
*/
public static FunctionInvokeRequest forScorer(
Object input, Object output, Object expected, Object metadata) {
return forScorer(input, output, expected, metadata, null);
}

/**
* Create an invoke request for a scorer with input, output, expected, metadata, and
* version. This maps to the standard scorer handler signature: handler(input, output,
* expected, metadata)
*
* <p>The scorer args are wrapped in the outer input field as required by the invoke API.
*/
public static FunctionInvokeRequest forScorer(
Object input,
Object output,
Object expected,
Object metadata,
@Nullable String version) {
// Wrap scorer args in an inner map that becomes the outer "input" field
var scorerArgs = new java.util.LinkedHashMap<String, Object>();
scorerArgs.put("input", input);
scorerArgs.put("output", output);
scorerArgs.put("expected", expected);
scorerArgs.put("metadata", metadata);
return new FunctionInvokeRequest(scorerArgs, version);
}
}
}
52 changes: 48 additions & 4 deletions src/main/java/dev/braintrust/devserver/Devserver.java
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,18 @@ private void handleEval(HttpExchange exchange) throws IOException {
return;
}

// TODO: support remote scorers
// Resolve remote scorers from the request
List<Scorer<Object, Object>> remoteScorers = new ArrayList<>();
if (request.getScores() != null) {
var apiClient = context.getBraintrust().apiClient();
for (var remoteScorer : request.getScores()) {
remoteScorers.add(resolveRemoteScorer(remoteScorer, apiClient));
}
log.debug(
"Resolved {} remote scorer(s): {}",
remoteScorers.size(),
remoteScorers.stream().map(Scorer::getName).toList());
}

String datasetDescription =
hasInlineData
Expand All @@ -308,7 +319,7 @@ private void handleEval(HttpExchange exchange) throws IOException {
if (isStreaming) {
// SSE streaming response - errors handled inside
log.debug("Starting streaming evaluation for '{}'", request.getName());
handleStreamingEval(exchange, eval, request, context);
handleStreamingEval(exchange, eval, request, context, remoteScorers);
} else {
throw new NotSupportedYetException("non-streaming responses");
}
Expand All @@ -325,7 +336,11 @@ private void handleEval(HttpExchange exchange) throws IOException {

@SuppressWarnings({"unchecked", "rawtypes"})
private void handleStreamingEval(
HttpExchange exchange, RemoteEval eval, EvalRequest request, RequestContext context)
HttpExchange exchange,
RemoteEval eval,
EvalRequest request,
RequestContext context,
List<Scorer<Object, Object>> remoteScorers)
throws Exception {
// Set SSE headers
exchange.getResponseHeaders().set("Content-Type", "text/event-stream");
Expand Down Expand Up @@ -423,7 +438,12 @@ private void handleStreamingEval(
taskResult);
}
// run scorers - one score span per scorer
for (var scorer : (List<Scorer<?, ?>>) eval.getScorers()) {
// Combine local scorers from RemoteEval with remote scorers
// from request
List<Scorer<?, ?>> allScorers =
new ArrayList<>(eval.getScorers());
allScorers.addAll(remoteScorers);
for (var scorer : allScorers) {
var scoreSpan = tracer.spanBuilder("score").startSpan();
try (var unused =
Context.current()
Expand Down Expand Up @@ -1037,6 +1057,30 @@ private static ParentInfo extractParentInfo(EvalRequest request) {
}
}

/**
* Resolve a remote scorer from the eval request into a Scorer instance.
*
* @param remoteScorer the remote scorer specification from the request
* @param apiClient the API client to use for invoking the scorer function
* @return a Scorer that invokes the remote function
* @throws IllegalArgumentException if the function_id is missing
*/
private static Scorer<Object, Object> resolveRemoteScorer(
EvalRequest.RemoteScorer remoteScorer, BraintrustApiClient apiClient) {
var functionIdSpec = remoteScorer.getFunctionId();

if (functionIdSpec == null || functionIdSpec.getFunctionId() == null) {
throw new IllegalArgumentException(
"Remote scorer '" + remoteScorer.getName() + "' missing function_id");
}

return new ScorerBrainstoreImpl<>(
apiClient,
functionIdSpec.getFunctionId(),
remoteScorer.getName(),
functionIdSpec.getVersion());
}

public static class Builder {
private @Nullable BraintrustConfig config = null;
private String host = "localhost";
Expand Down
31 changes: 31 additions & 0 deletions src/main/java/dev/braintrust/eval/Scorer.java
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package dev.braintrust.eval;

import dev.braintrust.api.BraintrustApiClient;
import java.util.List;
import java.util.function.BiFunction;
import java.util.function.Function;
import javax.annotation.Nullable;

/**
* A scorer evaluates the result of a test case with a score between 0 (inclusive) and 1
Expand Down Expand Up @@ -49,4 +51,33 @@ public List<Score> score(TaskResult<INPUT, OUTPUT> taskResult) {
}
};
}

/**
* Fetch a scorer from Braintrust by project name and slug.
*
* @param apiClient the API client to use
* @param projectName the name of the project containing the scorer
* @param scorerSlug the unique slug identifier for the scorer
* @param version optional version of the scorer to fetch
* @return a Scorer that invokes the remote function
* @throws RuntimeException if the scorer is not found
*/
static <INPUT, OUTPUT> Scorer<INPUT, OUTPUT> fetchFromBraintrust(
BraintrustApiClient apiClient,
String projectName,
String scorerSlug,
@Nullable String version) {
var function =
apiClient
.getFunction(projectName, scorerSlug, version)
.orElseThrow(
() ->
new RuntimeException(
"Scorer not found: project="
+ projectName
+ ", slug="
+ scorerSlug));

return new ScorerBrainstoreImpl<>(apiClient, function.id(), function.name(), version);
}
}
Loading