diff --git a/fact/src/config/mod.rs b/fact/src/config/mod.rs index 9922584e..190c1585 100644 --- a/fact/src/config/mod.rs +++ b/fact/src/config/mod.rs @@ -3,32 +3,46 @@ use std::{ net::SocketAddr, path::{Path, PathBuf}, str::FromStr, + sync::LazyLock, }; use anyhow::{bail, Context}; use clap::Parser; use log::info; -use yaml_rust2::{Yaml, YamlLoader}; +use yaml_rust2::{yaml, Yaml, YamlLoader}; -#[derive(Debug, Default, PartialEq, Eq)] +pub mod reloader; +#[cfg(test)] +mod tests; + +const CONFIG_FILES: [&str; 4] = [ + "/etc/stackrox/fact.yml", + "/etc/stackrox/fact.yaml", + "fact.yml", + "fact.yaml", +]; + +#[derive(Debug, Default, PartialEq, Eq, Clone)] pub struct FactConfig { paths: Option>, url: Option, certs: Option, - endpoint: Option, - expose_metrics: Option, - health_check: Option, + pub endpoint: EndpointConfig, skip_pre_flight: Option, json: Option, ringbuf_size: Option, + hotreload: Option, } -#[cfg(test)] -mod tests; - impl FactConfig { - pub fn new(paths: &[&str]) -> anyhow::Result { - let mut config = paths + pub fn new() -> anyhow::Result { + let config = FactConfig::build()?; + info!("{config:#?}"); + Ok(config) + } + + fn build() -> anyhow::Result { + let mut config = CONFIG_FILES .iter() .filter_map(|p| { let p = Path::new(p); @@ -53,10 +67,9 @@ impl FactConfig { )?; // Once file configuration is handled, apply CLI arguments - let args = FactCli::parse(); - config.update(&args.to_config()); + static CLI_ARGS: LazyLock = LazyLock::new(|| FactCli::parse().to_config()); + config.update(&CLI_ARGS); - info!("{config:#?}"); Ok(config) } @@ -73,17 +86,7 @@ impl FactConfig { self.certs = Some(certs.to_owned()); } - if let Some(endpoint) = from.endpoint { - self.endpoint = Some(endpoint); - } - - if let Some(expose_metrics) = from.expose_metrics { - self.expose_metrics = Some(expose_metrics); - } - - if let Some(health_check) = from.health_check { - self.health_check = Some(health_check); - } + self.endpoint.update(&from.endpoint); if let Some(skip_pre_flight) = from.skip_pre_flight { self.skip_pre_flight = Some(skip_pre_flight); @@ -96,6 +99,10 @@ impl FactConfig { if let Some(ringbuf_size) = from.ringbuf_size { self.ringbuf_size = Some(ringbuf_size); } + + if let Some(hotreload) = from.hotreload { + self.hotreload = Some(hotreload); + } } pub fn paths(&self) -> &[PathBuf] { @@ -110,19 +117,6 @@ impl FactConfig { self.certs.as_deref() } - pub fn endpoint(&self) -> SocketAddr { - self.endpoint - .unwrap_or(SocketAddr::from(([0, 0, 0, 0], 9000))) - } - - pub fn expose_metrics(&self) -> bool { - self.expose_metrics.unwrap_or(false) - } - - pub fn health_check(&self) -> bool { - self.health_check.unwrap_or(false) - } - pub fn skip_pre_flight(&self) -> bool { self.skip_pre_flight.unwrap_or(false) } @@ -135,6 +129,10 @@ impl FactConfig { self.ringbuf_size.unwrap_or(8192) } + pub fn hotreload(&self) -> bool { + self.hotreload.unwrap_or(true) + } + #[cfg(test)] pub fn set_paths(&mut self, paths: Vec) { self.paths = Some(paths); @@ -207,27 +205,9 @@ impl TryFrom> for FactConfig { }; config.certs = Some(PathBuf::from(certs)); } - "endpoint" => { - let Some(endpoint) = v.as_str() else { - bail!("endpoint field has incorrect type: {v:?}"); - }; - let endpoint = match SocketAddr::from_str(endpoint) { - Ok(endpoint) => endpoint, - Err(e) => bail!("Failed to parse endpoint: {e}"), - }; - config.endpoint = Some(endpoint); - } - "expose_metrics" => { - let Some(em) = v.as_bool() else { - bail!("expose_metrics field has incorrect type: {v:?}"); - }; - config.expose_metrics = Some(em); - } - "health_check" => { - let Some(hc) = v.as_bool() else { - bail!("health_check field has incorrect type: {v:?}"); - }; - config.health_check = Some(hc); + "endpoint" if v.is_hash() => { + let endpoint = v.as_hash().unwrap(); + config.endpoint = EndpointConfig::try_from(endpoint)?; } "skip_pre_flight" => { let Some(spf) = v.as_bool() else { @@ -254,6 +234,12 @@ impl TryFrom> for FactConfig { } config.ringbuf_size = Some(rb_size); } + "hotreload" => { + let Some(hotreload) = v.as_bool() else { + bail!("hotreload field has incorrect type: {v:?}"); + }; + config.hotreload = Some(hotreload); + } name => bail!("Invalid field '{name}' with value: {v:?}"), } } @@ -262,6 +248,83 @@ impl TryFrom> for FactConfig { } } +#[derive(Debug, Default, PartialEq, Eq, Clone)] +pub struct EndpointConfig { + address: Option, + expose_metrics: Option, + health_check: Option, +} + +impl EndpointConfig { + fn update(&mut self, from: &EndpointConfig) { + if let Some(address) = from.address { + self.address = Some(address); + } + + if let Some(expose_metrics) = from.expose_metrics { + self.expose_metrics = Some(expose_metrics); + } + + if let Some(health_check) = from.health_check { + self.health_check = Some(health_check); + } + } + + pub fn address(&self) -> SocketAddr { + self.address + .unwrap_or(SocketAddr::from(([0, 0, 0, 0], 9000))) + } + + pub fn expose_metrics(&self) -> bool { + self.expose_metrics.unwrap_or(false) + } + + pub fn health_check(&self) -> bool { + self.health_check.unwrap_or(false) + } +} + +impl TryFrom<&yaml::Hash> for EndpointConfig { + type Error = anyhow::Error; + + fn try_from(value: &yaml::Hash) -> Result { + let mut endpoint = EndpointConfig::default(); + for (k, v) in value.iter() { + let Some(k) = k.as_str() else { + bail!("key is not string: {k:?}"); + }; + + match k { + "address" => { + let Some(addr) = v.as_str() else { + bail!("endpoint.address field has incorrect type: {v:?}"); + }; + let address = match SocketAddr::from_str(addr) { + Ok(a) => a, + Err(e) => bail!("Failed to parse endpoint.address: {e}"), + }; + endpoint.address = Some(address); + } + "expose_metrics" => { + let Some(em) = v.as_bool() else { + bail!("endpoint.expose_metrics field has incorrect type: {v:?}"); + }; + endpoint.expose_metrics = Some(em); + } + "health_check" => { + let Some(hc) = v.as_bool() else { + bail!("endpoint.health_check field has incorrect type: {v:?}"); + }; + endpoint.health_check = Some(hc); + } + name => bail!("Invalid field 'endpoint.{name}' with value: {v:?}"), + } + } + + Ok(endpoint) + } +} + #[derive(Debug, Parser)] #[clap(version, about)] pub struct FactCli { @@ -278,17 +341,25 @@ pub struct FactCli { certs: Option, /// The port to bind for all exposed endpoints - #[arg(long, short, env = "FACT_ENDPOINT")] - endpoint: Option, + #[arg(long, short, env = "FACT_ENDPOINT_ADDRESS")] + address: Option, /// Whether prometheus metrics should be collected and exposed - #[arg(long, overrides_with("no_expose_metrics"), env = "FACT_EXPOSE_METRICS")] + #[arg( + long, + overrides_with("no_expose_metrics"), + env = "FACT_ENDPOINT_EXPOSE_METRICS" + )] expose_metrics: bool, #[arg(long, overrides_with = "expose_metrics", hide(true))] no_expose_metrics: bool, /// Whether a small health_check probe should be run - #[arg(long, overrides_with("no_health_check"), env = "FACT_HEALTH_CHECK")] + #[arg( + long, + overrides_with("no_health_check"), + env = "FACT_ENDPOINT_HEALTH_CHECK" + )] health_check: bool, #[arg(long, overrides_with = "health_check", hide(true))] no_health_check: bool, @@ -319,6 +390,12 @@ pub struct FactCli { /// Default value is 8MB. #[arg(long, short, env = "FACT_RINGBUF_SIZE")] ringbuf_size: Option, + + /// Whether configuration should be hotreloaded + #[arg(long, overrides_with = "no_hotreload", env = "FACT_HOTRELOAD")] + hotreload: bool, + #[arg(long, overrides_with = "hotreload", hide(true))] + no_hotreload: bool, } impl FactCli { @@ -327,12 +404,15 @@ impl FactCli { paths: self.paths.clone(), url: self.url.clone(), certs: self.certs.clone(), - endpoint: self.endpoint, - expose_metrics: resolve_bool_arg(self.expose_metrics, self.no_expose_metrics), - health_check: resolve_bool_arg(self.health_check, self.no_health_check), + endpoint: EndpointConfig { + address: self.address, + expose_metrics: resolve_bool_arg(self.expose_metrics, self.no_expose_metrics), + health_check: resolve_bool_arg(self.health_check, self.no_health_check), + }, skip_pre_flight: resolve_bool_arg(self.skip_pre_flight, self.no_skip_pre_flight), json: resolve_bool_arg(self.json, self.no_json), ringbuf_size: self.ringbuf_size, + hotreload: resolve_bool_arg(self.hotreload, self.no_hotreload), } } } diff --git a/fact/src/config/reloader.rs b/fact/src/config/reloader.rs new file mode 100644 index 00000000..f35291ee --- /dev/null +++ b/fact/src/config/reloader.rs @@ -0,0 +1,173 @@ +use std::{ + collections::HashMap, os::unix::fs::MetadataExt, path::PathBuf, sync::Arc, time::Duration, +}; + +use log::{debug, info, warn}; +use tokio::{ + sync::{watch, Notify}, + task::JoinHandle, + time::interval, +}; + +use super::{EndpointConfig, FactConfig, CONFIG_FILES}; + +pub struct Reloader { + config: FactConfig, + endpoint: watch::Sender, + files: HashMap<&'static str, i64>, + trigger: Arc, +} + +impl Reloader { + /// Consume the reloader into a task + /// + /// The resulting task will handle reloading the configuration and + /// forwarding the changes to any parts of the program that might + /// need to take action accordingly. + /// + /// If hotreload is disabled on startup the task will not be + /// spawned. + pub fn start(mut self, mut running: watch::Receiver) -> Option> { + if !self.config.hotreload() { + info!("Configuration hotreload is disabled, changes will require a restart."); + return None; + } + + let handle = tokio::spawn(async move { + let mut ticker = interval(Duration::from_secs(10)); + loop { + tokio::select! { + _ = ticker.tick() => self.reload(), + _ = self.trigger.notified() => self.reload(), + _ = running.changed() => { + if !*running.borrow() { + info!("Stopping config reloader..."); + return; + } + } + } + } + }); + Some(handle) + } + + /// Subscribe to get notifications when endpoint configuration is + /// changed. + pub fn endpoint(&self) -> watch::Receiver { + self.endpoint.subscribe() + } + + /// Get a reference to the internal trigger for manual reloading of + /// configuration. + /// + /// Mainly meant as a way to handle the SIGHUP signal, but could be + /// extended to other use cases. + pub fn get_trigger(&self) -> Arc { + self.trigger.clone() + } + + /// Go through the configuration files and reload the modification + /// time for each of them. + /// + /// Returns true if any file has been modified. + fn update_cache(&mut self) -> bool { + let mut res = false; + + for file in CONFIG_FILES { + let path = PathBuf::from(file); + if path.exists() { + let mtime = match path.metadata() { + Ok(m) => m.mtime(), + Err(e) => { + warn!("Failed to stat {file}: {e}"); + warn!("Configuration reloading may not work"); + continue; + } + }; + match self.files.get_mut(&file) { + Some(old) if *old == mtime => {} + Some(old) => { + debug!("Updating '{file}'"); + res = true; + *old = mtime; + } + None => { + debug!("New configuration file '{file}'"); + res = true; + self.files.insert(file, mtime); + } + } + } else if self.files.contains_key(&file) { + debug!("'{file}' no longer exists, removing from cache"); + res = true; + self.files.remove(&file); + } + } + res + } + + /// Recreate the configuration and notify of changes to any + /// subscribers. + fn reload(&mut self) { + if !self.update_cache() { + return; + } + + let new = match FactConfig::build() { + Ok(config) => config, + Err(e) => { + warn!("Configuration reloading failed: {e}"); + return; + } + }; + info!("Updated configuration: {new:#?}"); + + self.endpoint.send_if_modified(|old| { + if *old != new.endpoint { + debug!("Sending new endpoint configuration..."); + *old = new.endpoint.clone(); + true + } else { + false + } + }); + + if self.config.hotreload() != new.hotreload() { + warn!("Changes to the hotreload field only take effect on startup"); + } + + self.config = new; + } +} + +impl From for Reloader { + fn from(config: FactConfig) -> Self { + let files = CONFIG_FILES + .iter() + .filter_map(|path| { + let p = PathBuf::from(path); + if p.exists() { + let mtime = match p.metadata() { + Ok(m) => m.mtime(), + Err(e) => { + warn!("Failed to stat {path}: {e}"); + warn!("Configuration reloading may not work"); + return None; + } + }; + Some((*path, mtime)) + } else { + None + } + }) + .collect(); + let (endpoint, _) = watch::channel(config.endpoint.clone()); + let trigger = Arc::new(Notify::new()); + Reloader { + config, + endpoint, + files, + trigger, + } + } +} diff --git a/fact/src/config/tests.rs b/fact/src/config/tests.rs index 221dc0b7..74322a45 100644 --- a/fact/src/config/tests.rs +++ b/fact/src/config/tests.rs @@ -33,64 +33,112 @@ fn parsing() { }, ), ( - "endpoint: 0.0.0.0:8080", + r#" + endpoint: + address: 0.0.0.0:8080 + "#, FactConfig { - endpoint: Some(SocketAddr::from(([0, 0, 0, 0], 8080))), + endpoint: EndpointConfig { + address: Some(SocketAddr::from(([0, 0, 0, 0], 8080))), + ..Default::default() + }, ..Default::default() }, ), ( - "endpoint: 127.0.0.1:8080", + r#" + endpoint: + address: 127.0.0.1:8080 + "#, FactConfig { - endpoint: Some(SocketAddr::from(([127, 0, 0, 1], 8080))), + endpoint: EndpointConfig { + address: Some(SocketAddr::from(([127, 0, 0, 1], 8080))), + ..Default::default() + }, ..Default::default() }, ), ( - "endpoint: '[::]:8080'", + r#" + endpoint: + address: '[::]:8080' + "#, FactConfig { - endpoint: Some(SocketAddr::from(( - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], - 8080, - ))), + endpoint: EndpointConfig { + address: Some(SocketAddr::from(( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + 8080, + ))), + ..Default::default() + }, ..Default::default() }, ), ( - "endpoint: '[::1]:8080'", + r#" + endpoint: + address: '[::1]:8080' + "#, FactConfig { - endpoint: Some(SocketAddr::from(( - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - 8080, - ))), + endpoint: EndpointConfig { + address: Some(SocketAddr::from(( + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], + 8080, + ))), + ..Default::default() + }, ..Default::default() }, ), ( - "expose_metrics: true", + r#" + endpoint: + expose_metrics: true + "#, FactConfig { - expose_metrics: Some(true), + endpoint: EndpointConfig { + expose_metrics: Some(true), + ..Default::default() + }, ..Default::default() }, ), ( - "expose_metrics: false", + r#" + endpoint: + expose_metrics: false + "#, FactConfig { - expose_metrics: Some(false), + endpoint: EndpointConfig { + expose_metrics: Some(false), + ..Default::default() + }, ..Default::default() }, ), ( - "health_check: true", + r#" + endpoint: + health_check: true + "#, FactConfig { - health_check: Some(true), + endpoint: EndpointConfig { + health_check: Some(true), + ..Default::default() + }, ..Default::default() }, ), ( - "health_check: false", + r#" + endpoint: + health_check: false + "#, FactConfig { - health_check: Some(false), + endpoint: EndpointConfig { + health_check: Some(false), + ..Default::default() + }, ..Default::default() }, ), @@ -129,29 +177,48 @@ fn parsing() { ..Default::default() }, ), + ( + "hotreload: true", + FactConfig { + hotreload: Some(true), + ..Default::default() + }, + ), + ( + "hotreload: false", + FactConfig { + hotreload: Some(false), + ..Default::default() + }, + ), ( r#" paths: - /etc url: https://svc.sensor.stackrox:9090 certs: /etc/stackrox/certs - endpoint: 0.0.0.0:8080 - expose_metrics: true - health_check: true + endpoint: + address: 0.0.0.0:8080 + expose_metrics: true + health_check: true skip_pre_flight: false json: false ringbuf_size: 8192 + hotreload: false "#, FactConfig { paths: Some(vec![PathBuf::from("/etc")]), url: Some(String::from("https://svc.sensor.stackrox:9090")), certs: Some(PathBuf::from("/etc/stackrox/certs")), - endpoint: Some(SocketAddr::from(([0, 0, 0, 0], 8080))), - expose_metrics: Some(true), - health_check: Some(true), + endpoint: EndpointConfig { + address: Some(SocketAddr::from(([0, 0, 0, 0], 8080))), + expose_metrics: Some(true), + health_check: Some(true), + }, skip_pre_flight: Some(false), json: Some(false), ringbuf_size: Some(8192), + hotreload: Some(false), }, ), ]; @@ -193,43 +260,84 @@ paths: ), ( "endpoint: true", - "endpoint field has incorrect type: Boolean(true)", + "Invalid field 'endpoint' with value: Boolean(true)", + ), + ( + r#" + endpoint: + address: true + "#, + "endpoint.address field has incorrect type: Boolean(true)", + ), + ( + r#" + endpoint: + address: 127.0.0.1 + "#, + "Failed to parse endpoint.address: invalid socket address syntax", ), ( - "endpoint: 127.0.0.1", - "Failed to parse endpoint: invalid socket address syntax", + r#" + endpoint: + address: :8080 + "#, + "Failed to parse endpoint.address: invalid socket address syntax", ), ( - "endpoint: :8080", - "Failed to parse endpoint: invalid socket address syntax", + r#" + endpoint: + address: 127.0.0.:8080 + "#, + "Failed to parse endpoint.address: invalid socket address syntax", ), ( - "endpoint: 127.0.0.:8080", - "Failed to parse endpoint: invalid socket address syntax", + r#" + endpoint: + address: '[::]' + "#, + "Failed to parse endpoint.address: invalid socket address syntax", ), ( - "endpoint: '[::]'", - "Failed to parse endpoint: invalid socket address syntax", + r#" + endpoint: + address: '[::1]' + "#, + "Failed to parse endpoint.address: invalid socket address syntax", ), ( - "endpoint: '[::1]'", - "Failed to parse endpoint: invalid socket address syntax", + r#" + endpoint: + address: '[:::1]:8080' + "#, + "Failed to parse endpoint.address: invalid socket address syntax", ), ( - "endpoint: '[:::1]:8080'", - "Failed to parse endpoint: invalid socket address syntax", + r#" + endpoint: + address: '[::cafe::1]:8080' + "#, + "Failed to parse endpoint.address: invalid socket address syntax", ), ( - "endpoint: '[::cafe::1]:8080'", - "Failed to parse endpoint: invalid socket address syntax", + r#" + endpoint: + expose_metrics: 4 + "#, + "endpoint.expose_metrics field has incorrect type: Integer(4)", ), ( - "expose_metrics: 4", - "expose_metrics field has incorrect type: Integer(4)", + r#" + endpoint: + health_check: 4 + "#, + "endpoint.health_check field has incorrect type: Integer(4)", ), ( - "health_check: 4", - "health_check field has incorrect type: Integer(4)", + r#" + endpoint: + unknown: 4 + "#, + "Invalid field 'endpoint.unknown' with value: Integer(4)", ), ( "skip_pre_flight: 4", @@ -247,6 +355,10 @@ paths: &format!("ringbuf_size out of range: {}", u32::MAX), ), ("ringbuf_size: 65", "ringbuf_size is not a power of 2: 65"), + ( + "hotreload: 4", + "hotreload field has incorrect type: Integer(4)", + ), ("unknown:", "Invalid field 'unknown' with value: Null"), ]; for (input, expected) in tests { @@ -371,62 +483,110 @@ fn update() { }, ), ( - "expose_metrics: true", + r#" + endpoint: + expose_metrics: true + "#, FactConfig::default(), FactConfig { - expose_metrics: Some(true), + endpoint: EndpointConfig { + expose_metrics: Some(true), + ..Default::default() + }, ..Default::default() }, ), ( - "expose_metrics: true", + r#" + endpoint: + expose_metrics: true + "#, FactConfig { - expose_metrics: Some(false), + endpoint: EndpointConfig { + expose_metrics: Some(false), + ..Default::default() + }, ..Default::default() }, FactConfig { - expose_metrics: Some(true), + endpoint: EndpointConfig { + expose_metrics: Some(true), + ..Default::default() + }, ..Default::default() }, ), ( - "expose_metrics: true", + r#" + endpoint: + expose_metrics: true + "#, FactConfig { - expose_metrics: Some(true), + endpoint: EndpointConfig { + expose_metrics: Some(true), + ..Default::default() + }, ..Default::default() }, FactConfig { - expose_metrics: Some(true), + endpoint: EndpointConfig { + expose_metrics: Some(true), + ..Default::default() + }, ..Default::default() }, ), ( - "health_check: true", + r#" + endpoint: + health_check: true + "#, FactConfig::default(), FactConfig { - health_check: Some(true), + endpoint: EndpointConfig { + health_check: Some(true), + ..Default::default() + }, ..Default::default() }, ), ( - "health_check: true", + r#" + endpoint: + health_check: true + "#, FactConfig { - health_check: Some(false), + endpoint: EndpointConfig { + health_check: Some(false), + ..Default::default() + }, ..Default::default() }, FactConfig { - health_check: Some(true), + endpoint: EndpointConfig { + health_check: Some(true), + ..Default::default() + }, ..Default::default() }, ), ( - "health_check: true", + r#" + endpoint: + health_check: true + "#, FactConfig { - health_check: Some(true), + endpoint: EndpointConfig { + health_check: Some(true), + ..Default::default() + }, ..Default::default() }, FactConfig { - health_check: Some(true), + endpoint: EndpointConfig { + health_check: Some(true), + ..Default::default() + }, ..Default::default() }, ), @@ -490,40 +650,78 @@ fn update() { ..Default::default() }, ), + ( + "hotreload: false", + FactConfig::default(), + FactConfig { + hotreload: Some(false), + ..Default::default() + }, + ), + ( + "hotreload: true", + FactConfig { + hotreload: Some(false), + ..Default::default() + }, + FactConfig { + hotreload: Some(true), + ..Default::default() + }, + ), + ( + "hotreload: true", + FactConfig { + hotreload: Some(true), + ..Default::default() + }, + FactConfig { + hotreload: Some(true), + ..Default::default() + }, + ), ( r#" paths: - /etc url: https://svc.sensor.stackrox:9090 certs: /etc/stackrox/certs - endpoint: 127.0.0.1:8080 - expose_metrics: true - health_check: true + endpoint: + address: 127.0.0.1:8080 + expose_metrics: true + health_check: true skip_pre_flight: false json: false ringbuf_size: 16384 + hotreload: false "#, FactConfig { paths: Some(vec![PathBuf::from("/etc"), PathBuf::from("/bin")]), url: Some(String::from("http://localhost")), certs: Some(PathBuf::from("/etc/certs")), - endpoint: Some(SocketAddr::from(([0, 0, 0, 0], 9000))), - expose_metrics: Some(false), - health_check: Some(false), + endpoint: EndpointConfig { + address: Some(SocketAddr::from(([0, 0, 0, 0], 9000))), + expose_metrics: Some(false), + health_check: Some(false), + }, skip_pre_flight: Some(true), json: Some(true), ringbuf_size: Some(64), + hotreload: Some(true), }, FactConfig { paths: Some(vec![PathBuf::from("/etc")]), url: Some(String::from("https://svc.sensor.stackrox:9090")), certs: Some(PathBuf::from("/etc/stackrox/certs")), - endpoint: Some(SocketAddr::from(([127, 0, 0, 1], 8080))), - expose_metrics: Some(true), - health_check: Some(true), + endpoint: EndpointConfig { + address: Some(SocketAddr::from(([127, 0, 0, 1], 8080))), + expose_metrics: Some(true), + health_check: Some(true), + }, skip_pre_flight: Some(false), json: Some(false), ringbuf_size: Some(16384), + hotreload: Some(false), }, ), ]; @@ -544,9 +742,14 @@ fn defaults() { assert_eq!(config.paths(), default_paths); assert_eq!(config.url(), None); assert_eq!(config.certs(), None); - assert!(!config.expose_metrics()); - assert!(!config.health_check()); + assert_eq!( + config.endpoint.address(), + SocketAddr::from(([0, 0, 0, 0], 9000)) + ); + assert!(!config.endpoint.expose_metrics()); + assert!(!config.endpoint.health_check()); assert!(!config.skip_pre_flight()); assert!(!config.json()); assert_eq!(config.ringbuf_size(), 8192); + assert!(config.hotreload()); } diff --git a/fact/src/endpoints.rs b/fact/src/endpoints.rs index 62d5c2b3..b842c183 100644 --- a/fact/src/endpoints.rs +++ b/fact/src/endpoints.rs @@ -1,4 +1,4 @@ -use std::{future::Future, net::SocketAddr, pin::Pin}; +use std::{future::Future, pin::Pin}; use http_body_util::Full; use hyper::{ @@ -11,69 +11,104 @@ use hyper_util::rt::TokioIo; use log::{info, warn}; use tokio::{net::TcpListener, sync::watch, task::JoinHandle}; -use crate::metrics::exporter::Exporter; +use crate::{config::EndpointConfig, metrics::exporter::Exporter}; #[derive(Clone)] pub struct Server { - addr: SocketAddr, - metrics: Option, - health_check: bool, + metrics: Exporter, + config: watch::Receiver, + running: watch::Receiver, } impl Server { pub fn new( - addr: SocketAddr, metrics: Exporter, - expose_metrics: bool, - health_check: bool, + config: watch::Receiver, + running: watch::Receiver, ) -> Self { - let metrics = if expose_metrics { Some(metrics) } else { None }; Server { - addr, metrics, - health_check, + config, + running, } } - pub async fn start( - self, - mut running: watch::Receiver, - ) -> Option>> { - // If there is nothing to expose, we don't run the hyper server - if self.metrics.is_none() && !self.health_check { - return None; - } - - let listener = match TcpListener::bind(self.addr).await { - Ok(l) => l, - Err(e) => { - return Some(Err(e.into())); - } - }; - - let handle = tokio::spawn(async move { + /// Consume the Server into a task that will serve the endpoints. + /// + /// If all endpoints are disabled, no port will be listened on and + /// the task goes into an idle state waiting for configuration + /// changes. + pub fn start(mut self) -> JoinHandle<()> { + tokio::spawn(async move { loop { - tokio::select! { - Ok((stream, _)) = listener.accept() => { - let io = TokioIo::new(stream); - let s = self.clone(); - tokio::spawn(async move { - if let Err(e) = http1::Builder::new().serve_connection(io, s).await { - warn!("Error serving connection: {e:?}"); - } - }); - }, - _ = running.changed() => { - if !*running.borrow() { - drop(listener); + let res = if self.is_active() { + self.serve().await + } else { + self.idle().await + }; + + match res { + Ok(running) => { + if running { + info!("Reloading endpoints..."); + } else { info!("Stopping endpoints..."); break; } } - } + Err(e) => { + warn!("endpoints error: {e}"); + } + }; + } + }) + } + + /// Wait for configuration changes or fact to stop. + async fn idle(&mut self) -> anyhow::Result { + tokio::select! { + _ = self.config.changed() => Ok(true), + _ = self.running.changed() => Ok(*self.running.borrow()), + } + } + + /// Serve requests on the configured endpoints. + /// + /// If a configuration change is detected, returning from this + /// method will handle reloading it. + async fn serve(&mut self) -> anyhow::Result { + let addr = self.config.borrow().address(); + let listener = TcpListener::bind(addr).await?; + + loop { + tokio::select! { + Ok((stream, _)) = listener.accept() => { + let io = TokioIo::new(stream); + let s = self.clone(); + tokio::spawn(async move { + if let Err(e) = http1::Builder::new().serve_connection(io, s).await { + warn!("Error serving connection: {e:?}"); + } + }); + }, + _ = self.config.changed() => return Ok(true), + _ = self.running.changed() => return Ok(*self.running.borrow()), } - }); - Some(Ok(handle)) + } + } + + /// Check if there are active endpoints to serve. + fn is_active(&self) -> bool { + let config = self.config.borrow(); + config.health_check() || config.expose_metrics() + } + + fn health_check_is_active(&self) -> bool { + self.config.borrow().health_check() + } + + fn metrics_is_active(&self) -> bool { + self.config.borrow().expose_metrics() } fn make_response( @@ -87,23 +122,24 @@ impl Server { } fn handle_metrics(&self) -> Result>, anyhow::Error> { - match &self.metrics { - Some(metrics) => metrics.encode().map(|buf| { - let body = Full::new(Bytes::from(buf)); - Response::builder() - .header( - hyper::header::CONTENT_TYPE, - "application/openmetrics-text; version=1.0.0; charset=utf-8", - ) - .body(body) - .map_err(anyhow::Error::new) - })?, - None => Server::make_response(StatusCode::SERVICE_UNAVAILABLE, String::new()), + if !self.metrics_is_active() { + return Server::make_response(StatusCode::SERVICE_UNAVAILABLE, String::new()); } + + self.metrics.encode().map(|buf| { + let body = Full::new(Bytes::from(buf)); + Response::builder() + .header( + hyper::header::CONTENT_TYPE, + "application/openmetrics-text; version=1.0.0; charset=utf-8", + ) + .body(body) + .map_err(anyhow::Error::new) + })? } fn handle_health_check(&self) -> Result>, anyhow::Error> { - let res = if self.health_check { + let res = if self.health_check_is_active() { StatusCode::OK } else { StatusCode::SERVICE_UNAVAILABLE diff --git a/fact/src/lib.rs b/fact/src/lib.rs index bcc1a08a..090a2e51 100644 --- a/fact/src/lib.rs +++ b/fact/src/lib.rs @@ -66,7 +66,7 @@ pub async fn run(config: FactConfig) -> anyhow::Result<()> { // Log system information as early as possible so we have it // available in case of a crash log_system_information(); - let (run_tx, run_rx) = watch::channel(true); + let (running, _) = watch::channel(true); let (tx, rx) = broadcast::channel(100); if !config.skip_pre_flight() { @@ -80,29 +80,35 @@ pub async fn run(config: FactConfig) -> anyhow::Result<()> { let exporter = Exporter::new(bpf.get_metrics()?); - let server = endpoints::Server::new( - config.endpoint(), - exporter.clone(), - config.expose_metrics(), - config.health_check(), - ); - if let Some(Err(e)) = server.start(run_rx.clone()).await { - warn!("Failed to start endpoints server: {e}"); - }; - - let output = Output::new(run_rx.clone(), rx, exporter.metrics.output.clone()); + let output = Output::new(running.subscribe(), rx, exporter.metrics.output.clone()); output.start(&config)?; + let reloader = config::reloader::Reloader::from(config); + let config_trigger = reloader.get_trigger(); + + endpoints::Server::new(exporter.clone(), reloader.endpoint(), running.subscribe()).start(); + + reloader.start(running.subscribe()); + // Gather events from the ring buffer and print them out. - Bpf::start_worker(tx, bpf.fd, run_rx, exporter.metrics.bpf_worker.clone()); + Bpf::start_worker( + tx, + bpf.fd, + running.subscribe(), + exporter.metrics.bpf_worker.clone(), + ); let mut sigterm = signal(SignalKind::terminate())?; - tokio::select! { - _ = tokio::signal::ctrl_c() => {} - _ = sigterm.recv() => {} + let mut sighup = signal(SignalKind::hangup())?; + loop { + tokio::select! { + _ = tokio::signal::ctrl_c() => break, + _ = sigterm.recv() => break, + _ = sighup.recv() => config_trigger.notify_one(), + } } - run_tx.send(false)?; + running.send(false)?; info!("Exiting..."); Ok(()) diff --git a/fact/src/main.rs b/fact/src/main.rs index 8b5a1a1f..1ae88408 100644 --- a/fact/src/main.rs +++ b/fact/src/main.rs @@ -3,12 +3,7 @@ use fact::config::FactConfig; #[tokio::main] async fn main() -> anyhow::Result<()> { fact::init_log()?; - let config = FactConfig::new(&[ - "/etc/stackrox/fact.yml", - "/etc/stackrox/fact.yaml", - "fact.yml", - "fact.yaml", - ])?; + let config = FactConfig::new()?; fact::run(config).await } diff --git a/tests/conftest.py b/tests/conftest.py index e624b4ae..e343eaa8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,12 @@ -from concurrent import futures import os from shutil import rmtree -from tempfile import mkdtemp +from tempfile import NamedTemporaryFile, mkdtemp from time import sleep import docker import pytest import requests +import yaml from server import FileActivityService @@ -81,21 +81,38 @@ def dump_logs(container, file): @pytest.fixture -def fact(request, docker_client, monitored_dir, server, logs_dir): +def fact_config(request, monitored_dir, logs_dir): + cwd = os.getcwd() + config = { + 'paths': [monitored_dir], + 'url': 'http://127.0.0.1:9999', + 'endpoint': { + 'address': '127.0.0.1:9000', + 'expose_metrics': True, + 'health_check': True, + }, + 'json': True, + } + config_file = NamedTemporaryFile( + prefix='fact-config-', suffix='.yml', dir=cwd, mode='w') + yaml.dump(config, config_file) + + yield config, config_file.name + with open(os.path.join(logs_dir, 'fact.yml'), 'w') as f: + with open(config_file.name, 'r') as r: + f.write(r.read()) + config_file.close() + + +@pytest.fixture +def fact(request, docker_client, fact_config, server, logs_dir): """ Run the fact docker container for integration tests. """ - command = [ - 'http://127.0.0.1:9999', - '-p', monitored_dir, - '--expose-metrics', - '--health-check', - '--json', - ] + config, config_file = fact_config image = request.config.getoption('--image') container = docker_client.containers.run( image, - command=command, detach=True, environment={ 'FACT_LOGLEVEL': 'debug', @@ -121,6 +138,10 @@ def fact(request, docker_client, monitored_dir, server, logs_dir): 'bind': '/host/usr/lib/os-release', 'mode': 'ro', }, + config_file: { + 'bind': '/etc/stackrox/fact.yml', + 'mode': 'ro', + } }, ) @@ -128,7 +149,8 @@ def fact(request, docker_client, monitored_dir, server, logs_dir): # Wait for container to be ready for _ in range(3): try: - resp = requests.get('http://127.0.0.1:9000/health_check') + resp = requests.get( + f'http://{config["endpoint"]["address"]}/health_check') if resp.status_code == 200: break except (requests.RequestException, requests.ConnectionError) as e: @@ -143,11 +165,13 @@ def fact(request, docker_client, monitored_dir, server, logs_dir): yield container # Capture prometheus metrics before stopping the container - metric_log = os.path.join(logs_dir, 'metrics') - resp = requests.get('http://127.0.0.1:9000/metrics') - if resp.status_code == 200: - with open(metric_log, 'w') as f: - f.write(resp.text) + if config['endpoint']['expose_metrics']: + metric_log = os.path.join(logs_dir, 'metrics') + resp = requests.get( + f'http://{config["endpoint"]["address"]}/metrics') + if resp.status_code == 200: + with open(metric_log, 'w') as f: + f.write(resp.text) container.stop(timeout=1) exit_status = container.wait(timeout=1) diff --git a/tests/requirements.txt b/tests/requirements.txt index 3ebe5046..97079261 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -3,3 +3,4 @@ grpcio==1.73.1 grpcio-tools==1.73.1 pytest==8.4.1 requests==2.32.4 +pyyaml==6.0.3 diff --git a/tests/test_config_hotreload.py b/tests/test_config_hotreload.py new file mode 100644 index 00000000..46a4c8ea --- /dev/null +++ b/tests/test_config_hotreload.py @@ -0,0 +1,66 @@ +from time import sleep +import pytest +import requests +import yaml + +DEFAULT_URL = 'http://127.0.0.1:9000' + + +def assert_endpoint(endpoint, status_code=200): + resp = requests.get(f'{DEFAULT_URL}/{endpoint}') + assert resp.status_code == status_code + + +def reload_config(fact, config, file): + with open(file, 'w') as f: + yaml.dump(config, f) + fact.kill('SIGHUP') + sleep(0.1) + + +cases = [('metrics', 'expose_metrics'), ('health_check', 'health_check')] + + +@pytest.mark.parametrize('case', cases, ids=['metrics', 'health_check']) +def test_endpoint(fact, fact_config, case): + """ + Test the endpoints configurability + """ + endpoint, field = case + + # Endpoints are assumed to start up enabled. + assert_endpoint(endpoint) + + # Mark the endpoint as off and reload configuration + config, config_file = fact_config + config['endpoint'][field] = False + reload_config(fact, config, config_file) + + assert_endpoint(endpoint, 503) + + +def test_endpoint_disable_all(fact, fact_config): + """ + Disable all endpoints and check the default port is not bound + """ + config, config_file = fact_config + config['endpoint'] = { + 'health_check': False, + 'expose_metrics': False, + } + reload_config(fact, config, config_file) + + with pytest.raises(requests.ConnectionError): + requests.get(f'{DEFAULT_URL}/metrics') + + +def test_endpoint_address_change(fact, fact_config): + config, config_file = fact_config + config['endpoint']['address'] = '127.0.0.1:9001' + reload_config(fact, config, config_file) + + with pytest.raises(requests.ConnectionError): + requests.get(f'{DEFAULT_URL}/metrics') + + resp = requests.get('http://127.0.0.1:9001/metrics') + assert resp.status_code == 200