diff --git a/scylla-server/src/controllers/rule_controller.rs b/scylla-server/src/controllers/rule_controller.rs index f8778a52..8d475046 100644 --- a/scylla-server/src/controllers/rule_controller.rs +++ b/scylla-server/src/controllers/rule_controller.rs @@ -5,7 +5,6 @@ use axum_extra::{ headers::{authorization::Basic, Authorization}, TypedHeader, }; -use tokio::sync::RwLock; use tracing::debug; use crate::{ @@ -16,7 +15,7 @@ use crate::{ #[debug_handler] pub async fn add_rule( TypedHeader(auth): TypedHeader>, - Extension(rules_manager): Extension>>, + Extension(rules_manager): Extension>, Json(rule): Json, ) -> Result, ScyllaError> { debug!( @@ -25,9 +24,8 @@ pub async fn add_rule( auth.username().to_string() ); match rules_manager - .write() - .await .add_rule(ClientId(auth.username().to_string()), rule) + .await { Ok(_) => Ok(Json::from("Rule added!".to_owned())), Err(err) => Err(ScyllaError::RuleError(err)), @@ -37,7 +35,7 @@ pub async fn add_rule( #[debug_handler] pub async fn delete_rule( TypedHeader(auth): TypedHeader>, - Extension(rules_manager): Extension>>, + Extension(rules_manager): Extension>, Path(rule_id): Path, ) -> Result<(), ScyllaError> { debug!( @@ -46,9 +44,8 @@ pub async fn delete_rule( auth.username().to_string() ); match rules_manager - .write() - .await .delete_rule(ClientId(auth.username().to_string()), RuleId(rule_id)) + .await { Ok(_) => Ok(()), Err(err) => Err(ScyllaError::RuleError(err)), diff --git a/scylla-server/src/main.rs b/scylla-server/src/main.rs index d51d3fc1..1a2e3fc1 100755 --- a/scylla-server/src/main.rs +++ b/scylla-server/src/main.rs @@ -43,7 +43,7 @@ use scylla_server::{ use socketioxide::{extract::SocketRef, SocketIo}; use tokio::{ signal, - sync::{broadcast, mpsc, RwLock}, + sync::{broadcast, mpsc}, }; use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tower::ServiceBuilder; @@ -249,7 +249,7 @@ async fn main() -> Result<(), Box> { let (db_send, db_receive) = mpsc::channel::>(1000); // the rules manager - let rules_manager = Arc::new(RwLock::new(RuleManager::new())); + let rules_manager = Arc::new(RuleManager::new()); // the below two threads need to cancel cleanly to ensure all queued messages are sent. therefore they are part of the a task tracker group. // create a task tracker and cancellation token diff --git a/scylla-server/src/rule_structs.rs b/scylla-server/src/rule_structs.rs index 8d5ca58c..9df51d35 100644 --- a/scylla-server/src/rule_structs.rs +++ b/scylla-server/src/rule_structs.rs @@ -10,10 +10,14 @@ use rustc_hash::FxHashSet; use serde::{Deserialize, Serialize}; use serde_with::serde_as; use serde_with::DurationSeconds; +use std::borrow::Borrow; +use std::hash::Hash; use std::time::Duration; +use tokio::sync::RwLock; use tracing::trace; use tracing::warn; +use crate::rule_structs::BiMapRemoveResult::*; use crate::ClientData; static ASCII_LOWER: [char; 26] = [ @@ -21,6 +25,158 @@ static ASCII_LOWER: [char; 26] = [ 't', 'u', 'v', 'w', 'x', 'y', 'z', ]; +#[derive(Debug, Clone)] +pub enum BiMapRemoveResult { + /// Removed succesfully, and also removed any empty mappings \ + /// Contains the data that was thrown out from the map because they were unused. + RemovedWithCleanUp(T), + /// Removed succesfully, no empty mappings to clean up + RemovedOnly, + NothingToRemove, +} + +pub struct BiMultiMap { + left_to_right: FxHashMap>, + right_to_left: FxHashMap>, +} + +#[allow(dead_code)] +impl BiMultiMap { + pub fn new() -> Self { + Self { + left_to_right: FxHashMap::default(), + right_to_left: FxHashMap::default(), + } + } + + pub fn lefts(&self) -> Vec { + self.left_to_right.keys().cloned().collect() + } + + pub fn rights(&self) -> Vec { + self.right_to_left.keys().cloned().collect() + } + + pub fn get_right(&self, left: &L) -> Option<&FxHashSet> { + self.left_to_right.get(left) + } + + pub fn get_left(&self, right: &R) -> Option<&FxHashSet> { + self.right_to_left.get(right) + } + + pub fn insert(&mut self, left: &L, right: &R) { + self.left_to_right + .entry(left.clone()) + .or_insert_with(FxHashSet::default) + .insert(right.clone()); + self.right_to_left + .entry(right.clone()) + .or_insert_with(FxHashSet::default) + .insert(left.clone()); + } + + /// Remove all mappings for a given left key, if none left keys remain for a right key, remove that right key as well. \ + /// Returns: BiMapRemoveResult with optional set of empty rights that were cleaned from map. + pub fn remove_left(&mut self, left: &L) -> BiMapRemoveResult> { + Self::remove_key(&mut self.left_to_right, &mut self.right_to_left, left) + } + + /// Remove all mappings for a given right key, if none right keys remain for a left key, remove that left key as well. \ + /// Returns: BiMapRemoveResult with optional set of empty lefts that were cleaned from map. + pub fn remove_right(&mut self, right: &R) -> BiMapRemoveResult> { + Self::remove_key(&mut self.right_to_left, &mut self.left_to_right, right) + } + + /// Remove a specific mapping from left to right, cleaning up empty entries as needed.\ + /// Returns: BiMapRemoveresult with optional right that was cleaned from map. + pub fn remove_right_from_left(&mut self, left: &L, right: &R) -> BiMapRemoveResult { + Self::remove_mapping( + &mut self.left_to_right, + &mut self.right_to_left, + left, + right, + ) + } + + /// Remove a specific mapping from right to left, cleaning up empty entries as needed. \ + /// Returns: BiMapRemoveresult with optional left that was cleaned from map. + pub fn remove_left_from_right(&mut self, right: &R, left: &L) -> BiMapRemoveResult { + Self::remove_mapping( + &mut self.right_to_left, + &mut self.left_to_right, + right, + left, + ) + } + + fn remove_key( + k_to_v: &mut FxHashMap>, + v_to_k: &mut FxHashMap>, + key: &K, + ) -> BiMapRemoveResult> + where + K: Hash + Eq + Clone, + V: Hash + Eq + Clone, + { + let Some(values) = k_to_v.remove(key) else { + return NothingToRemove; + }; + + let mut empty_values = FxHashSet::default(); + for value in values { + if let Some(keys) = v_to_k.get_mut(&value) { + keys.remove(key); + if keys.is_empty() { + v_to_k.remove(&value); + empty_values.insert(value); + } + } + } + + if empty_values.is_empty() { + RemovedOnly + } else { + RemovedWithCleanUp(empty_values) + } + } + + fn remove_mapping( + k_to_v: &mut FxHashMap>, + v_to_k: &mut FxHashMap>, + key: &K, + value: &V, + ) -> BiMapRemoveResult + where + K: Hash + Eq + Clone, + V: Hash + Eq + Clone, + { + let Some(values) = k_to_v.get_mut(key) else { + return NothingToRemove; + }; + + if !values.remove(value) { + return NothingToRemove; + } + + if values.is_empty() { + k_to_v.remove(key); + } + + if let Some(keys) = v_to_k.get_mut(value) { + keys.remove(key); + if keys.is_empty() { + v_to_k.remove(value); + RemovedWithCleanUp(value.clone()) + } else { + RemovedOnly + } + } else { + NothingToRemove + } + } +} + /// cooldown time const COOLDOWN_TIME: std::time::Duration = std::time::Duration::from_secs(60); @@ -39,7 +195,19 @@ pub struct RuleId(pub String); /// a MQTT topic to trigger on, add to derives to get more string features #[derive(PartialEq, Eq, Hash, Display, Clone, Serialize, Deserialize)] -pub struct Topic(String); +pub struct Topic(pub String); + +impl Borrow for Topic { + fn borrow(&self) -> &str { + &self.0 + } +} + +impl Borrow for Topic { + fn borrow(&self) -> &String { + &self.0 + } +} #[derive(Serialize)] pub struct RuleNotification { @@ -50,7 +218,7 @@ pub struct RuleNotification { } #[serde_as] -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] /// A single modular rule, can be serial/deserialized pub struct Rule { id: RuleId, @@ -146,6 +314,7 @@ impl Rule { } /// errors seen in the rule manager +#[derive(Debug)] pub enum RuleManagerError { NoMatchingRule, NoSuchClient, @@ -155,10 +324,12 @@ pub enum RuleManagerError { /// the rule manager pub struct RuleManager { - /// > - clients_map: FxHashMap>>, - /// > - rules_lookup: FxHashMap>, + /// + rules: RwLock>, + /// > + topic_index: RwLock>>, + /// bimap + subscriptions: RwLock>, } impl Default for RuleManager { fn default() -> Self { @@ -169,57 +340,86 @@ impl Default for RuleManager { impl RuleManager { pub fn new() -> Self { Self { - clients_map: FxHashMap::default(), - rules_lookup: FxHashMap::default(), + rules: RwLock::new(FxHashMap::default()), + topic_index: RwLock::new(FxHashMap::default()), + subscriptions: RwLock::new(BiMultiMap::new()), } } /// Handles a new socket message, returning a RuleNotification for one to many clients if action should be taken - pub fn handle_msg( - &mut self, + pub async fn handle_msg( + &self, data: &ClientData, ) -> Result>, RuleManagerError> { - // TODO uneccessary clone - let topic = Topic(data.name.clone()); - - let Some(clients) = self.rules_lookup.get(&topic) else { - trace!("(normal) Could not find rule in rule cache: {}", data.name); - return Err(RuleManagerError::NoMatchingRule); + // Read from topic to rule index and drop lock immediately + let rule_ids = match self.topic_index.read().await.get(&data.name) { + Some(rule_ids) => rule_ids.clone(), // Clone so we can drop resource + None => { + trace!("Could not find rule in topic -> rule index: {}", data.name); + return Err(RuleManagerError::NoMatchingRule); + } }; - let mut notifications: Vec<(ClientId, RuleNotification)> = Vec::new(); - // warning if the clients is empty we havent been cleaning right - if clients.is_empty() { - warn!("Empty rule cache entry for {}!", data.name); - } - - for client_want in clients { - let Some(rules) = self.clients_map.get_mut(client_want) else { - warn!("Client cached but not found!"); - return Err(RuleManagerError::Failure); - }; - - if let Some(rule_set) = rules.get_mut(&topic) { - for rule in rule_set { - // return rule failure if underlying tick fails - let Some(is_triggered) = rule.tick(&data.values) else { - return Err(RuleManagerError::RuleFailure); - }; - if is_triggered { - notifications.push(( - client_want.clone(), - RuleNotification { - id: rule.id.clone(), - topic: topic.clone(), - values: data.values.clone(), - time: data.timestamp, - }, - )); + let mut notifications: Vec<(ClientId, RuleNotification)> = Vec::new(); + for rule_id in rule_ids { + let (triggered_result, clients_result) = { + // Future for if rule was triggered + let triggered_future = async { + if let Some(rule) = self.rules.write().await.get_mut(&rule_id) { + if let Some(triggered) = rule.tick(&data.values) { + Ok(triggered) + } else { + Err(RuleManagerError::RuleFailure) + } + } else { + trace!("Could not find rule in rules map: {}", rule_id); + Err(RuleManagerError::NoMatchingRule) + } + }; + + // Future for getting subscribed clients + let clients_future = + async { self.subscriptions.read().await.get_left(&rule_id).cloned() }; + + tokio::pin!(triggered_future); + tokio::pin!(clients_future); + + // Check which operation finished first + tokio::select! { + triggered_result = &mut triggered_future => { + match triggered_result? { + true => (Ok(true), clients_future.await), + false => (Ok(false), None), + } + }, + clients_result = &mut clients_future => { + match clients_result { + Some(_) => (triggered_future.await, clients_result), + None => (Ok(false), None) + } } } + }; + + let triggered = triggered_result?; + if !triggered || clients_result.is_none() { + continue; + } + + // Push notifications for all clients who are subscribed to this rule + for client in clients_result.unwrap() { + notifications.push(( + client.clone(), + RuleNotification { + id: rule_id.clone(), + topic: Topic(data.name.clone()), + values: data.values.clone(), + time: data.timestamp, + }, + )); } } - // pass back the results + if notifications.is_empty() { Ok(None) } else { @@ -228,105 +428,86 @@ impl RuleManager { } /// Adds a rule, creating or activating the client if needed - pub fn add_rule(&mut self, client: ClientId, rule: Rule) -> Result<(), RuleManagerError> { - // go through the topics and add to rules lookup table - match self.rules_lookup.get_mut(&rule.topic) { - Some(rules) => { - rules.insert(client.clone()); - } - None => { - let mut new_set = FxHashSet::default(); - new_set.insert(client.clone()); - self.rules_lookup.insert(rule.topic.clone(), new_set); - } - } - - // push rule, make client active and push rule, or create client and push rule - match self.clients_map.get_mut(&client) { - Some(client) => match client.get_mut(&rule.topic) { - Some(set) => set.push(rule), - None => { - client.insert(rule.topic.clone(), vec![rule]); - } + pub async fn add_rule(&self, client: ClientId, rule: Rule) -> Result<(), RuleManagerError> { + let rule_id = rule.id.clone(); + let topic = rule.topic.clone(); + + // Run all three writes concurrently + tokio::join!( + async { + // Add to subscriptions bimap + self.subscriptions.write().await.insert(&client, &rule_id); }, - - None => { - let mut map = FxHashMap::default(); - map.insert(rule.topic.clone(), vec![rule]); - self.clients_map.insert(client, map); + async { + // Add to topic index + self.topic_index + .write() + .await + .entry(topic) + .or_insert(FxHashSet::default()) + .insert(rule_id.clone()); + }, + async { + // Add to rules lookup + self.rules.write().await.insert(rule_id.clone(), rule); } - }; + ); Ok(()) } - /// Deletes a rule, leaving the client existing and active no matter what - pub fn delete_rule( - &mut self, + /// Deletes a rule from client. \ + /// If no more rules exist for that client, the client is also removed. + pub async fn delete_rule( + &self, client_id: ClientId, rule_id: RuleId, ) -> Result<(), RuleManagerError> { - // first, find the rules from the clients map - let Some(rules) = self.clients_map.get_mut(&client_id) else { - warn!("Could not find client {}", client_id); - return Err(RuleManagerError::NoSuchClient); - }; - - let mut removed: Option = None; - for rule_vals in rules.values_mut() { - let Some(pos) = rule_vals.iter().position(|a| a.id == rule_id) else { - break; - }; - removed = Some(rule_vals.remove(pos)); - } - - let Some(removed) = removed else { - warn!("Could not find rule: {}", rule_id); - return Err(RuleManagerError::NoMatchingRule); - }; - - // now, yeet the rule from the lookup cache, ONLY IF the client doesnt have any rules with the given topic left - let lookup_preserve = rules.contains_key(&removed.topic); - - if !lookup_preserve { - let Some(clients) = self.rules_lookup.get_mut(&removed.topic) else { - warn!("Could not find rule in cache!"); - return Err(RuleManagerError::Failure); - }; - // remove the client from the cache for that topic, deleting rule from cache if necessary - clients.retain(|client| *client != client_id); - // delete client from cache is normal, the client could still exist without rules in client_map - if clients.is_empty() { - self.rules_lookup.remove(&removed.topic); + // Remove rule from client + match self + .subscriptions + .write() + .await + .remove_right_from_left(&client_id, &rule_id) + { + RemovedWithCleanUp(_) | RemovedOnly => Ok(()), + NothingToRemove => { + warn!( + "Could not find client in subscriptions bimap to delete rule: {}", + client_id + ); + Err(RuleManagerError::NoSuchClient) } - } // else we dont touch the lookup cache - - Ok(()) + } } - /// deletes a client, and all of its rules - pub fn delete_client(&mut self, client_id: ClientId) -> Result<(), RuleManagerError> { - // first, yeet from clients map - let Some(rules) = self.clients_map.remove(&client_id) else { - warn!("Could not find client to delete: {}", client_id); - return Err(RuleManagerError::NoSuchClient); - }; - - // now, for each unique topic found amongst the rules, yeet it from the lookup - // this uses a hashset to de-dup the rules to avoid annoying warnings - for rule in rules.keys() { - warn!("DELETING {}", rule); - let Some(client_list) = self.rules_lookup.get_mut(rule) else { - warn!("Could not find topic in rule lookup table!"); - return Err(RuleManagerError::Failure); - }; - client_list.retain(|client| *client != client_id); - // remove the whole entry if no clients exist for the topic - if client_list.is_empty() { - self.rules_lookup.remove(rule); + /// Deletes a client, and all of its rules. + /// Removes rules that are no longer subscribed to if needed. + pub async fn delete_client(&self, client_id: ClientId) -> Result<(), RuleManagerError> { + // Removing from left returns rules that no longer have clients + match self.subscriptions.write().await.remove_left(&client_id) { + RemovedWithCleanUp(_) | RemovedOnly => Ok(()), + NothingToRemove => { + warn!( + "Could not find client in subscriptions bimap to delete client: {}", + client_id + ); + Err(RuleManagerError::NoSuchClient) } } + } - Ok(()) + pub async fn get_all_rules(&self) -> Vec { + self.rules + .read() + .await + .values() + .into_iter() + .map(|rule| rule.clone()) + .collect() + } + + pub async fn get_all_clients(&self) -> Vec { + self.subscriptions.read().await.lefts() } } diff --git a/scylla-server/src/socket_handler.rs b/scylla-server/src/socket_handler.rs index 372fa05f..88042255 100644 --- a/scylla-server/src/socket_handler.rs +++ b/scylla-server/src/socket_handler.rs @@ -46,7 +46,7 @@ struct AuthData { pub async fn socket_handler_with_metadata( cancel_token: CancellationToken, mut data_channel: broadcast::Receiver, - rules_manager: Arc>, + rules_manager: Arc, io: SocketIo, ) { let mut upload_counter = 0u8; @@ -195,18 +195,18 @@ pub async fn socket_handler_with_metadata( /// Handles triggering rules based on a recieved datapoint async fn handle_rule_processing( data: &ClientData, - rule_manager: &Arc>, + rule_manager: &Arc, client_socket_map: &Arc>>, io: &SocketIo, ) { - let Ok(Some(notifs)) = rule_manager.write().await.handle_msg(data) else { + let Ok(Some(notifs)) = rule_manager.handle_msg(data).await else { return; }; for notification in notifs { let read_clients = client_socket_map.read().await; let Some(sid) = read_clients.get(¬ification.0 .0) else { warn!("Could not find client to deliver notification, deleting client"); - let _ = rule_manager.write().await.delete_client(notification.0); + let _ = rule_manager.delete_client(notification.0).await; return; }; debug!( @@ -215,7 +215,7 @@ async fn handle_rule_processing( ); let Some(socket) = io.get_socket(*sid) else { warn!("Could not find client socket, deleting client"); - let _ = rule_manager.write().await.delete_client(notification.0); + let _ = rule_manager.delete_client(notification.0).await; return; }; if let Err(err) = socket.emit(RULE_SOCKET_KEY, ¬ification.1) { diff --git a/scylla-server/tests/bimultimap_test.rs b/scylla-server/tests/bimultimap_test.rs new file mode 100644 index 00000000..24b76528 --- /dev/null +++ b/scylla-server/tests/bimultimap_test.rs @@ -0,0 +1,240 @@ +use scylla_server::rule_structs::{BiMapRemoveResult, BiMultiMap, ClientId, RuleId}; + +#[test] +fn test_bi_multi_map_insert_single() { + let mut bimap = BiMultiMap::new(); + let left = "client1".to_string(); + let right = 42; + + bimap.insert(&left, &right); + + assert_eq!(bimap.get_right(&left).unwrap().len(), 1); + assert!(bimap.get_right(&left).unwrap().contains(&right)); + + assert_eq!(bimap.get_left(&right).unwrap().len(), 1); + assert!(bimap.get_left(&right).unwrap().contains(&left)); +} + +#[test] +fn test_bi_multi_map_insert_multiple_rights() { + let mut bimap = BiMultiMap::new(); + let left = "client1".to_string(); + let right1 = 42; + let right2 = 43; + let right3 = 44; + + bimap.insert(&left, &right1); + bimap.insert(&left, &right2); + bimap.insert(&left, &right3); + + let rights = bimap.get_right(&left).unwrap(); + assert_eq!(rights.len(), 3); + assert!(rights.contains(&right1)); + assert!(rights.contains(&right2)); + assert!(rights.contains(&right3)); + + // Each right should map back to the left + assert!(bimap.get_left(&right1).unwrap().contains(&left)); + assert!(bimap.get_left(&right2).unwrap().contains(&left)); + assert!(bimap.get_left(&right3).unwrap().contains(&left)); +} + +#[test] +fn test_bi_multi_map_insert_many_to_many() { + let mut bimap = BiMultiMap::new(); + let left1 = "client1".to_string(); + let left2 = "client2".to_string(); + let right1 = 42; + let right2 = 43; + + // Create many-to-many relationships + bimap.insert(&left1, &right1); + bimap.insert(&left1, &right2); + bimap.insert(&left2, &right1); + bimap.insert(&left2, &right2); + + // Verify left1 maps to both rights + let rights_for_left1 = bimap.get_right(&left1).unwrap(); + assert_eq!(rights_for_left1.len(), 2); + assert!(rights_for_left1.contains(&right1)); + assert!(rights_for_left1.contains(&right2)); + + // Verify left2 maps to both rights + let rights_for_left2 = bimap.get_right(&left2).unwrap(); + assert_eq!(rights_for_left2.len(), 2); + assert!(rights_for_left2.contains(&right1)); + assert!(rights_for_left2.contains(&right2)); + + // Verify right1 maps to both lefts + let lefts_for_right1 = bimap.get_left(&right1).unwrap(); + assert_eq!(lefts_for_right1.len(), 2); + assert!(lefts_for_right1.contains(&left1)); + assert!(lefts_for_right1.contains(&left2)); + + // Verify right2 maps to both lefts + let lefts_for_right2 = bimap.get_left(&right2).unwrap(); + assert_eq!(lefts_for_right2.len(), 2); + assert!(lefts_for_right2.contains(&left1)); + assert!(lefts_for_right2.contains(&left2)); +} + +#[test] +fn test_bi_multi_map_remove_left_single() { + let mut bimap = BiMultiMap::new(); + let left = "client1".to_string(); + let right = 42; + + bimap.insert(&left, &right); + + let result = bimap.remove_left(&left); + assert!(matches!(result, BiMapRemoveResult::RemovedWithCleanUp(_))); + + if let BiMapRemoveResult::RemovedWithCleanUp(removed_rights) = result { + assert_eq!(removed_rights.len(), 1); + assert!(removed_rights.contains(&right)); + } + + // Verify both directions are cleaned up + assert!(bimap.get_right(&left).is_none()); + assert!(bimap.get_left(&right).is_none()); +} + +#[test] +fn test_bi_multi_map_remove_left_shared_right() { + let mut bimap = BiMultiMap::new(); + let left1 = "client1".to_string(); + let left2 = "client2".to_string(); + let right = 42; + + bimap.insert(&left1, &right); + bimap.insert(&left2, &right); + + let result = bimap.remove_left(&left1); + assert!(matches!(result, BiMapRemoveResult::RemovedOnly)); + + // left1 should be gone + assert!(bimap.get_right(&left1).is_none()); + + // right should still exist and map to left2 + let remaining_lefts = bimap.get_left(&right).unwrap(); + assert_eq!(remaining_lefts.len(), 1); + assert!(remaining_lefts.contains(&left2)); + + // left2 should still map to right + assert!(bimap.get_right(&left2).unwrap().contains(&right)); +} + +#[test] +fn test_bi_multi_map_remove_left_nonexistent() { + let mut bimap: BiMultiMap = BiMultiMap::new(); + let left = "nonexistent".to_string(); + + let result = bimap.remove_left(&left); + assert!(matches!(result, BiMapRemoveResult::NothingToRemove)); +} + +#[test] +fn test_bi_multi_map_remove_right_from_left_nonexistent() { + let mut bimap: BiMultiMap = BiMultiMap::new(); + let left = "client1".to_string(); + let right = 42; + + let result = bimap.remove_right_from_left(&left, &right); + assert!(matches!(result, BiMapRemoveResult::NothingToRemove)); +} + +#[test] +fn test_bi_multi_map_remove_left_from_right() { + let mut bimap = BiMultiMap::new(); + let left1 = "client1".to_string(); + let left2 = "client2".to_string(); + let right = 42; + + bimap.insert(&left1, &right); + bimap.insert(&left2, &right); + + let result = bimap.remove_left_from_right(&right, &left1); + assert!(matches!(result, BiMapRemoveResult::RemovedWithCleanUp(_))); + + if let BiMapRemoveResult::RemovedWithCleanUp(removed_left) = result { + assert_eq!(removed_left, left1); + } + + // right should still exist but only map to left2 + let remaining_lefts = bimap.get_left(&right).unwrap(); + assert_eq!(remaining_lefts.len(), 1); + assert!(remaining_lefts.contains(&left2)); + + // left1 should be completely removed + assert!(bimap.get_right(&left1).is_none()); + + // left2 should still map to right + assert!(bimap.get_right(&left2).unwrap().contains(&right)); +} + +#[test] +fn test_bi_multi_map_complex_operations() { + let mut bimap = BiMultiMap::new(); + + // Set up a complex mapping + bimap.insert(&"client1", &"rule1"); + bimap.insert(&"client1", &"rule2"); + bimap.insert(&"client2", &"rule1"); + bimap.insert(&"client2", &"rule3"); + bimap.insert(&"client3", &"rule3"); + + // Remove a shared rule from one client + let result = bimap.remove_right_from_left(&"client1", &"rule1"); + assert!(matches!(result, BiMapRemoveResult::RemovedOnly)); + + // Verify rule1 still exists for client2 + assert!(bimap.get_right(&"client2").unwrap().contains(&"rule1")); + + // Verify client1 still has rule2 + assert!(bimap.get_right(&"client1").unwrap().contains(&"rule2")); + assert!(!bimap.get_right(&"client1").unwrap().contains(&"rule1")); + + // Remove client2 entirely + let result = bimap.remove_left(&"client2"); + assert!(matches!(result, BiMapRemoveResult::RemovedWithCleanUp(_))); + + if let BiMapRemoveResult::RemovedWithCleanUp(removed_rights) = result { + assert_eq!(removed_rights.len(), 1); + assert!(removed_rights.contains(&"rule1")); // rule1 should be cleaned up + } + + // rule3 should still exist for client3 + assert!(bimap.get_left(&"rule3").unwrap().contains(&"client3")); + + // rule1 should be completely gone + assert!(bimap.get_left(&"rule1").is_none()); +} + +#[test] +fn test_bi_multi_map_with_rule_manager_types() { + let mut bimap: BiMultiMap = BiMultiMap::new(); + + let client1 = ClientId("client1".to_string()); + let client2 = ClientId("client2".to_string()); + let rule1 = RuleId("rule1".to_string()); + let rule2 = RuleId("rule2".to_string()); + + bimap.insert(&client1, &rule1); + bimap.insert(&client1, &rule2); + bimap.insert(&client2, &rule1); + + // Test with actual types used in RuleManager + assert_eq!(bimap.get_right(&client1).unwrap().len(), 2); + assert_eq!(bimap.get_left(&rule1).unwrap().len(), 2); + + let result = bimap.remove_left(&client1); + assert!(matches!(result, BiMapRemoveResult::RemovedWithCleanUp(_))); + + if let BiMapRemoveResult::RemovedWithCleanUp(removed_rules) = result { + assert_eq!(removed_rules.len(), 1); + assert!(removed_rules.contains(&rule2)); // rule2 should be cleaned up + } + + // rule1 should still exist for client2 + assert!(bimap.get_left(&rule1).unwrap().contains(&client2)); +} diff --git a/scylla-server/tests/rule_structs_test.rs b/scylla-server/tests/rule_structs_test.rs new file mode 100644 index 00000000..97343a01 --- /dev/null +++ b/scylla-server/tests/rule_structs_test.rs @@ -0,0 +1,544 @@ +use chrono::Utc; +use scylla_server::rule_structs::*; +use scylla_server::ClientData; +use tokio::task::JoinSet; + +#[tokio::test] +async fn test_add_multiple_rules_same_client() -> Result<(), RuleManagerError> { + let rule_manager = RuleManager::new(); + let client = ClientId("test_client".to_string()); + + let rule1 = Rule::new( + RuleId("rule_1".to_string()), + Topic("test/topic1".to_string()), + core::time::Duration::from_secs(60), + "a > 10".to_owned(), + ); + + let rule2 = Rule::new( + RuleId("rule_2".to_string()), + Topic("test/topic2".to_string()), + core::time::Duration::from_secs(30), + "b < 5".to_owned(), + ); + + rule_manager.add_rule(client.clone(), rule1).await?; + rule_manager.add_rule(client, rule2).await?; + + assert_eq!(rule_manager.get_all_rules().await.len(), 2); + Ok(()) +} + +#[tokio::test] +async fn test_delete_rule_success() -> Result<(), RuleManagerError> { + let rule_manager = RuleManager::new(); + let client = ClientId("test_client".to_string()); + let rule_id = RuleId("rule_1".to_string()); + + let rule = Rule::new( + rule_id.clone(), + Topic("test/topic".to_string()), + core::time::Duration::from_secs(60), + "a > 10".to_owned(), + ); + + rule_manager.add_rule(client.clone(), rule).await?; + assert_eq!(rule_manager.get_all_rules().await.len(), 1); + + rule_manager.delete_rule(client, rule_id).await?; + assert_eq!(rule_manager.get_all_rules().await.len(), 1); // Rule still exists but client is unsubscribed + + Ok(()) +} + +#[tokio::test] +async fn test_delete_client_success() -> Result<(), RuleManagerError> { + let rule_manager = RuleManager::new(); + let client = ClientId("test_client".to_string()); + + let rule1 = Rule::new( + RuleId("rule_1".to_string()), + Topic("test/topic1".to_string()), + core::time::Duration::from_secs(60), + "a > 10".to_owned(), + ); + + let rule2 = Rule::new( + RuleId("rule_2".to_string()), + Topic("test/topic2".to_string()), + core::time::Duration::from_secs(30), + "b < 5".to_owned(), + ); + + rule_manager.add_rule(client.clone(), rule1).await?; + rule_manager.add_rule(client.clone(), rule2).await?; + assert_eq!(rule_manager.get_all_rules().await.len(), 2); + assert_eq!(rule_manager.get_all_clients().await.len(), 1); + + rule_manager.delete_client(client).await?; + assert!(rule_manager.get_all_clients().await.is_empty()); + assert_eq!(rule_manager.get_all_rules().await.len(), 2); + + Ok(()) +} + +#[tokio::test] +async fn test_handle_msg_rule_triggered() -> Result<(), RuleManagerError> { + let rule_manager = RuleManager::new(); + let client = ClientId("test_client".to_string()); + + let rule = Rule::new( + RuleId("rule_1".to_string()), + Topic("test/topic".to_string()), + core::time::Duration::from_secs(1), + "a > 10".to_owned(), // First value (a) should be > 10 + ); + + rule_manager.add_rule(client.clone(), rule).await?; + + let client_data = ClientData { + run_id: 1, + name: "test/topic".to_string(), + unit: "test_unit".to_string(), + values: vec![15.0], // a = 15.0 > 10, should trigger + timestamp: Utc::now(), + }; + + // First trigger might not fire due to debounce logic + let empty_notifications = rule_manager.handle_msg(&client_data).await; + assert!(empty_notifications.is_ok_and(|op| op.is_none())); + + // Wait for debounce time + tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + + let result = rule_manager.handle_msg(&client_data).await?; + + let notifications = result.unwrap(); + assert!(!notifications.is_empty()); + assert_eq!(notifications[0].0 .0, client.0); + assert_eq!(notifications[0].1.topic.0, "test/topic"); + + Ok(()) +} + +#[tokio::test] +async fn test_handle_msg_multiple_clients_same_rule() -> Result<(), RuleManagerError> { + let rule_manager = RuleManager::new(); + let client1 = ClientId("client1".to_string()); + let client2 = ClientId("client2".to_string()); + + let rule1 = Rule::new( + RuleId("rule_1".to_string()), + Topic("shared/topic".to_string()), + core::time::Duration::from_millis(100), + "a > 10".to_owned(), + ); + + let rule2 = Rule::new( + RuleId("rule_2".to_string()), + Topic("shared/topic".to_string()), + core::time::Duration::from_millis(100), + "a > 5".to_owned(), // Different condition but same topic + ); + + rule_manager.add_rule(client1.clone(), rule1).await?; + rule_manager.add_rule(client2.clone(), rule2).await?; + + let client_data = ClientData { + run_id: 1, + name: "shared/topic".to_string(), + unit: "test_unit".to_string(), + values: vec![15.0], + timestamp: Utc::now(), + }; + + // First trigger to start debounce timers + let empty = rule_manager.handle_msg(&client_data).await; + assert!(empty.is_ok_and(|op| op.is_none())); + + // Wait for debounce + tokio::time::sleep(tokio::time::Duration::from_millis(150)).await; + + let result = rule_manager.handle_msg(&client_data).await?; + + if let Some(notifications) = result { + // Both rules should trigger since 15.0 > 10 and 15.0 > 5 + assert!(notifications.len() >= 1); + + let client_ids: Vec<_> = notifications.iter().map(|(id, _)| id.clone()).collect(); + assert!(client_ids.contains(&client1) && client_ids.contains(&client2)); + } + + Ok(()) +} + +fn check_rules_present(rules: Vec, prefix: &str, k: usize) { + assert_eq!(rules.len(), k); + let topics = rules.into_iter().map(|r| r.topic.0).collect::>(); + assert!((0..k).all(|i| topics.contains(&format!("{}{}", prefix, i)))); +} + +fn check_clients_present(clients: Vec, prefix: &str, k: usize) { + assert_eq!(clients.len(), k); + let client_strings = clients.into_iter().map(|c| c.0).collect::>(); + assert!((0..k).all(|i| client_strings.contains(&format!("{}{}", prefix, i)))); +} + +#[tokio::test] +async fn test_rule_manager_concurrent_add_rule() -> Result<(), RuleManagerError> { + let num_rules = 10; + let rule_manager = std::sync::Arc::new(RuleManager::new()); + + (0..num_rules) + .fold(JoinSet::new(), |mut set, i| { + let rm = rule_manager.clone(); + set.spawn(async move { + let client = ClientId(format!("client_{}", i)); + let rule = Rule::new( + RuleId(format!("rule_{}", i)), + Topic(format!("topic/{}", i)), + core::time::Duration::from_secs(60), + "a > 5".to_owned(), + ); + + rm.add_rule(client, rule).await.unwrap(); + }); + set + }) + .join_all() + .await; + + let clients = rule_manager.get_all_clients().await; + check_clients_present(clients, "client_", num_rules); + + let rules = rule_manager.get_all_rules().await; + check_rules_present(rules, "topic/", num_rules); + + Ok(()) +} + +#[tokio::test] +async fn test_rule_manager_concurrent_delete_rule() -> Result<(), RuleManagerError> { + let num_rules = 10; + let rule_manager = std::sync::Arc::new(RuleManager::new()); + + (0..num_rules) + .fold(JoinSet::new(), |mut set, i| { + let rm = rule_manager.clone(); + set.spawn(async move { + let client = ClientId(format!("client_{}", i)); + let rule = Rule::new( + RuleId(format!("rule_{}", i)), + Topic(format!("topic/{}", i)), + core::time::Duration::from_secs(60), + "a > 5".to_owned(), + ); + + rm.add_rule(client, rule).await.unwrap(); + }); + set + }) + .join_all() + .await; + + check_clients_present(rule_manager.get_all_clients().await, "client_", num_rules); + check_rules_present(rule_manager.get_all_rules().await, "topic/", num_rules); + + let f = async || { + (0..10) + .fold(JoinSet::new(), |mut set, i| { + let rm = rule_manager.clone(); + set.spawn(async move { + let client = ClientId(format!("client_{}", i)); + let rule_id = RuleId(format!("rule_{}", i)); + rm.delete_rule(client, rule_id).await + }); + set + }) + .join_all() + .await + }; + + // Deleting rules from calling client side code doesn't actually remove rules + let res = f().await; + assert!(res.into_iter().all(|e| e.is_ok())); + check_rules_present(rule_manager.get_all_rules().await, "topic/", num_rules); + assert!(rule_manager.get_all_clients().await.is_empty()); + + // Deleting again will result in NoSuchClient errors + let res = f().await; + assert!(res.into_iter().all(|e| e.is_err())); + check_rules_present(rule_manager.get_all_rules().await, "topic/", num_rules); + assert!(rule_manager.get_all_clients().await.is_empty()); + + Ok(()) +} + +#[tokio::test] +async fn test_concurrent_topic_index_stress() -> Result<(), RuleManagerError> { + let num_topics = 20; + let num_rules_per_topic = 5; + let rule_manager = std::sync::Arc::new(RuleManager::new()); + + // Create multiple rules for the same topics concurrently + let results: Vec<_> = (0..num_topics) + .flat_map(|topic_idx| (0..num_rules_per_topic).map(move |rule_idx| (topic_idx, rule_idx))) + .fold(JoinSet::new(), |mut set, (topic_idx, rule_idx)| { + let rm = rule_manager.clone(); + set.spawn(async move { + let client = ClientId(format!("topic_client_{}_{}", topic_idx, rule_idx)); + let rule = Rule::new( + RuleId(format!("topic_rule_{}_{}", topic_idx, rule_idx)), + Topic(format!("topic/{}", topic_idx)), + core::time::Duration::from_millis(50), + format!("a > {}", rule_idx), + ); + rm.add_rule(client.clone(), rule) + .await + .map(|_| (topic_idx, rule_idx, client)) + }); + set + }) + .join_all() + .await; + + // Verify all operations succeeded + let successful_adds: Vec<_> = results.into_iter().filter_map(|r| r.ok()).collect(); + let total_expected = num_topics * num_rules_per_topic; + assert_eq!(successful_adds.len(), total_expected); + + // Verify final counts + assert_eq!(rule_manager.get_all_rules().await.len(), total_expected); + assert_eq!(rule_manager.get_all_clients().await.len(), total_expected); + + // Verify topic distribution + let all_rules = rule_manager.get_all_rules().await; + let mut topic_counts = std::collections::HashMap::new(); + for rule in all_rules { + *topic_counts.entry(rule.topic.0).or_insert(0) += 1; + } + + assert_eq!(topic_counts.len(), num_topics); + for i in 0..num_topics { + let topic_name = format!("topic/{}", i); + assert_eq!(topic_counts[&topic_name], num_rules_per_topic); + } + + // Test that all topics can handle messages concurrently + let message_results: Vec<_> = (0..num_topics) + .fold(JoinSet::new(), |mut set, topic_idx| { + let rm = rule_manager.clone(); + set.spawn(async move { + let client_data = ClientData { + run_id: 1, + name: format!("topic/{}", topic_idx), + unit: "test".to_string(), + values: vec![10.0], // Should trigger rules with threshold < 10 + timestamp: Utc::now(), + }; + rm.handle_msg(&client_data) + .await + .map(|result| (topic_idx, result)) + }); + set + }) + .join_all() + .await; + + // Verify all messages were processed + let successful_messages: Vec<_> = message_results.into_iter().filter_map(|r| r.ok()).collect(); + assert_eq!(successful_messages.len(), num_topics); + + // Wait for debounce and try again + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let second_round_results: Vec<_> = (0..num_topics) + .fold(JoinSet::new(), |mut set, topic_idx| { + let rm = rule_manager.clone(); + set.spawn(async move { + let client_data = ClientData { + run_id: 1, + name: format!("topic/{}", topic_idx), + unit: "test".to_string(), + values: vec![10.0], + timestamp: Utc::now(), + }; + rm.handle_msg(&client_data).await + }); + set + }) + .join_all() + .await; + + // Count total notifications from second round (should have some due to debounce completion) + let total_notifications: usize = second_round_results + .iter() + .filter_map(|r| r.as_ref().ok()) + .map(|result| result.as_ref().map(|n| n.len()).unwrap_or(0)) + .sum(); + + // Should have triggered some rules (those with threshold < 10) + // Each topic has rules with thresholds 0,1,2,3,4 so value 10.0 should trigger all of them + assert!(total_notifications > 0); + println!( + "Total notifications in second round: {}", + total_notifications + ); + + Ok(()) +} + +#[tokio::test] +async fn test_concurrent_high_frequency_messages() -> Result<(), RuleManagerError> { + let rule_manager = std::sync::Arc::new(RuleManager::new()); + + // Set up multiple rules that will receive high-frequency messages + let num_rules = 5; + for i in 0..num_rules { + let client = ClientId(format!("high_freq_client_{}", i)); + let rule = Rule::new( + RuleId(format!("high_freq_rule_{}", i)), + Topic("high_freq/topic".to_string()), + core::time::Duration::from_millis(50), + format!("a > {}", i * 10), // Thresholds: 0, 10, 20, 30, 40 + ); + rule_manager.add_rule(client, rule).await?; + } + + // Verify setup + assert_eq!(rule_manager.get_all_rules().await.len(), num_rules); + assert_eq!(rule_manager.get_all_clients().await.len(), num_rules); + + let messages_per_task = 20; + let num_tasks = 10; + let total_messages = messages_per_task * num_tasks; + + // Send high-frequency messages from multiple tasks + let results: Vec<_> = (0..num_tasks) + .fold(JoinSet::new(), |mut set, task_id| { + let rm = rule_manager.clone(); + set.spawn(async move { + let mut task_results = Vec::new(); + for msg_id in 0..messages_per_task { + let value = (task_id * messages_per_task + msg_id) as f32 % 100.0; + let client_data = ClientData { + run_id: task_id as i32, + name: "high_freq/topic".to_string(), + unit: "test".to_string(), + values: vec![value], + timestamp: Utc::now(), + }; + + let result = rm.handle_msg(&client_data).await; + task_results.push((msg_id, value, result)); + + // Small delay to simulate realistic message timing + if msg_id % 5 == 0 { + tokio::task::yield_now().await; + } + } + (task_id, task_results) + }); + set + }) + .join_all() + .await; + + // Verify all tasks completed + assert_eq!(results.len(), num_tasks); + + // Flatten and verify all message results + let all_message_results: Vec<_> = results + .into_iter() + .flat_map(|(task_id, task_results)| { + task_results + .into_iter() + .map(move |(msg_id, value, result)| (task_id, msg_id, value, result)) + }) + .collect(); + + assert_eq!(all_message_results.len(), total_messages); + + // Verify all messages were processed successfully + let successful_messages: Vec<_> = all_message_results + .iter() + .filter(|(_, _, _, result)| result.is_ok()) + .collect(); + assert_eq!(successful_messages.len(), total_messages); + + println!( + "Successfully processed {} high-frequency messages", + total_messages + ); + + // Wait for any pending debounce timers + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Send final test messages with known values that should trigger specific rules + let test_values = vec![5.0, 15.0, 25.0, 35.0, 45.0]; // Should trigger different numbers of rules + let final_results: Vec<_> = test_values + .into_iter() + .fold(JoinSet::new(), |mut set, value| { + let rm = rule_manager.clone(); + set.spawn(async move { + let client_data = ClientData { + run_id: 999, + name: "high_freq/topic".to_string(), + unit: "test".to_string(), + values: vec![value], + timestamp: Utc::now(), + }; + + // Wait for debounce + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + let result = rm.handle_msg(&client_data).await; + (value, result) + }); + set + }) + .join_all() + .await; + + // Verify final test results + assert_eq!(final_results.len(), 5); + + for (value, result) in final_results { + assert!( + result.is_ok(), + "Failed to process message with value {}", + value + ); + + if let Ok(Some(notifications)) = result { + // Count how many rules should trigger for this value + let expected_triggers = num_rules - (value as usize / 10).min(num_rules); + if expected_triggers > 0 { + assert!( + !notifications.is_empty(), + "Value {} should have triggered some rules", + value + ); + assert!( + notifications.len() <= expected_triggers, + "Value {} triggered {} rules, expected at most {}", + value, + notifications.len(), + expected_triggers + ); + + // Verify notification structure + for (client_id, notification) in notifications { + assert!(client_id.0.starts_with("high_freq_client_")); + assert_eq!(notification.topic.0, "high_freq/topic"); + assert_eq!(notification.values, vec![value]); + } + } + } + } + + // Verify system state is unchanged + assert_eq!(rule_manager.get_all_rules().await.len(), num_rules); + assert_eq!(rule_manager.get_all_clients().await.len(), num_rules); + + Ok(()) +}