diff --git a/.bleep b/.bleep index b99e531f..9c45175b 100644 --- a/.bleep +++ b/.bleep @@ -1 +1 @@ -c90e4ce2596840c60b5ff1737e2141447e5953e1 +f123f5e43e9ada31a0e541b917ea674527fd06a3 \ No newline at end of file diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md index a8abcb10..cb57eb61 100644 --- a/docs/user_guide/index.md +++ b/docs/user_guide/index.md @@ -20,6 +20,7 @@ In this guide, we will cover the most used features, operations and settings of * [Examples: take control of the request](modify_filter.md) * [Connection pooling and reuse](pooling.md) * [Handling failures and failover](failover.md) +* [RateLimiter quickstart](rate_limiter.md) ## Advanced topics (WIP) * [Pingora internals](internals.md) diff --git a/docs/user_guide/rate_limiter.md b/docs/user_guide/rate_limiter.md new file mode 100644 index 00000000..fe337a19 --- /dev/null +++ b/docs/user_guide/rate_limiter.md @@ -0,0 +1,167 @@ +# **RateLimiter quickstart** +Pingora provides a crate `pingora-limits` which provides a simple and easy to use rate limiter for your application. Below is an example of how you can use [`Rate`](https://docs.rs/pingora-limits/latest/pingora_limits/rate/struct.Rate.html) to create an application that uses multiple limiters to restrict the rate at which requests can be made on a per-app basis (determined by a request header). + +## Steps +1. Add the following dependencies to your `Cargo.toml`: + ```toml + async-trait="0.1" + pingora = { version = "0.3", features = [ "lb" ] } + pingora-limits = "0.3.0" + once_cell = "1.19.0" + ``` +2. Declare a global rate limiter map to store the rate limiter for each client. In this example, we use `appid`. +3. Override the `request_filter` method in the `ProxyHttp` trait to implement rate limiting. + 1. Retrieve the client appid from header. + 2. Retrieve the current window requests from the rate limiter map. If there is no rate limiter for the client, create a new one and insert it into the map. + 3. If the current window requests exceed the limit, return 429 and set RateLimiter associated headers. + 4. If the request is not rate limited, return `Ok(false)` to continue the request. + +## Example +```rust +use async_trait::async_trait; +use once_cell::sync::Lazy; +use pingora::http::ResponseHeader; +use pingora::prelude::*; +use pingora_limits::rate::Rate; +use std::sync::Arc; +use std::time::Duration; + +fn main() { + let mut server = Server::new(Some(Opt::default())).unwrap(); + server.bootstrap(); + let mut upstreams = LoadBalancer::try_from_iter(["1.1.1.1:443", "1.0.0.1:443"]).unwrap(); + // Set health check + let hc = TcpHealthCheck::new(); + upstreams.set_health_check(hc); + upstreams.health_check_frequency = Some(Duration::from_secs(1)); + // Set background service + let background = background_service("health check", upstreams); + let upstreams = background.task(); + // Set load balancer + let mut lb = http_proxy_service(&server.configuration, LB(upstreams)); + lb.add_tcp("0.0.0.0:6188"); + + // let rate = Rate + server.add_service(background); + server.add_service(lb); + server.run_forever(); +} + +pub struct LB(Arc>); + +impl LB { + pub fn get_request_appid(&self, session: &mut Session) -> Option { + match session + .req_header() + .headers + .get("appid") + .map(|v| v.to_str()) + { + None => None, + Some(v) => match v { + Ok(v) => Some(v.to_string()), + Err(_) => None, + }, + } + } +} + +// Rate limiter +static RATE_LIMITER: Lazy = Lazy::new(|| Rate::new(Duration::from_secs(1))); + +// max request per second per client +static MAX_REQ_PER_SEC: isize = 1; + +#[async_trait] +impl ProxyHttp for LB { + type CTX = (); + + fn new_ctx(&self) {} + + async fn upstream_peer( + &self, + _session: &mut Session, + _ctx: &mut Self::CTX, + ) -> Result> { + let upstream = self.0.select(b"", 256).unwrap(); + // Set SNI + let peer = Box::new(HttpPeer::new(upstream, true, "one.one.one.one".to_string())); + Ok(peer) + } + + async fn upstream_request_filter( + &self, + _session: &mut Session, + upstream_request: &mut RequestHeader, + _ctx: &mut Self::CTX, + ) -> Result<()> + where + Self::CTX: Send + Sync, + { + upstream_request + .insert_header("Host", "one.one.one.one") + .unwrap(); + Ok(()) + } + + async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result + where + Self::CTX: Send + Sync, + { + let appid = match self.get_request_appid(session) { + None => return Ok(false), // no client appid found, skip rate limiting + Some(addr) => addr, + }; + + // retrieve the current window requests + let curr_window_requests = RATE_LIMITER.observe(&appid, 1); + if curr_window_requests > MAX_REQ_PER_SEC { + // rate limited, return 429 + let mut header = ResponseHeader::build(429, None).unwrap(); + header + .insert_header("X-Rate-Limit-Limit", MAX_REQ_PER_SEC.to_string()) + .unwrap(); + header.insert_header("X-Rate-Limit-Remaining", "0").unwrap(); + header.insert_header("X-Rate-Limit-Reset", "1").unwrap(); + session.set_keepalive(None); + session + .write_response_header(Box::new(header), true) + .await?; + return Ok(true); + } + Ok(false) + } +} +``` + +## Testing +To use the example above, + +1. Run your program with `cargo run`. +2. Verify the program is working with a few executions of ` curl localhost:6188 -H "appid:1" -v` + - The first request should work and any later requests that arrive within 1s of a previous request should fail with: + ``` + * Trying 127.0.0.1:6188... + * Connected to localhost (127.0.0.1) port 6188 (#0) + > GET / HTTP/1.1 + > Host: localhost:6188 + > User-Agent: curl/7.88.1 + > Accept: */* + > appid:1 + > + < HTTP/1.1 429 Too Many Requests + < X-Rate-Limit-Limit: 1 + < X-Rate-Limit-Remaining: 0 + < X-Rate-Limit-Reset: 1 + < Date: Sun, 14 Jul 2024 20:29:02 GMT + < Connection: close + < + * Closing connection 0 + ``` + +## Complete Example +You can run the pre-made example code in the [`pingora-proxy` examples folder](https://github.com/cloudflare/pingora/tree/main/pingora-proxy/examples/rate_limiter.rs) with + +``` +cargo run --example rate_limiter +``` \ No newline at end of file diff --git a/pingora-cache/src/key.rs b/pingora-cache/src/key.rs index 1c8c5329..73053cdc 100644 --- a/pingora-cache/src/key.rs +++ b/pingora-cache/src/key.rs @@ -130,7 +130,7 @@ impl CacheKey { /// Storage optimized cache key to keep in memory or in storage // 16 bytes + 8 bytes (+16 * u8) + user_tag.len() + 16 Bytes (Box) -#[derive(Debug, Deserialize, Serialize, Clone, Hash, PartialEq, Eq)] +#[derive(Debug, Deserialize, Serialize, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct CompactCacheKey { pub primary: HashBinary, // save 8 bytes for non-variance but waste 8 bytes for variance vs, store flat 16 bytes diff --git a/pingora-cache/src/lib.rs b/pingora-cache/src/lib.rs index 86cfbdfd..d1346e50 100644 --- a/pingora-cache/src/lib.rs +++ b/pingora-cache/src/lib.rs @@ -77,6 +77,8 @@ pub enum CachePhase { Miss, /// A staled (expired) asset is found Stale, + /// A staled (expired) asset was found, but another request is revalidating it + StaleUpdating, /// A staled (expired) asset was found, so a fresh one was fetched Expired, /// A staled (expired) asset was found, and it was revalidated to be fresh @@ -96,6 +98,7 @@ impl CachePhase { CachePhase::Hit => "hit", CachePhase::Miss => "miss", CachePhase::Stale => "stale", + CachePhase::StaleUpdating => "stale-updating", CachePhase::Expired => "expired", CachePhase::Revalidated => "revalidated", CachePhase::RevalidatedNoCache(_) => "revalidated-nocache", @@ -260,7 +263,7 @@ impl HttpCache { use CachePhase::*; match self.phase { Disabled(_) | Bypass | Miss | Expired | Revalidated | RevalidatedNoCache(_) => true, - Hit | Stale => false, + Hit | Stale | StaleUpdating => false, Uninit | CacheKey => false, // invalid states for this call, treat them as false to keep it simple } } @@ -493,7 +496,8 @@ impl HttpCache { match self.phase { // from CacheKey: set state to miss during cache lookup // from Bypass: response became cacheable, set state to miss to cache - CachePhase::CacheKey | CachePhase::Bypass => { + // from Stale: waited for cache lock, then retried and found asset was gone + CachePhase::CacheKey | CachePhase::Bypass | CachePhase::Stale => { self.phase = CachePhase::Miss; self.inner_mut().traces.start_miss_span(); } @@ -508,6 +512,7 @@ impl HttpCache { match self.phase { CachePhase::Hit | CachePhase::Stale + | CachePhase::StaleUpdating | CachePhase::Revalidated | CachePhase::RevalidatedNoCache(_) => self.inner_mut().body_reader.as_mut().unwrap(), _ => panic!("wrong phase {:?}", self.phase), @@ -543,6 +548,7 @@ impl HttpCache { | CachePhase::Miss | CachePhase::Expired | CachePhase::Stale + | CachePhase::StaleUpdating | CachePhase::Revalidated | CachePhase::RevalidatedNoCache(_) => { let inner = self.inner_mut(); @@ -785,6 +791,14 @@ impl HttpCache { // TODO: remove this asset from cache once finished? } + /// Mark this asset as stale, but being updated separately from this request. + pub fn set_stale_updating(&mut self) { + match self.phase { + CachePhase::Stale => self.phase = CachePhase::StaleUpdating, + _ => panic!("wrong phase {:?}", self.phase), + } + } + /// Update the variance of the [CacheMeta]. /// /// Note that this process may change the lookup `key`, and eventually (when the asset is @@ -853,6 +867,7 @@ impl HttpCache { match self.phase { // TODO: allow in Bypass phase? CachePhase::Stale + | CachePhase::StaleUpdating | CachePhase::Expired | CachePhase::Hit | CachePhase::Revalidated @@ -881,6 +896,7 @@ impl HttpCache { match self.phase { CachePhase::Miss | CachePhase::Stale + | CachePhase::StaleUpdating | CachePhase::Expired | CachePhase::Hit | CachePhase::Revalidated @@ -1005,7 +1021,7 @@ impl HttpCache { /// Whether this request's cache hit is staled fn has_staled_asset(&self) -> bool { - self.phase == CachePhase::Stale + matches!(self.phase, CachePhase::Stale | CachePhase::StaleUpdating) } /// Whether this asset is staled and stale if error is allowed diff --git a/pingora-core/src/protocols/http/compression/brotli.rs b/pingora-core/src/protocols/http/compression/brotli.rs index 956f87da..89f7b4ec 100644 --- a/pingora-core/src/protocols/http/compression/brotli.rs +++ b/pingora-core/src/protocols/http/compression/brotli.rs @@ -42,7 +42,6 @@ impl Decompressor { impl Encode for Decompressor { fn encode(&mut self, input: &[u8], end: bool) -> Result { - // reserve at most 16k const MAX_INIT_COMPRESSED_SIZE_CAP: usize = 4 * 1024; // Brotli compress ratio can be 3.5 to 4.5 const ESTIMATED_COMPRESSION_RATIO: usize = 4; diff --git a/pingora-core/src/protocols/http/compression/gzip.rs b/pingora-core/src/protocols/http/compression/gzip.rs index d64c961b..f7f997d1 100644 --- a/pingora-core/src/protocols/http/compression/gzip.rs +++ b/pingora-core/src/protocols/http/compression/gzip.rs @@ -12,15 +12,65 @@ // See the License for the specific language governing permissions and // limitations under the License. -use super::Encode; +use super::{Encode, COMPRESSION_ERROR}; use bytes::Bytes; -use flate2::write::GzEncoder; -use pingora_error::Result; +use flate2::write::{GzDecoder, GzEncoder}; +use pingora_error::{OrErr, Result}; use std::io::Write; use std::time::{Duration, Instant}; -// TODO: unzip +pub struct Decompressor { + decompress: GzDecoder>, + total_in: usize, + total_out: usize, + duration: Duration, +} + +impl Decompressor { + pub fn new() -> Self { + Decompressor { + decompress: GzDecoder::new(vec![]), + total_in: 0, + total_out: 0, + duration: Duration::new(0, 0), + } + } +} + +impl Encode for Decompressor { + fn encode(&mut self, input: &[u8], end: bool) -> Result { + const MAX_INIT_COMPRESSED_SIZE_CAP: usize = 4 * 1024; + const ESTIMATED_COMPRESSION_RATIO: usize = 3; // estimated 2.5-3x compression + let start = Instant::now(); + self.total_in += input.len(); + // cap the buf size amplification, there is a DoS risk of always allocate + // 3x the memory of the input buffer + let reserve_size = if input.len() < MAX_INIT_COMPRESSED_SIZE_CAP { + input.len() * ESTIMATED_COMPRESSION_RATIO + } else { + input.len() + }; + self.decompress.get_mut().reserve(reserve_size); + self.decompress + .write_all(input) + .or_err(COMPRESSION_ERROR, "while decompress Gzip")?; + // write to vec will never fail, only possible error is that the input data + // was not actually gzip compressed + if end { + self.decompress + .try_finish() + .or_err(COMPRESSION_ERROR, "while decompress Gzip")?; + } + self.total_out += self.decompress.get_ref().len(); + self.duration += start.elapsed(); + Ok(std::mem::take(self.decompress.get_mut()).into()) // into() Bytes will drop excess capacity + } + + fn stat(&self) -> (&'static str, usize, usize, Duration) { + ("de-gzip", self.total_in, self.total_out, self.duration) + } +} pub struct Compressor { // TODO: enum for other compression algorithms @@ -66,6 +116,20 @@ impl Encode for Compressor { } use std::ops::{Deref, DerefMut}; +impl Deref for Decompressor { + type Target = GzDecoder>; + + fn deref(&self) -> &Self::Target { + &self.decompress + } +} + +impl DerefMut for Decompressor { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.decompress + } +} + impl Deref for Compressor { type Target = GzEncoder>; @@ -100,4 +164,21 @@ mod tests_stream { assert!(compressor.get_ref().is_empty()); } + + #[test] + fn gunzip_data() { + let mut decompressor = Decompressor::new(); + + let compressed_bytes = &[ + 0x1f, 0x8b, 0x08, 0, 0, 0, 0, 0, 0, 255, 75, 76, 74, 78, 73, 77, 75, 7, 0, 166, 106, + 42, 49, 7, 0, 0, 0, + ]; + let decompressed = decompressor.encode(compressed_bytes, true).unwrap(); + + assert_eq!(&decompressed[..], b"abcdefg"); + assert_eq!(decompressor.total_in, compressed_bytes.len()); + assert_eq!(decompressor.total_out, decompressed.len()); + + assert!(decompressor.get_ref().is_empty()); + } } diff --git a/pingora-core/src/protocols/http/compression/mod.rs b/pingora-core/src/protocols/http/compression/mod.rs index ad1cd0b4..6236a0d8 100644 --- a/pingora-core/src/protocols/http/compression/mod.rs +++ b/pingora-core/src/protocols/http/compression/mod.rs @@ -67,10 +67,10 @@ pub struct ResponseCompressionCtx(CtxInner); enum CtxInner { HeaderPhase { - decompress_enable: bool, // Store the preferred list to compare with content-encoding accept_encoding: Vec, encoding_levels: [u32; Algorithm::COUNT], + decompress_enable: [bool; Algorithm::COUNT], }, BodyPhase(Option>), } @@ -81,9 +81,9 @@ impl ResponseCompressionCtx { /// The `decompress_enable` flag will tell the ctx to decompress if needed. pub fn new(compression_level: u32, decompress_enable: bool) -> Self { Self(CtxInner::HeaderPhase { - decompress_enable, accept_encoding: Vec::new(), encoding_levels: [compression_level; Algorithm::COUNT], + decompress_enable: [decompress_enable; Algorithm::COUNT], }) } @@ -93,9 +93,9 @@ impl ResponseCompressionCtx { match &self.0 { CtxInner::HeaderPhase { decompress_enable, - accept_encoding: _, encoding_levels: levels, - } => levels.iter().any(|l| *l != 0) || *decompress_enable, + .. + } => levels.iter().any(|l| *l != 0) || decompress_enable.iter().any(|d| *d), CtxInner::BodyPhase(c) => c.is_some(), } } @@ -104,11 +104,7 @@ impl ResponseCompressionCtx { /// algorithm name, in bytes, out bytes, time took for the compression pub fn get_info(&self) -> Option<(&'static str, usize, usize, Duration)> { match &self.0 { - CtxInner::HeaderPhase { - decompress_enable: _, - accept_encoding: _, - encoding_levels: _, - } => None, + CtxInner::HeaderPhase { .. } => None, CtxInner::BodyPhase(c) => c.as_ref().map(|c| c.stat()), } } @@ -119,9 +115,8 @@ impl ResponseCompressionCtx { pub fn adjust_level(&mut self, new_level: u32) { match &mut self.0 { CtxInner::HeaderPhase { - decompress_enable: _, - accept_encoding: _, encoding_levels: levels, + .. } => { *levels = [new_level; Algorithm::COUNT]; } @@ -135,9 +130,8 @@ impl ResponseCompressionCtx { pub fn adjust_algorithm_level(&mut self, algorithm: Algorithm, new_level: u32) { match &mut self.0 { CtxInner::HeaderPhase { - decompress_enable: _, - accept_encoding: _, encoding_levels: levels, + .. } => { levels[algorithm.index()] = new_level; } @@ -145,17 +139,29 @@ impl ResponseCompressionCtx { } } - /// Adjust the decompression flag. + /// Adjust the decompression flag for all compression algorithms. /// # Panic /// This function will panic if it has already started encoding the response body. pub fn adjust_decompression(&mut self, enabled: bool) { match &mut self.0 { CtxInner::HeaderPhase { - decompress_enable, - accept_encoding: _, - encoding_levels: _, + decompress_enable, .. } => { - *decompress_enable = enabled; + *decompress_enable = [enabled; Algorithm::COUNT]; + } + CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"), + } + } + + /// Adjust the decompression flag for a specific algorithm. + /// # Panic + /// This function will panic if it has already started encoding the response body. + pub fn adjust_algorithm_decompression(&mut self, algorithm: Algorithm, enabled: bool) { + match &mut self.0 { + CtxInner::HeaderPhase { + decompress_enable, .. + } => { + decompress_enable[algorithm.index()] = enabled; } CtxInner::BodyPhase(_) => panic!("Wrong phase: BodyPhase"), } @@ -208,7 +214,9 @@ impl ResponseCompressionCtx { let encoder = match action { Action::Noop => None, Action::Compress(algorithm) => algorithm.compressor(levels[algorithm.index()]), - Action::Decompress(algorithm) => algorithm.decompressor(*decompress_enable), + Action::Decompress(algorithm) => { + algorithm.decompressor(decompress_enable[algorithm.index()]) + } }; if encoder.is_some() { adjust_response_header(resp, &action); @@ -317,6 +325,7 @@ impl Algorithm { None } else { match self { + Self::Gzip => Some(Box::new(gzip::Decompressor::new())), Self::Brotli => Some(Box::new(brotli::Decompressor::new())), _ => None, // not implemented } diff --git a/pingora-core/src/protocols/http/v2/client.rs b/pingora-core/src/protocols/http/v2/client.rs index 86a3fe3d..9bdbff46 100644 --- a/pingora-core/src/protocols/http/v2/client.rs +++ b/pingora-core/src/protocols/http/v2/client.rs @@ -349,6 +349,19 @@ impl Http2Session { if self.ping_timedout() { e.etype = PING_TIMEDOUT; } + + // is_go_away: retry via another connection, this connection is being teardown + // should retry + if self.response_header.is_none() { + if let Some(err) = e.root_cause().downcast_ref::() { + if err.is_go_away() + && err.is_remote() + && err.reason().map_or(false, |r| r == h2::Reason::NO_ERROR) + { + e.retry = true.into(); + } + } + } e } } @@ -367,7 +380,7 @@ pub fn write_body(send_body: &mut SendStream, data: Bytes, end: bool) -> /* Types of errors during h2 header read 1. peer requests to downgrade to h1, mostly IIS server for NTLM: we will downgrade and retry 2. peer sends invalid h2 frames, usually sending h1 only header: we will downgrade and retry - 3. peer sends GO_AWAY(NO_ERROR) on reused conn, usually hit http2_max_requests: we will retry + 3. peer sends GO_AWAY(NO_ERROR) connection is being shut down: we will retry 4. peer IO error on reused conn, usually firewall kills old conn: we will retry 5. All other errors will terminate the request */ @@ -393,9 +406,8 @@ fn handle_read_header_error(e: h2::Error) -> Box { && e.reason().map_or(false, |r| r == h2::Reason::NO_ERROR) { // is_go_away: retry via another connection, this connection is being teardown - // only retry if the connection is reused let mut err = Error::because(H2Error, "while reading h2 header", e); - err.retry = RetryType::ReusedOnly; + err.retry = true.into(); err } else if e.is_io() { // is_io: typical if a previously reused connection silently drops it diff --git a/pingora-core/src/protocols/raw_connect.rs b/pingora-core/src/protocols/raw_connect.rs index df82413c..4aeb10c5 100644 --- a/pingora-core/src/protocols/raw_connect.rs +++ b/pingora-core/src/protocols/raw_connect.rs @@ -68,9 +68,13 @@ where H: Iterator)>, { // TODO: valid that host doesn't have port - // TODO: support adding ad-hoc headers - let authority = format!("{host}:{port}"); + let authority = if host.parse::().is_ok() { + format!("[{host}]:{port}") + } else { + format!("{host}:{port}") + }; + let req = http::request::Builder::new() .version(http::Version::HTTP_11) .method(http::method::Method::CONNECT) @@ -217,6 +221,19 @@ mod test_sync { assert_eq!(req.headers.get("Host").unwrap(), "pingora.org:123"); assert_eq!(req.headers.get("foo").unwrap(), "bar"); } + + #[test] + fn test_generate_connect_header_ipv6() { + let mut headers = BTreeMap::new(); + headers.insert(String::from("foo"), b"bar".to_vec()); + let req = generate_connect_header("::1", 123, headers.iter()).unwrap(); + + assert_eq!(req.method, http::method::Method::CONNECT); + assert_eq!(req.uri.authority().unwrap(), "[::1]:123"); + assert_eq!(req.headers.get("Host").unwrap(), "[::1]:123"); + assert_eq!(req.headers.get("foo").unwrap(), "bar"); + } + #[test] fn test_request_to_wire_auth_form() { let new_request = http::Request::builder() diff --git a/pingora-load-balancing/src/health_check.rs b/pingora-load-balancing/src/health_check.rs index ba5e6c0e..ac63579b 100644 --- a/pingora-load-balancing/src/health_check.rs +++ b/pingora-load-balancing/src/health_check.rs @@ -24,6 +24,16 @@ use pingora_http::{RequestHeader, ResponseHeader}; use std::sync::Arc; use std::time::Duration; +/// [HealthObserve] is an interface for observing health changes of backends, +/// this is what's used for our health observation callback. +#[async_trait] +pub trait HealthObserve { + /// Observes the health of a [Backend], can be used for monitoring purposes. + async fn observe(&self, target: &Backend, healthy: bool); +} +/// Provided to a [HealthCheck] to observe changes to [Backend] health. +pub type HealthObserveCallback = Box; + /// [HealthCheck] is the interface to implement health check for backends #[async_trait] pub trait HealthCheck { @@ -31,6 +41,10 @@ pub trait HealthCheck { /// /// `Ok(())`` if the check passes, otherwise the check fails. async fn check(&self, target: &Backend) -> Result<()>; + + /// Called when the health changes for a [Backend]. + async fn health_status_change(&self, _target: &Backend, _healthy: bool) {} + /// This function defines how many *consecutive* checks should flip the health of a backend. /// /// For example: with `success``: `true`: this function should return the @@ -56,6 +70,8 @@ pub struct TcpHealthCheck { /// set, it will also try to establish a TLS connection on top of the TCP connection. pub peer_template: BasicPeer, connector: TransportConnector, + /// A callback that is invoked when the `healthy` status changes for a [Backend]. + pub health_changed_callback: Option, } impl Default for TcpHealthCheck { @@ -67,6 +83,7 @@ impl Default for TcpHealthCheck { consecutive_failure: 1, peer_template, connector: TransportConnector::new(None), + health_changed_callback: None, } } } @@ -110,6 +127,12 @@ impl HealthCheck for TcpHealthCheck { peer._address = target.addr.clone(); self.connector.get_stream(&peer).await.map(|_| {}) } + + async fn health_status_change(&self, target: &Backend, healthy: bool) { + if let Some(callback) = &self.health_changed_callback { + callback.observe(target, healthy).await; + } + } } type Validator = Box Result<()> + Send + Sync>; @@ -147,6 +170,8 @@ pub struct HttpHealthCheck { /// Sometimes the health check endpoint lives one a different port than the actual backend. /// Setting this option allows the health check to perform on the given port of the backend IP. pub port_override: Option, + /// A callback that is invoked when the `healthy` status changes for a [Backend]. + pub health_changed_callback: Option, } impl HttpHealthCheck { @@ -174,6 +199,7 @@ impl HttpHealthCheck { req, validator: None, port_override: None, + health_changed_callback: None, } } @@ -235,6 +261,11 @@ impl HealthCheck for HttpHealthCheck { Ok(()) } + async fn health_status_change(&self, target: &Backend, healthy: bool) { + if let Some(callback) = &self.health_changed_callback { + callback.observe(target, healthy).await; + } + } } #[derive(Clone)] @@ -313,8 +344,14 @@ impl Health { #[cfg(test)] mod test { + use std::{ + collections::{BTreeSet, HashMap}, + sync::atomic::{AtomicU16, Ordering}, + }; + use super::*; - use crate::SocketAddr; + use crate::{discovery, Backends, SocketAddr}; + use async_trait::async_trait; use http::Extensions; #[tokio::test] @@ -387,4 +424,78 @@ mod test { assert!(http_check.check(&backend).await.is_ok()); } + + #[tokio::test] + async fn test_health_observe() { + struct Observe { + unhealthy_count: Arc, + } + #[async_trait] + impl HealthObserve for Observe { + async fn observe(&self, _target: &Backend, healthy: bool) { + if !healthy { + self.unhealthy_count.fetch_add(1, Ordering::Relaxed); + } + } + } + + let good_backend = Backend::new("127.0.0.1:79").unwrap(); + let new_good_backends = || -> (BTreeSet, HashMap) { + let mut healthy = HashMap::new(); + healthy.insert(good_backend.hash_key(), true); + let mut backends = BTreeSet::new(); + backends.extend(vec![good_backend.clone()]); + (backends, healthy) + }; + // tcp health check + { + let unhealthy_count = Arc::new(AtomicU16::new(0)); + let ob = Observe { + unhealthy_count: unhealthy_count.clone(), + }; + let bob = Box::new(ob); + let tcp_check = TcpHealthCheck { + health_changed_callback: Some(bob), + ..Default::default() + }; + + let discovery = discovery::Static::default(); + let mut backends = Backends::new(Box::new(discovery)); + backends.set_health_check(Box::new(tcp_check)); + let result = new_good_backends(); + backends.do_update(result.0, result.1, |_backend: Arc>| {}); + // the backend is ready + assert!(backends.ready(&good_backend)); + + // run health check + backends.run_health_check(false).await; + assert!(1 == unhealthy_count.load(Ordering::Relaxed)); + // backend is unhealthy + assert!(!backends.ready(&good_backend)); + } + + // http health check + { + let unhealthy_count = Arc::new(AtomicU16::new(0)); + let ob = Observe { + unhealthy_count: unhealthy_count.clone(), + }; + let bob = Box::new(ob); + + let mut https_check = HttpHealthCheck::new("one.one.one.one", true); + https_check.health_changed_callback = Some(bob); + + let discovery = discovery::Static::default(); + let mut backends = Backends::new(Box::new(discovery)); + backends.set_health_check(Box::new(https_check)); + let result = new_good_backends(); + backends.do_update(result.0, result.1, |_backend: Arc>| {}); + // the backend is ready + assert!(backends.ready(&good_backend)); + // run health check + backends.run_health_check(false).await; + assert!(1 == unhealthy_count.load(Ordering::Relaxed)); + assert!(!backends.ready(&good_backend)); + } + } } diff --git a/pingora-load-balancing/src/lib.rs b/pingora-load-balancing/src/lib.rs index c5adea30..4a7433ee 100644 --- a/pingora-load-balancing/src/lib.rs +++ b/pingora-load-balancing/src/lib.rs @@ -77,12 +77,18 @@ impl Backend { /// Create a new [Backend] with `weight` 1. The function will try to parse /// `addr` into a [std::net::SocketAddr]. pub fn new(addr: &str) -> Result { + Self::new_with_weight(addr, 1) + } + + /// Creates a new [Backend] with the specified `weight`. The function will try to parse + /// `addr` into a [std::net::SocketAddr]. + pub fn new_with_weight(addr: &str, weight: usize) -> Result { let addr = addr .parse() .or_err(ErrorType::InternalError, "invalid socket addr")?; Ok(Backend { addr: SocketAddr::Inet(addr), - weight: 1, + weight, ext: Extensions::new(), }) // TODO: UDS @@ -260,6 +266,7 @@ impl Backends { let flipped = h.observe_health(errored.is_none(), check.health_threshold(errored.is_none())); if flipped { + check.health_status_change(backend, errored.is_none()).await; if let Some(e) = errored { warn!("{backend:?} becomes unhealthy, {e}"); } else { diff --git a/pingora-proxy/Cargo.toml b/pingora-proxy/Cargo.toml index 0a52dffd..a672867f 100644 --- a/pingora-proxy/Cargo.toml +++ b/pingora-proxy/Cargo.toml @@ -44,6 +44,7 @@ env_logger = "0.9" hyperlocal = "0.8" hyper = "0.14" tokio-tungstenite = "0.20.1" +pingora-limits = { version = "0.3.0", path = "../pingora-limits" } pingora-load-balancing = { version = "0.3.0", path = "../pingora-load-balancing" } prometheus = "0" futures-util = "0.3" diff --git a/pingora-proxy/examples/rate_limiter.rs b/pingora-proxy/examples/rate_limiter.rs new file mode 100644 index 00000000..d2c8b7ec --- /dev/null +++ b/pingora-proxy/examples/rate_limiter.rs @@ -0,0 +1,117 @@ +use async_trait::async_trait; +use once_cell::sync::Lazy; +use pingora_core::prelude::*; +use pingora_http::{RequestHeader, ResponseHeader}; +use pingora_limits::rate::Rate; +use pingora_load_balancing::prelude::{RoundRobin, TcpHealthCheck}; +use pingora_load_balancing::LoadBalancer; +use pingora_proxy::{http_proxy_service, ProxyHttp, Session}; +use std::sync::Arc; +use std::time::Duration; + +fn main() { + let mut server = Server::new(Some(Opt::default())).unwrap(); + server.bootstrap(); + let mut upstreams = LoadBalancer::try_from_iter(["1.1.1.1:443", "1.0.0.1:443"]).unwrap(); + // Set health check + let hc = TcpHealthCheck::new(); + upstreams.set_health_check(hc); + upstreams.health_check_frequency = Some(Duration::from_secs(1)); + // Set background service + let background = background_service("health check", upstreams); + let upstreams = background.task(); + // Set load balancer + let mut lb = http_proxy_service(&server.configuration, LB(upstreams)); + lb.add_tcp("0.0.0.0:6188"); + + // let rate = Rate + server.add_service(background); + server.add_service(lb); + server.run_forever(); +} + +pub struct LB(Arc>); + +impl LB { + pub fn get_request_appid(&self, session: &mut Session) -> Option { + match session + .req_header() + .headers + .get("appid") + .map(|v| v.to_str()) + { + None => None, + Some(v) => match v { + Ok(v) => Some(v.to_string()), + Err(_) => None, + }, + } + } +} + +// Rate limiter +static RATE_LIMITER: Lazy = Lazy::new(|| Rate::new(Duration::from_secs(1))); + +// max request per second per client +static MAX_REQ_PER_SEC: isize = 1; + +#[async_trait] +impl ProxyHttp for LB { + type CTX = (); + + fn new_ctx(&self) {} + + async fn upstream_peer( + &self, + _session: &mut Session, + _ctx: &mut Self::CTX, + ) -> Result> { + let upstream = self.0.select(b"", 256).unwrap(); + // Set SNI + let peer = Box::new(HttpPeer::new(upstream, true, "one.one.one.one".to_string())); + Ok(peer) + } + + async fn upstream_request_filter( + &self, + _session: &mut Session, + upstream_request: &mut RequestHeader, + _ctx: &mut Self::CTX, + ) -> Result<()> + where + Self::CTX: Send + Sync, + { + upstream_request + .insert_header("Host", "one.one.one.one") + .unwrap(); + Ok(()) + } + + async fn request_filter(&self, session: &mut Session, _ctx: &mut Self::CTX) -> Result + where + Self::CTX: Send + Sync, + { + let appid = match self.get_request_appid(session) { + None => return Ok(false), // no client appid found, skip rate limiting + Some(addr) => addr, + }; + + // retrieve the current window requests + let curr_window_requests = RATE_LIMITER.observe(&appid, 1); + if curr_window_requests > MAX_REQ_PER_SEC { + // rate limited, return 429 + let mut header = ResponseHeader::build(429, None).unwrap(); + header + .insert_header("X-Rate-Limit-Limit", MAX_REQ_PER_SEC.to_string()) + .unwrap(); + header.insert_header("X-Rate-Limit-Remaining", "0").unwrap(); + header.insert_header("X-Rate-Limit-Reset", "1").unwrap(); + session.set_keepalive(None); + session + .write_response_header(Box::new(header), true) + .await?; + return Ok(true); + } + Ok(false) + } +} diff --git a/pingora-proxy/src/proxy_cache.rs b/pingora-proxy/src/proxy_cache.rs index 5e90b188..77b73842 100644 --- a/pingora-proxy/src/proxy_cache.rs +++ b/pingora-proxy/src/proxy_cache.rs @@ -165,7 +165,9 @@ impl HttpProxy { } else { break None; } - } // else continue to serve stale + } + // else continue to serve stale + session.cache.set_stale_updating(); } else if session.cache.is_cache_lock_writer() { // stale while revalidate logic for the writer let will_serve_stale = session.cache.can_serve_stale_updating() @@ -182,6 +184,7 @@ impl HttpProxy { new_app.process_subrequest(subrequest, sub_req_ctx).await; }); // continue to serve stale for this request + session.cache.set_stale_updating(); } else { // return to fetch from upstream break None; diff --git a/pingora-proxy/tests/test_upstream.rs b/pingora-proxy/tests/test_upstream.rs index 26008328..c2499f59 100644 --- a/pingora-proxy/tests/test_upstream.rs +++ b/pingora-proxy/tests/test_upstream.rs @@ -1373,7 +1373,7 @@ mod test_cache { .unwrap(); assert_eq!(res.status(), StatusCode::OK); let headers = res.headers(); - assert_eq!(headers["x-cache-status"], "stale"); + assert_eq!(headers["x-cache-status"], "stale-updating"); assert_eq!(res.text().await.unwrap(), "hello world"); }); // sleep just a little to make sure the req above gets the cache lock @@ -1387,7 +1387,7 @@ mod test_cache { .unwrap(); assert_eq!(res.status(), StatusCode::OK); let headers = res.headers(); - assert_eq!(headers["x-cache-status"], "stale"); + assert_eq!(headers["x-cache-status"], "stale-updating"); assert_eq!(res.text().await.unwrap(), "hello world"); }); let task3 = tokio::spawn(async move { @@ -1399,7 +1399,7 @@ mod test_cache { .unwrap(); assert_eq!(res.status(), StatusCode::OK); let headers = res.headers(); - assert_eq!(headers["x-cache-status"], "stale"); + assert_eq!(headers["x-cache-status"], "stale-updating"); assert_eq!(res.text().await.unwrap(), "hello world"); }); @@ -1436,7 +1436,7 @@ mod test_cache { .unwrap(); assert_eq!(res.status(), StatusCode::OK); let headers = res.headers(); - assert_eq!(headers["x-cache-status"], "stale"); + assert_eq!(headers["x-cache-status"], "stale-updating"); assert_eq!(res.text().await.unwrap(), "hello world"); // wait for the background request to finish diff --git a/pingora-proxy/tests/utils/server_utils.rs b/pingora-proxy/tests/utils/server_utils.rs index ec1a9627..4f03f212 100644 --- a/pingora-proxy/tests/utils/server_utils.rs +++ b/pingora-proxy/tests/utils/server_utils.rs @@ -473,6 +473,9 @@ impl ProxyHttp for ExampleProxyCache { CachePhase::Hit => upstream_response.insert_header("x-cache-status", "hit")?, CachePhase::Miss => upstream_response.insert_header("x-cache-status", "miss")?, CachePhase::Stale => upstream_response.insert_header("x-cache-status", "stale")?, + CachePhase::StaleUpdating => { + upstream_response.insert_header("x-cache-status", "stale-updating")? + } CachePhase::Expired => { upstream_response.insert_header("x-cache-status", "expired")? } diff --git a/tinyufo/Cargo.toml b/tinyufo/Cargo.toml index 47bcae1b..110b6d1a 100644 --- a/tinyufo/Cargo.toml +++ b/tinyufo/Cargo.toml @@ -17,7 +17,7 @@ path = "src/lib.rs" [dependencies] ahash = { workspace = true } -flurry = "<0.5.0" # Try not to require Rust 1.71 +flurry = "0.5" parking_lot = "0" crossbeam-queue = "0" crossbeam-skiplist = "0" @@ -28,7 +28,7 @@ lru = "0" zipf = "7" moka = { version = "0", features = ["sync"] } dhat = "0" -quick_cache = "0.4" +quick_cache = "0.6" triomphe = "<=0.1.11" # 0.1.12 requires Rust 1.76 [[bench]]