Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion engine/packages/api-builder/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ lazy_static::lazy_static! {
*REGISTRY
).unwrap();
pub static ref API_REQUEST_ERRORS: IntCounterVec = register_int_counter_vec_with_registry!(
"api_request_errors",
"api_request_errors_total",
"All errors made to this request.",
&["method", "path", "status", "error_code"],
*REGISTRY,
Expand Down
60 changes: 37 additions & 23 deletions engine/packages/gasoline/src/builder/workflow/lupe.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::time::Instant;
use std::time::{Duration, Instant};

use anyhow::Result;
use serde::{Serialize, de::DeserializeOwned};
Expand Down Expand Up @@ -117,12 +117,11 @@ impl<'a, S: Serialize + DeserializeOwned> LoopBuilder<'a, S> {

// Used to defer loop upsertion for parallelization
let mut loop_event_upsert_fut = None;
let mut iteration_dt = Duration::ZERO;

loop {
ctx.check_stop()?;

let start_instant = Instant::now();

// Create a new branch for each iteration of the loop at location {...loop location, iteration idx}
let mut iteration_branch = loop_branch.branch_inner(
ctx.input().clone(),
Expand All @@ -140,8 +139,7 @@ impl<'a, S: Serialize + DeserializeOwned> LoopBuilder<'a, S> {
let i = iteration;

// Async block for instrumentation purposes
let (dt2, res) = async {
let start_instant2 = Instant::now();
let res = async {
let db2 = ctx.db().clone();

// NOTE: Great care has been taken to optimize this function. This join allows multiple
Expand All @@ -151,9 +149,14 @@ impl<'a, S: Serialize + DeserializeOwned> LoopBuilder<'a, S> {
// commit the loop event. This only happens on the first iteration of the loop
// 2. Second, we commit the branch event for the current iteration
// 3. Third, we run the user's loop code
// 4. Last, if we have to upsert the loop event, we save the future and process it in the
// 4. Last, if we have to upsert the loop event, we save the future and poll it in the
// next iteration of the loop as part of this join
let (loop_event_commit_res, loop_event_upsert_res, branch_commit_res, loop_res) = tokio::join!(
let (
loop_event_commit_res,
loop_event_upsert_res,
branch_commit_res,
(loop_res, cb_dt),
) = tokio::join!(
async {
if let Some(loop_event_init_fut) = loop_event_init_fut.take() {
loop_event_init_fut.await
Expand All @@ -163,10 +166,14 @@ impl<'a, S: Serialize + DeserializeOwned> LoopBuilder<'a, S> {
},
async {
if let Some(loop_event_upsert_fut) = loop_event_upsert_fut.take() {
loop_event_upsert_fut.await
} else {
Ok(())
let start_instant = Instant::now();
loop_event_upsert_fut.await?;
metrics::LOOP_COMMIT_DURATION
.with_label_values(&[&ctx.name().to_string()])
.observe(start_instant.elapsed().as_secs_f64());
}

anyhow::Ok(())
},
async {
// Insert event if iteration is not a replay
Expand All @@ -177,22 +184,35 @@ impl<'a, S: Serialize + DeserializeOwned> LoopBuilder<'a, S> {
ctx.version(),
Some(&loop_location),
)
.await
} else {
Ok(())
.await?;

// Only record iteration duration if its not a replay
metrics::LOOP_ITERATION_DURATION
.with_label_values(&[&ctx.name().to_string()])
.observe(iteration_dt.as_secs_f64());
}

anyhow::Ok(())
},
cb(&mut iteration_branch, &mut state),
async {
let iteration_start_instant = Instant::now();

(
cb(&mut iteration_branch, &mut state).await,
iteration_start_instant.elapsed(),
)
}
);

loop_event_commit_res?;
loop_event_upsert_res?;
branch_commit_res?;

iteration_dt = cb_dt;

// Run loop
match loop_res? {
Loop::Continue => {
let dt2 = start_instant2.elapsed().as_secs_f64();
iteration += 1;

// Commit workflow state to db
Expand Down Expand Up @@ -226,10 +246,9 @@ impl<'a, S: Serialize + DeserializeOwned> LoopBuilder<'a, S> {
});
}

anyhow::Ok((dt2, None))
anyhow::Ok(None)
}
Loop::Break(res) => {
let dt2 = start_instant2.elapsed().as_secs_f64();
iteration += 1;

let state_val = serde_json::value::to_raw_value(&state)
Expand All @@ -252,7 +271,7 @@ impl<'a, S: Serialize + DeserializeOwned> LoopBuilder<'a, S> {
)
.await?;

Ok((dt2, Some(res)))
Ok(Some(res))
}
}
}
Expand All @@ -262,11 +281,6 @@ impl<'a, S: Serialize + DeserializeOwned> LoopBuilder<'a, S> {
// Validate no leftover events
iteration_branch.cursor().check_clear()?;

let dt = start_instant.elapsed().as_secs_f64();
metrics::LOOP_ITERATION_DURATION
.with_label_values(&[&ctx.name().to_string()])
.observe(dt - dt2);

if let Some(res) = res {
break res;
}
Expand Down
10 changes: 9 additions & 1 deletion engine/packages/gasoline/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,17 @@ lazy_static::lazy_static! {
*REGISTRY
).unwrap();

pub static ref LOOP_COMMIT_DURATION: HistogramVec = register_histogram_vec_with_registry!(
"gasoline_loop_commit_duration",
"Total duration of a single loop commit.",
&["workflow_name"],
BUCKETS.to_vec(),
*REGISTRY
).unwrap();

pub static ref LOOP_ITERATION_DURATION: HistogramVec = register_histogram_vec_with_registry!(
"gasoline_loop_iteration_duration",
"Total duration of a single loop iteration (excluding its body).",
"Total duration of a single loop iteration.",
&["workflow_name"],
BUCKETS.to_vec(),
*REGISTRY
Expand Down
5 changes: 5 additions & 0 deletions engine/packages/guard-core/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ lazy_static! {
"Number of active in-flight counters",
*REGISTRY
).unwrap();
pub static ref IN_FLIGHT_REQUEST_COUNT: IntGauge = register_int_gauge_with_registry!(
"guard_in_flight_request_count",
"Number of active in-flight requests",
*REGISTRY
).unwrap();

// MARK: TCP
pub static ref TCP_CONNECTION_TOTAL: IntCounter = register_int_counter_with_registry!(
Expand Down
103 changes: 13 additions & 90 deletions engine/packages/guard-core/src/proxy_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ use serde_json;

use rivet_runner_protocol as protocol;
use std::{
borrow::Cow,
collections::HashMap as StdHashMap,
net::SocketAddr,
sync::Arc,
time::{Duration, Instant},
Expand Down Expand Up @@ -134,44 +132,13 @@ pub struct RouteConfig {
pub enum RoutingOutput {
/// Return the data to route to.
Route(RouteConfig),
/// Return a custom response.
Response(StructuredResponse),
/// Return a custom serve handler.
CustomServe(Arc<dyn CustomServeTrait>),
}

#[derive(Clone, Debug)]
pub struct StructuredResponse {
pub status: StatusCode,
pub message: Cow<'static, str>,
pub docs: Option<Cow<'static, str>>,
}

impl StructuredResponse {
pub fn build_response(&self) -> Result<Response<ResponseBody>> {
let mut body = StdHashMap::new();
body.insert("message", self.message.clone().into_owned());

if let Some(docs) = &self.docs {
body.insert("docs", docs.clone().into_owned());
}

let body_json = serde_json::to_string(&body)?;
let bytes = Bytes::from(body_json);

let response = Response::builder()
.status(self.status)
.header(hyper::header::CONTENT_TYPE, "application/json")
.body(ResponseBody::Full(Full::new(bytes)))?;

Ok(response)
}
}

#[derive(Clone)]
enum ResolveRouteOutput {
Target(RouteTarget),
Response(StructuredResponse),
CustomServe(Arc<dyn CustomServeTrait>),
}

Expand Down Expand Up @@ -347,6 +314,7 @@ pub struct ProxyState {
cache_key_fn: CacheKeyFn,
middleware_fn: MiddlewareFn,
route_cache: RouteCache,
// We use moka::Cache instead of scc::HashMap because it automatically handles TTL and capacity
rate_limiters: Cache<(Id, std::net::IpAddr), Arc<Mutex<RateLimiter>>>,
in_flight_counters: Cache<(Id, std::net::IpAddr), Arc<Mutex<InFlightCounter>>>,
in_flight_requests: Cache<protocol::RequestId, ()>,
Expand Down Expand Up @@ -478,15 +446,6 @@ impl ProxyState {
Err(errors::NoRouteTargets.build())
}
}
RoutingOutput::Response(response) => {
tracing::debug!(
hostname = %hostname_only,
path = %path,
status = ?response.status,
"Routing returned custom response"
);
Ok(ResolveRouteOutput::Response(response))
}
RoutingOutput::CustomServe(handler) => {
tracing::debug!(
hostname = %hostname_only,
Expand Down Expand Up @@ -660,6 +619,7 @@ impl ProxyState {

// Release request ID
self.in_flight_requests.invalidate(&request_id).await;
metrics::IN_FLIGHT_REQUEST_COUNT.set(self.in_flight_requests.entry_count() as i64);
}

/// Generate a unique request ID that is not currently in flight
Expand All @@ -668,11 +628,19 @@ impl ProxyState {

for attempt in 0..MAX_TRIES {
let request_id = protocol::util::generate_request_id();
let mut inserted = false;

// Check if this ID is already in use
if self.in_flight_requests.get(&request_id).await.is_none() {
// Insert the ID and return it
self.in_flight_requests.insert(request_id, ()).await;
self.in_flight_requests
.entry(request_id)
.or_insert_with(async {
inserted = true;
})
.await;

if inserted {
metrics::IN_FLIGHT_REQUEST_COUNT.set(self.in_flight_requests.entry_count() as i64);

return Ok(request_id);
}

Expand Down Expand Up @@ -769,10 +737,6 @@ impl ProxyService {

// Resolve target
let target = target_res?;
if let ResolveRouteOutput::Response(response) = &target {
// Return the custom response
return response.build_response();
}

let actor_id = if let ResolveRouteOutput::Target(target) = &target {
target.actor_id
Expand Down Expand Up @@ -1088,9 +1052,6 @@ impl ProxyService {
}
.build());
}
ResolveRouteOutput::Response(_) => {
unreachable!()
}
ResolveRouteOutput::CustomServe(mut handler) => {
let req_headers = req.headers().clone();
let req_method = req.method().clone();
Expand Down Expand Up @@ -1554,20 +1515,6 @@ impl ProxyService {
Ok(ResolveRouteOutput::Target(new_target)) => {
target = new_target;
}
Ok(ResolveRouteOutput::Response(response)) => {
tracing::debug!(
status=?response.status,
message=?response.message,
docs=?response.docs,
"got response instead of websocket target",
);

// Close the WebSocket connection with the response message
let _ = client_ws
.close(Some(str_to_close_frame(response.message.as_ref())))
.await;
return;
}
Ok(ResolveRouteOutput::CustomServe(_)) => {
let err = errors::WebSocketTargetChanged.build();
tracing::warn!(
Expand Down Expand Up @@ -1907,7 +1854,6 @@ impl ProxyService {
.instrument(tracing::info_span!("handle_ws_task_target")),
);
}
ResolveRouteOutput::Response(_) => unreachable!(),
ResolveRouteOutput::CustomServe(mut handler) => {
tracing::debug!(%req_path, "Spawning task to handle WebSocket communication");
let mut request_context = request_context.clone();
Expand Down Expand Up @@ -2090,19 +2036,6 @@ impl ProxyService {
handler = new_handler;
continue;
}
Ok(ResolveRouteOutput::Response(response)) => {
ws_handle
.send(to_hyper_close(Some(str_to_close_frame(
response.message.as_ref(),
))))
.await?;

// Flush to ensure close frame is sent
ws_handle.flush().await?;

// Keep TCP connection open briefly to allow client to process close
tokio::time::sleep(WEBSOCKET_CLOSE_LINGER).await;
}
Ok(ResolveRouteOutput::Target(_)) => {
let err = errors::WebSocketTargetChanged.build();
tracing::warn!(
Expand Down Expand Up @@ -2666,16 +2599,6 @@ pub fn is_ws_hibernate(err: &anyhow::Error) -> bool {
}
}

fn str_to_close_frame(err: &str) -> CloseFrame {
// NOTE: reason cannot be more than 123 bytes as per the WS protocol spec
let reason = rivet_util::safe_slice(err, 0, 123).into();

CloseFrame {
code: CloseCode::Error,
reason,
}
}

fn err_to_close_frame(err: anyhow::Error, ray_id: Option<Id>) -> CloseFrame {
let rivet_err = err
.chain()
Expand Down
Loading