diff --git a/benches/bench.rs b/benches/bench.rs index 4d9be02..d40f3df 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -2,6 +2,7 @@ use blunt::websocket::{WebSocketHandler, WebSocketMessage, WebSocketSession}; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; use std::time::{Duration, Instant}; use tokio::task::JoinHandle; +use tungstenite::{connect, Message}; #[derive(Debug, Default)] pub struct EchoServer; @@ -10,7 +11,7 @@ pub struct EchoServer; impl WebSocketHandler for EchoServer { async fn on_open(&mut self, _ws: &WebSocketSession) {} async fn on_message(&mut self, ws: &WebSocketSession, msg: WebSocketMessage) { - ws.send(msg).expect("Unable to send message"); + let _ = ws.send(msg); } async fn on_close(&mut self, _ws: &WebSocketSession, _msg: WebSocketMessage) {} @@ -27,113 +28,125 @@ fn start_echo_server() -> JoinHandle<()> { }) } -fn single_echo_benchmark(c: &mut Criterion) { - use tungstenite::{connect, Message}; +fn single_server_send_only_iter_custom() -> impl FnMut(u64) -> Duration { + |iters| { + let (mut socket, _response) = connect("ws://localhost:9999/echo").expect("Can't connect"); + let ws_message = Message::Text(String::from("Hello World!")); + let start = Instant::now(); + for _n in 0..iters { + if let Err(e) = socket.write_message(black_box(ws_message.clone())) { + eprintln!("error: {:?}", e); + } + } + + start.elapsed() + } +} - let rt = tokio::runtime::Runtime::new().unwrap(); - rt.block_on(async { - let _server_handle = start_echo_server(); +fn single_server_send_recv_iter_custom() -> impl FnMut(u64) -> Duration { + |iters| { + let ws_message = Message::Text(String::from("Hello World!")); let (mut socket, _response) = connect("ws://localhost:9999/echo").expect("Can't connect"); - let mut group = c.benchmark_group("single echo server"); - group.throughput(Throughput::Elements(100)); - group.bench_function("100 - Send only", |b| { - b.iter_custom(|iters| { - let ws_message = Message::Text(String::from("Hello World!")); - let start = Instant::now(); - for _n in 0..iters { - socket.write_message(black_box(ws_message.clone())).unwrap(); - } - - start.elapsed() - }) - }); - - group.finish(); - - let mut group = c.benchmark_group("single echo server"); - group.throughput(Throughput::Elements(100)); - group.bench_function("100 - Send and receive", |b| { - b.iter_custom(|iters| { - let ws_message = Message::Text(String::from("Hello World!")); - let start = Instant::now(); - for _n in 0..iters { - socket.write_message(black_box(ws_message.clone())).unwrap(); - let _ = socket.read_message().expect("Error reading message"); - } - - start.elapsed() - }) - }); - - group.finish(); - - let mut group = c.benchmark_group("single echo server"); - group.throughput(Throughput::Elements(1000)); - group.bench_function("1000 - Send only", |b| { - b.iter_custom(|iters| { - let ws_message = Message::Text(String::from("Hello World!")); - let start = Instant::now(); - for _n in 0..iters { - socket.write_message(black_box(ws_message.clone())).unwrap(); - } - - start.elapsed() - }) - }); - - group.finish(); - - let mut group = c.benchmark_group("single echo server"); - group.throughput(Throughput::Elements(1000)); - group.bench_function("1000 - Send and receive", |b| { - b.iter_custom(|iters| { - let ws_message = Message::Text(String::from("Hello World!")); - let start = Instant::now(); - for _n in 0..iters { - socket.write_message(black_box(ws_message.clone())).unwrap(); - let _ = socket.read_message().expect("Error reading message"); - } - - start.elapsed() - }) - }); - - group.finish(); - - let mut group = c.benchmark_group("single echo server"); - group.throughput(Throughput::Elements(10000)); - group.bench_function("10000 - Send only", |b| { - b.iter_custom(|iters| { - let ws_message = Message::Text(String::from("Hello World!")); - let start = Instant::now(); - for _n in 0..iters { - socket.write_message(black_box(ws_message.clone())).unwrap(); - } - - start.elapsed() - }) - }); - - group.finish(); - - let mut group = c.benchmark_group("single echo server"); - group.throughput(Throughput::Elements(10000)); - group.bench_function("10000 - Send and receive", |b| { - b.iter_custom(|iters| { - let ws_message = Message::Text(String::from("Hello World!")); - let start = Instant::now(); - for _n in 0..iters { - socket.write_message(black_box(ws_message.clone())).unwrap(); - let _ = socket.read_message().expect("Error reading message"); - } - - start.elapsed() - }) - }); - - group.finish(); + let start = Instant::now(); + for _n in 0..iters { + socket.write_message(black_box(ws_message.clone())).unwrap(); + let _ = socket.read_message().expect("Error reading message"); + } + + start.elapsed() + } +} + +enum ServerType { + SingleServer, + MultiServer, +} + +fn start_tokio_rt(typ: ServerType) -> tokio::sync::mpsc::UnboundedSender { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + std::thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + let _server_handle = match typ { + ServerType::SingleServer => start_echo_server(), + ServerType::MultiServer => start_multi_echo_server(), + }; + + let _ = rx.recv().await; + }) + }); + + std::thread::sleep(Duration::from_secs(2)); + tx +} + +#[allow(dead_code)] +fn test_async(c: &mut Criterion) { + let _rt_guard = start_tokio_rt(ServerType::SingleServer); + + let mut group = c.benchmark_group("single echo server"); + group.throughput(Throughput::Elements(100)); + group.measurement_time(Duration::from_secs(60)); + group.sample_size(100); + group.bench_function("100 - Send only", |b| { + b.iter_custom(single_server_send_only_iter_custom()) + }); + + group.finish(); +} + +fn single_echo_benchmark(c: &mut Criterion) { + let _rt_guard = start_tokio_rt(ServerType::SingleServer); + + let mut group = c.benchmark_group("single echo server"); + group.throughput(Throughput::Elements(100)); + group.measurement_time(Duration::from_secs(15)); + group.bench_function("100 - Send only", |b| { + b.iter_custom(single_server_send_only_iter_custom()) + }); + + group.finish(); + + let mut group = c.benchmark_group("single echo server"); + group.throughput(Throughput::Elements(100)); + group.bench_function("100 - Send and receive", |b| { + b.iter_custom(single_server_send_recv_iter_custom()) + }); + + group.finish(); + + let mut group = c.benchmark_group("single echo server"); + group.throughput(Throughput::Elements(1000)); + group.bench_function("1000 - Send only", |b| { + b.iter_custom(single_server_send_only_iter_custom()) + }); + + group.finish(); + + let mut group = c.benchmark_group("single echo server"); + group.throughput(Throughput::Elements(1000)); + group.bench_function("1000 - Send and receive", |b| { + b.iter_custom(single_server_send_recv_iter_custom()) + }); + + group.finish(); + + let mut group = c.benchmark_group("single echo server"); + group.throughput(Throughput::Elements(10000)); + group.bench_function("10000 - Send only", |b| { + b.iter_custom(single_server_send_only_iter_custom()) + }); + + group.finish(); + + let mut group = c.benchmark_group("single echo server"); + group.throughput(Throughput::Elements(10000)); + group.bench_function("10000 - Send and receive", |b| { + b.iter_custom(single_server_send_recv_iter_custom()) }); + + group.finish(); } fn start_multi_echo_server() -> JoinHandle<()> { @@ -149,158 +162,105 @@ fn start_multi_echo_server() -> JoinHandle<()> { }) } -fn multi_echo_benchmark(c: &mut Criterion) { - use tungstenite::{connect, Message}; +fn multi_server_send_only_iter_custom() -> impl FnMut(u64) -> Duration { + |iters| { + let (mut socket, _response) = connect("ws://localhost:9999/echo").expect("Can't connect"); + let (mut socket2, _response) = connect("ws://localhost:9999/echo2").expect("Can't connect"); + let ws_message = Message::Text(String::from("Hello World!")); + let start = Instant::now(); + for n in 0..iters { + if n % 2 == 0 { + socket.write_message(black_box(ws_message.clone())).unwrap(); + } else { + socket2 + .write_message(black_box(ws_message.clone())) + .unwrap(); + } + } + + start.elapsed() + } +} - let rt = tokio::runtime::Runtime::new().unwrap(); - rt.block_on(async { - let _multi_server_handle = start_multi_echo_server(); +fn multi_server_send_recv_iter_custom() -> impl FnMut(u64) -> Duration { + |iters| { let (mut socket, _response) = connect("ws://localhost:9999/echo").expect("Can't connect"); let (mut socket2, _response) = connect("ws://localhost:9999/echo2").expect("Can't connect"); + let ws_message = Message::Text(String::from("Hello World!")); + let start = Instant::now(); + for n in 0..iters { + if n % 2 == 0 { + socket.write_message(black_box(ws_message.clone())).unwrap(); + let _ = socket.read_message().expect("Error reading message"); + } else { + socket2 + .write_message(black_box(ws_message.clone())) + .unwrap(); + let _ = socket2.read_message().expect("Error reading message"); + } + } + + start.elapsed() + } +} + +fn multi_echo_benchmark(c: &mut Criterion) { + let _rt_guard = start_tokio_rt(ServerType::MultiServer); - let mut group = c.benchmark_group("multi echo server"); - group.throughput(Throughput::Elements(100)); - group.bench_function("100 - Send only", |b| { - b.iter_custom(|iters| { - let ws_message = Message::Text(String::from("Hello World!")); - let start = Instant::now(); - for n in 0..iters { - if n % 2 == 0 { - socket.write_message(black_box(ws_message.clone())).unwrap(); - } else { - socket2 - .write_message(black_box(ws_message.clone())) - .unwrap(); - } - } - - start.elapsed() - }) - }); - - group.finish(); - - let mut group = c.benchmark_group("multi echo server"); - group.throughput(Throughput::Elements(100)); - group.bench_function("100 - Send and receive", |b| { - b.iter_custom(|iters| { - let ws_message = Message::Text(String::from("Hello World!")); - let start = Instant::now(); - for n in 0..iters { - if n % 2 == 0 { - socket.write_message(black_box(ws_message.clone())).unwrap(); - let _ = socket.read_message().expect("Error reading message"); - } else { - socket2 - .write_message(black_box(ws_message.clone())) - .unwrap(); - let _ = socket2.read_message().expect("Error reading message"); - } - } - - start.elapsed() - }) - }); - - group.finish(); - - let mut group = c.benchmark_group("multi echo server"); - group.measurement_time(Duration::from_secs(15)); - group.throughput(Throughput::Elements(1000)); - group.bench_function("1000 - Send only", |b| { - b.iter_custom(|iters| { - let ws_message = Message::Text(String::from("Hello World!")); - let start = Instant::now(); - for n in 0..iters { - if n % 2 == 0 { - socket.write_message(black_box(ws_message.clone())).unwrap(); - } else { - socket2 - .write_message(black_box(ws_message.clone())) - .unwrap(); - } - } - - start.elapsed() - }) - }); - - group.finish(); - - let mut group = c.benchmark_group("multi echo server"); - group.measurement_time(Duration::from_secs(15)); - group.throughput(Throughput::Elements(1000)); - group.bench_function("1000 - Send and receive", |b| { - b.iter_custom(|iters| { - let ws_message = Message::Text(String::from("Hello World!")); - let start = Instant::now(); - for n in 0..iters { - if n % 2 == 0 { - socket.write_message(black_box(ws_message.clone())).unwrap(); - let _ = socket.read_message().expect("Error reading message"); - } else { - socket2 - .write_message(black_box(ws_message.clone())) - .unwrap(); - let _ = socket2.read_message().expect("Error reading message"); - } - } - - start.elapsed() - }) - }); - - group.finish(); - - let mut group = c.benchmark_group("multi echo server"); - group.measurement_time(Duration::from_secs(100)); - group.throughput(Throughput::Elements(10000)); - group.bench_function("10000 - Send only", |b| { - b.iter_custom(|iters| { - let ws_message = Message::Text(String::from("Hello World!")); - let start = Instant::now(); - for n in 0..iters { - if n % 2 == 0 { - socket.write_message(black_box(ws_message.clone())).unwrap(); - } else { - socket2 - .write_message(black_box(ws_message.clone())) - .unwrap(); - } - } - - start.elapsed() - }) - }); - - group.finish(); - - let mut group = c.benchmark_group("multi echo server"); - group.measurement_time(Duration::from_secs(100)); - group.throughput(Throughput::Elements(10000)); - group.bench_function("10000 - Send and receive", |b| { - b.iter_custom(|iters| { - let ws_message = Message::Text(String::from("Hello World!")); - let start = Instant::now(); - for n in 0..iters { - if n % 2 == 0 { - socket.write_message(black_box(ws_message.clone())).unwrap(); - let _ = socket.read_message().expect("Error reading message"); - } else { - socket2 - .write_message(black_box(ws_message.clone())) - .unwrap(); - let _ = socket2.read_message().expect("Error reading message"); - } - } - - start.elapsed() - }) - }); - - group.finish(); + let mut group = c.benchmark_group("multi echo server"); + group.throughput(Throughput::Elements(100)); + group.bench_function("100 - Send only", |b| { + b.iter_custom(multi_server_send_only_iter_custom()) }); + + group.finish(); + + let mut group = c.benchmark_group("multi echo server"); + group.throughput(Throughput::Elements(100)); + group.bench_function("100 - Send and receive", |b| { + b.iter_custom(multi_server_send_recv_iter_custom()) + }); + + group.finish(); + + let mut group = c.benchmark_group("multi echo server"); + group.measurement_time(Duration::from_secs(15)); + group.throughput(Throughput::Elements(1000)); + group.bench_function("1000 - Send only", |b| { + b.iter_custom(multi_server_send_only_iter_custom()) + }); + + group.finish(); + + let mut group = c.benchmark_group("multi echo server"); + group.measurement_time(Duration::from_secs(15)); + group.throughput(Throughput::Elements(1000)); + group.bench_function("1000 - Send and receive", |b| { + b.iter_custom(multi_server_send_recv_iter_custom()) + }); + + group.finish(); + + let mut group = c.benchmark_group("multi echo server"); + group.measurement_time(Duration::from_secs(100)); + group.throughput(Throughput::Elements(10000)); + group.bench_function("10000 - Send only", |b| { + b.iter_custom(multi_server_send_only_iter_custom()) + }); + + group.finish(); + + let mut group = c.benchmark_group("multi echo server"); + group.measurement_time(Duration::from_secs(100)); + group.throughput(Throughput::Elements(10000)); + group.bench_function("10000 - Send and receive", |b| { + b.iter_custom(multi_server_send_recv_iter_custom()) + }); + + group.finish(); } criterion_group!(benches, single_echo_benchmark, multi_echo_benchmark); +//criterion_group!(benches, single_echo_benchmark); +//criterion_group!(benches, test_async); criterion_main!(benches); diff --git a/examples/chat.rs b/examples/chat.rs index f727979..4bbe5ce 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -23,12 +23,15 @@ type UserCollection = Arc pub struct ChatServer(UserCollection); impl ChatServer { - async fn broadcast(&mut self, except_id: Uuid, msg: WebSocketMessage) { - self.0.read().await.iter().for_each(|entry| { - if entry.0 != &except_id { + async fn broadcast(&self, except_id: Uuid, msg: WebSocketMessage) { + self.0 + .read() + .await + .iter() + .filter(|entry| entry.0 != &except_id) + .for_each(|entry| { let _ = entry.1.send(msg.clone()); - } - }); + }); } } diff --git a/src/endpoints.rs b/src/endpoints.rs index 6487fab..7d94f5e 100644 --- a/src/endpoints.rs +++ b/src/endpoints.rs @@ -142,7 +142,7 @@ impl Endpoints { } pub(crate) async fn handle_web_request( - &mut self, + &self, request: Request, ) -> Arc>> { let result = self.web_channels.get(request.uri().path()).map(|tx| { diff --git a/src/lib.rs b/src/lib.rs index 005373a..88ce14d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -58,17 +58,18 @@ impl Server { // async task to receive messages from the web socket connection let ws_server2 = self.clone(); - websocket::register_recv_ws_message_handling(ws_server2, ws_session_rx, session_id).await; + websocket::mediator(ws_server2, ws_session_rx, session_id, ws_session_tx, rx).await; + //websocket::register_recv_ws_message_handling(ws_server2, ws_session_rx, session_id).await; // async task to send messages to the web socket connection - websocket::register_send_to_ws_message_handling(ws_session_tx, rx).await; + //websocket::register_send_to_ws_message_handling(ws_session_tx, rx).await; self.add_session(session).await; } #[tracing::instrument(level = "trace", skip(self, request))] async fn handle_web_request( - &mut self, + &self, request: Request, ) -> Arc>> { self.endpoints.handle_web_request(request).await @@ -100,10 +101,12 @@ impl Server { /// Removed a web socket session from the server #[tracing::instrument(level = "trace", skip(self))] async fn remove_session(&self, session_id: Uuid) { - let s = { self.sessions.write().await.remove(session_id.borrow()) }; - drop(s); + let (s, len) = { + let mut guard = self.sessions.write().await; + (guard.remove(session_id.borrow()), guard.len()) + }; - let len = { self.sessions.read().await.len() }; + drop(s); tracing::debug!("Current total active sessions: {}", len); } diff --git a/src/websocket.rs b/src/websocket.rs index 2a1dff6..d481fda 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -19,7 +19,58 @@ use tracing::{error, trace, trace_span, warn}; use tracing_futures::Instrument; use uuid::Uuid; +#[tracing::instrument( + level = "trace", + skip(server, ws_session_rx, session_id, ws_session_tx, rx) +)] +pub(crate) async fn mediator( + mut server: Server, + mut ws_session_rx: SplitStream>>, + session_id: impl Into, + mut ws_session_tx: SplitSink>, WebSocketMessage>, + mut rx: UnboundedReceiver, +) { + let session_id = session_id.into(); + tokio::spawn(async move { + loop { + tokio::select! { + Some(result) = rx.recv() => { + { trace!("Sending to websocket: {:?}", result); } + if let Err(e) = ws_session_tx.send(result).await { + error!("Sending to websocket: {:?}", e); + warn!("Dropping channel server -> 'ws_session_rx'"); + return; + } + }, + Some(result) = ws_session_rx.next() => { + match result { + Ok(msg) => server.recv(session_id, msg).await, + Err(e) => { + let error_message = { + let m = format!("Receive from websocket: {:?}", e); + error!("{}", m); + m + }; + + let frame = CloseFrame { + code: CloseCode::Abnormal, + reason: std::borrow::Cow::Owned(error_message), + }; + + { warn!("Dropping channel 'ws_session_rx' -> server::recv()"); } + server.recv(session_id, Message::Close(Some(frame))).await; + return; + } + }; + }, + else => break, + } + } + }); +} + /// Async task to receive messages from the web socket connection +#[allow(dead_code)] pub(crate) async fn register_recv_ws_message_handling( mut server: Server, mut ws_session_rx: SplitStream>>, @@ -54,6 +105,7 @@ pub(crate) async fn register_recv_ws_message_handling( } /// Async task to send messages to the web socket connection +#[allow(dead_code)] pub(crate) async fn register_send_to_ws_message_handling( mut ws_session_tx: SplitSink>, WebSocketMessage>, mut rx: UnboundedReceiver, @@ -168,3 +220,28 @@ impl ConnectionContext { self.query.clone() } } + +#[cfg(test)] +mod tests { + use crate::websocket::{ConnectionContext, WebSocketMessage, WebSocketSession}; + use hyper::HeaderMap; + use std::marker::{Send, Sync}; + use tokio::sync::mpsc::unbounded_channel; + + #[test] + fn web_socket_session_is_sync_and_send() { + let (tx, _) = unbounded_channel::(); + let ctx = ConnectionContext::new( + None, + HeaderMap::new(), + String::with_capacity(0), + String::with_capacity(0), + ); + + let ws = WebSocketSession::new(ctx, tx); + test_it(&ws); + assert!(true); + } + + fn test_it(_ws: &T) {} +}