wsforge_core/middleware/
mod.rs

1//! Middleware system for request/response processing.
2//!
3//! This module provides a flexible middleware chain system that allows you to intercept
4//! and process WebSocket messages before they reach handlers. Middleware can modify messages,
5//! perform authentication, logging, rate limiting, and more.
6//!
7//! # Overview
8//!
9//! The middleware system is built around three core types:
10//! - [`Middleware`] - Trait that all middleware must implement
11//! - [`MiddlewareChain`] - Container that holds and executes middleware in order
12//! - [`Next`] - Represents the next step in the middleware chain
13//!
14//! # Architecture
15//!
16//! ```
17//! Message → Middleware 1 → Middleware 2 → ... → Handler → Response
18//!              ↓              ↓                      ↓
19//!           Next::run     Next::run            Handler::call
20//! ```
21//!
22//! Each middleware can:
23//! - Inspect the incoming message
24//! - Modify the message before passing it forward
25//! - Short-circuit the chain by not calling `next.run()`
26//! - Modify the response after calling `next.run()`
27//! - Handle errors and transform responses
28//!
29//! # Examples
30//!
31//! ## Using Built-in Logger Middleware
32//!
33//! ```
34//! use wsforge::prelude::*;
35//!
36//! async fn echo(msg: Message) -> Result<Message> {
37//!     Ok(msg)
38//! }
39//!
40//! # async fn example() -> Result<()> {
41//! let router = Router::new()
42//!     .layer(LoggerMiddleware::new())
43//!     .default_handler(handler(echo));
44//!
45//! router.listen("127.0.0.1:8080").await?;
46//! # Ok(())
47//! # }
48//! ```
49//!
50//! ## Creating Custom Middleware
51//!
52//! ```
53//! use wsforge::prelude::*;
54//! use async_trait::async_trait;
55//!
56//! struct AuthMiddleware {
57//!     secret: String,
58//! }
59//!
60//! #[async_trait]
61//! impl Middleware for AuthMiddleware {
62//!     async fn handle(
63//!         &self,
64//!         message: Message,
65//!         conn: Connection,
66//!         state: AppState,
67//!         extensions: Extensions,
68//!         mut next: Next,
69//!     ) -> Result<Option<Message>> {
70//!         // Check for auth token in message
71//!         if let Some(text) = message.as_text() {
72//!             if !text.contains(&self.secret) {
73//!                 return Err(Error::custom("Unauthorized"));
74//!             }
75//!         }
76//!
77//!         // Continue to next middleware/handler
78//!         next.run(message, conn, state, extensions).await
79//!     }
80//! }
81//! ```
82//!
83//! ## Function-based Middleware
84//!
85//! ```
86//! use wsforge::prelude::*;
87//!
88//! # async fn example() {
89//! let logging_middleware = from_fn(|msg, conn, state, ext, mut next| async move {
90//!     println!("Before handler: {:?}", msg.as_text());
91//!     let response = next.run(msg, conn, state, ext).await?;
92//!     println!("After handler");
93//!     Ok(response)
94//! });
95//!
96//! // Use in router
97//! // router.layer(logging_middleware);
98//! # }
99//! ```
100//!
101//! ## Chaining Multiple Middleware
102//!
103//! ```
104//! use wsforge::prelude::*;
105//!
106//! # async fn example() -> Result<()> {
107//! let router = Router::new()
108//!     .layer(LoggerMiddleware::new())
109//!     .layer(auth_middleware())
110//!     .layer(rate_limit_middleware())
111//!     .default_handler(handler(my_handler));
112//! # Ok(())
113//! # }
114//! # async fn my_handler() -> Result<String> { Ok("".to_string()) }
115//! # fn auth_middleware() -> Arc<dyn Middleware> { unimplemented!() }
116//! # fn rate_limit_middleware() -> Arc<dyn Middleware> { unimplemented!() }
117//! ```
118
119pub mod logger;
120
121pub use logger::LoggerMiddleware;
122
123use crate::connection::Connection;
124use crate::error::Result;
125use crate::extractor::Extensions;
126use crate::message::Message;
127use crate::state::AppState;
128use async_trait::async_trait;
129use std::sync::Arc;
130
131/// Represents the next middleware or handler in the chain.
132///
133/// `Next` is used to pass control to the next step in the middleware pipeline.
134/// When a middleware calls `next.run()`, it invokes the next middleware or,
135/// if there are no more middleware, the final handler.
136///
137/// # Examples
138///
139/// ```
140/// use wsforge::prelude::*;
141/// use async_trait::async_trait;
142///
143/// struct MyMiddleware;
144///
145/// #[async_trait]
146/// impl Middleware for MyMiddleware {
147///     async fn handle(
148///         &self,
149///         message: Message,
150///         conn: Connection,
151///         state: AppState,
152///         extensions: Extensions,
153///         mut next: Next,
154///     ) -> Result<Option<Message>> {
155///         println!("Before next");
156///
157///         // Call the next middleware/handler
158///         let response = next.run(message, conn, state, extensions).await?;
159///
160///         println!("After next");
161///         Ok(response)
162///     }
163/// }
164/// ```
165pub struct Next {
166    chain: Arc<MiddlewareChain>,
167    index: usize,
168}
169
170impl Next {
171    /// Creates a new `Next` instance.
172    ///
173    /// # Arguments
174    ///
175    /// * `chain` - The middleware chain to execute
176    /// * `index` - Current position in the chain
177    pub fn new(chain: Arc<MiddlewareChain>, index: usize) -> Self {
178        Self { chain, index }
179    }
180
181    /// Call the next middleware in the chain.
182    ///
183    /// This method executes the next middleware in the sequence. If all middleware
184    /// have been executed, it calls the final handler.
185    ///
186    /// # Arguments
187    ///
188    /// * `message` - The WebSocket message being processed
189    /// * `conn` - The connection that sent the message
190    /// * `state` - Application state
191    /// * `extensions` - Request-scoped extension data
192    ///
193    /// # Returns
194    ///
195    /// Returns the response from the next middleware or handler, or `None` if
196    /// no response should be sent.
197    ///
198    /// # Examples
199    ///
200    /// ```
201    /// use wsforge::prelude::*;
202    /// use async_trait::async_trait;
203    ///
204    /// struct TimingMiddleware;
205    ///
206    /// #[async_trait]
207    /// impl Middleware for TimingMiddleware {
208    ///     async fn handle(
209    ///         &self,
210    ///         message: Message,
211    ///         conn: Connection,
212    ///         state: AppState,
213    ///         extensions: Extensions,
214    ///         mut next: Next,
215    ///     ) -> Result<Option<Message>> {
216    ///         let start = std::time::Instant::now();
217    ///
218    ///         let response = next.run(message, conn, state, extensions).await?;
219    ///
220    ///         let duration = start.elapsed();
221    ///         println!("Request took: {:?}", duration);
222    ///
223    ///         Ok(response)
224    ///     }
225    /// }
226    /// ```
227    pub async fn run(
228        mut self,
229        message: Message,
230        conn: Connection,
231        state: AppState,
232        extensions: Extensions,
233    ) -> Result<Option<Message>> {
234        if self.index < self.chain.middlewares.len() {
235            let middleware = self.chain.middlewares[self.index].clone();
236            self.index += 1;
237            middleware
238                .handle(message, conn, state, extensions, self)
239                .await
240        } else if let Some(ref handler) = self.chain.handler {
241            handler.call(message, conn, state, extensions).await
242        } else {
243            Ok(None)
244        }
245    }
246}
247
248/// Middleware trait that all middleware must implement.
249///
250/// Middleware can intercept messages before they reach handlers, perform
251/// transformations, add metadata to extensions, or short-circuit the request.
252///
253/// # Implementation Guidelines
254///
255/// - **Always call `next.run()`** unless you want to short-circuit
256/// - **Use extensions** to pass data to handlers or other middleware
257/// - **Handle errors gracefully** and provide clear error messages
258/// - **Be mindful of performance** - middleware runs on every message
259///
260/// # Examples
261///
262/// ## Authentication Middleware
263///
264/// ```
265/// use wsforge::prelude::*;
266/// use async_trait::async_trait;
267///
268/// struct AuthMiddleware {
269///     required_token: String,
270/// }
271///
272/// #[async_trait]
273/// impl Middleware for AuthMiddleware {
274///     async fn handle(
275///         &self,
276///         message: Message,
277///         conn: Connection,
278///         state: AppState,
279///         extensions: Extensions,
280///         mut next: Next,
281///     ) -> Result<Option<Message>> {
282///         if let Some(text) = message.as_text() {
283///             if let Some(token) = text.strip_prefix("TOKEN:") {
284///                 if token == self.required_token {
285///                     extensions.insert("authenticated", true);
286///                     return next.run(message, conn, state, extensions).await;
287///                 }
288///             }
289///         }
290///
291///         Err(Error::custom("Unauthorized"))
292///     }
293/// }
294/// ```
295///
296/// ## Rate Limiting Middleware
297///
298/// ```
299/// use wsforge::prelude::*;
300/// use async_trait::async_trait;
301/// use std::sync::Arc;
302/// use tokio::sync::RwLock;
303/// use std::collections::HashMap;
304///
305/// struct RateLimitMiddleware {
306///     limits: Arc<RwLock<HashMap<String, u32>>>,
307///     max_requests: u32,
308/// }
309///
310/// #[async_trait]
311/// impl Middleware for RateLimitMiddleware {
312///     async fn handle(
313///         &self,
314///         message: Message,
315///         conn: Connection,
316///         state: AppState,
317///         extensions: Extensions,
318///         mut next: Next,
319///     ) -> Result<Option<Message>> {
320///         let conn_id = conn.id();
321///         let mut limits = self.limits.write().await;
322///         let count = limits.entry(conn_id.clone()).or_insert(0);
323///
324///         if *count >= self.max_requests {
325///             return Err(Error::custom("Rate limit exceeded"));
326///         }
327///
328///         *count += 1;
329///         drop(limits);
330///
331///         next.run(message, conn, state, extensions).await
332///     }
333/// }
334/// ```
335///
336/// ## Request ID Middleware
337///
338/// ```
339/// use wsforge::prelude::*;
340/// use async_trait::async_trait;
341///
342/// struct RequestIdMiddleware;
343///
344/// #[async_trait]
345/// impl Middleware for RequestIdMiddleware {
346///     async fn handle(
347///         &self,
348///         message: Message,
349///         conn: Connection,
350///         state: AppState,
351///         extensions: Extensions,
352///         mut next: Next,
353///     ) -> Result<Option<Message>> {
354///         use std::sync::atomic::{AtomicU64, Ordering};
355///         static COUNTER: AtomicU64 = AtomicU64::new(0);
356///
357///         let request_id = COUNTER.fetch_add(1, Ordering::SeqCst);
358///         extensions.insert("request_id", request_id);
359///
360///         next.run(message, conn, state, extensions).await
361///     }
362/// }
363/// ```
364#[async_trait]
365pub trait Middleware: Send + Sync + 'static {
366    /// Handle a message and optionally pass it to the next middleware.
367    ///
368    /// # Arguments
369    ///
370    /// * `message` - The incoming WebSocket message
371    /// * `conn` - The connection that sent the message
372    /// * `state` - Application state
373    /// * `extensions` - Request-scoped extension data
374    /// * `next` - The next step in the middleware chain
375    ///
376    /// # Returns
377    ///
378    /// Returns an optional message to send back to the client, or an error.
379    async fn handle(
380        &self,
381        message: Message,
382        conn: Connection,
383        state: AppState,
384        extensions: Extensions,
385        next: Next,
386    ) -> Result<Option<Message>>;
387}
388
389/// Middleware chain holds all middlewares and the final handler.
390///
391/// The chain executes middleware in the order they were added, and finally
392/// calls the handler if all middleware pass control forward.
393///
394/// # Examples
395///
396/// ```
397/// use wsforge::prelude::*;
398///
399/// # fn example() {
400/// let mut chain = MiddlewareChain::new();
401///
402/// // Add middleware
403/// chain.layer(LoggerMiddleware::new());
404///
405/// // Set final handler
406/// chain.handler(handler(my_handler));
407/// # }
408/// # async fn my_handler() -> Result<String> { Ok("".to_string()) }
409/// ```
410#[derive(Clone)]
411pub struct MiddlewareChain {
412    /// All middleware in the chain, executed in order
413    pub middlewares: Vec<Arc<dyn Middleware>>,
414    /// The final handler to call after all middleware
415    pub handler: Option<Arc<dyn crate::handler::Handler>>,
416}
417
418impl MiddlewareChain {
419    /// Creates a new empty middleware chain.
420    ///
421    /// # Examples
422    ///
423    /// ```
424    /// use wsforge::prelude::*;
425    ///
426    /// let chain = MiddlewareChain::new();
427    /// ```
428    pub fn new() -> Self {
429        Self {
430            middlewares: Vec::new(),
431            handler: None,
432        }
433    }
434
435    /// Add a middleware to the chain.
436    ///
437    /// Middleware are executed in the order they are added.
438    ///
439    /// # Arguments
440    ///
441    /// * `middleware` - The middleware to add
442    ///
443    /// # Examples
444    ///
445    /// ```
446    /// use wsforge::prelude::*;
447    ///
448    /// # fn example() {
449    /// let mut chain = MiddlewareChain::new();
450    ///
451    /// chain.layer(LoggerMiddleware::new());
452    /// # }
453    /// ```
454    pub fn layer(mut self, middleware: Arc<dyn Middleware>) -> Self {
455        self.middlewares.push(middleware);
456        self
457    }
458
459    /// Set the final handler for the chain.
460    ///
461    /// The handler is called after all middleware have been executed.
462    ///
463    /// # Arguments
464    ///
465    /// * `handler` - The handler to call
466    ///
467    /// # Examples
468    ///
469    /// ```
470    /// use wsforge::prelude::*;
471    ///
472    /// async fn my_handler(msg: Message) -> Result<String> {
473    ///     Ok("response".to_string())
474    /// }
475    ///
476    /// # fn example() {
477    /// let mut chain = MiddlewareChain::new();
478    /// chain.handler(handler(my_handler));
479    /// # }
480    /// ```
481    pub fn handler(mut self, handler: Arc<dyn crate::handler::Handler>) -> Self {
482        self.handler = Some(handler);
483        self
484    }
485
486    /// Execute the middleware chain.
487    ///
488    /// This runs all middleware in order, then calls the handler if present.
489    ///
490    /// # Arguments
491    ///
492    /// * `message` - The message to process
493    /// * `conn` - The connection
494    /// * `state` - Application state
495    /// * `extensions` - Extension data
496    ///
497    /// # Examples
498    ///
499    /// ```
500    /// use wsforge::prelude::*;
501    ///
502    /// # async fn example(chain: MiddlewareChain, msg: Message, conn: Connection) -> Result<()> {
503    /// let state = AppState::new();
504    /// let extensions = Extensions::new();
505    ///
506    /// let response = chain.execute(msg, conn, state, extensions).await?;
507    /// # Ok(())
508    /// # }
509    /// ```
510    pub async fn execute(
511        &self,
512        message: Message,
513        conn: Connection,
514        state: AppState,
515        extensions: Extensions,
516    ) -> Result<Option<Message>> {
517        let next = Next::new(Arc::new(self.clone()), 0);
518        next.run(message, conn, state, extensions).await
519    }
520}
521
522impl Default for MiddlewareChain {
523    fn default() -> Self {
524        Self::new()
525    }
526}
527
528/// Function-based Middleware
529///
530/// Helper to create middleware from async functions without implementing
531/// the full `Middleware` trait.
532pub struct FnMiddleware<F> {
533    func: F,
534}
535
536impl<F> FnMiddleware<F> {
537    /// Creates a new function-based middleware.
538    ///
539    /// # Examples
540    ///
541    /// ```
542    /// use wsforge::prelude::*;
543    ///
544    /// # fn example() {
545    /// let middleware = FnMiddleware::new(|msg, conn, state, ext, mut next| async move {
546    ///     println!("Before handler");
547    ///     let response = next.run(msg, conn, state, ext).await?;
548    ///     println!("After handler");
549    ///     Ok(response)
550    /// });
551    /// # }
552    /// ```
553    pub fn new(func: F) -> Arc<Self> {
554        Arc::new(Self { func })
555    }
556}
557
558#[async_trait]
559impl<F, Fut> Middleware for FnMiddleware<F>
560where
561    F: Fn(Message, Connection, AppState, Extensions, Next) -> Fut + Send + Sync + 'static,
562    Fut: std::future::Future<Output = Result<Option<Message>>> + Send + 'static,
563{
564    async fn handle(
565        &self,
566        message: Message,
567        conn: Connection,
568        state: AppState,
569        extensions: Extensions,
570        next: Next,
571    ) -> Result<Option<Message>> {
572        (self.func)(message, conn, state, extensions, next).await
573    }
574}
575
576/// Helper function to create middleware from async functions.
577///
578/// This is a convenience function that wraps an async function in a middleware.
579///
580/// # Arguments
581///
582/// * `f` - Async function with signature matching middleware requirements
583///
584/// # Examples
585///
586/// ## Simple Logging
587///
588/// ```
589/// use wsforge::prelude::*;
590///
591/// # fn example() {
592/// let logging = from_fn(|msg, conn, state, ext, mut next| async move {
593///     println!("Processing message from {}", conn.id());
594///     next.run(msg, conn, state, ext).await
595/// });
596/// # }
597/// ```
598///
599/// ## With State Access
600///
601/// ```
602/// use wsforge::prelude::*;
603/// use std::sync::Arc;
604///
605/// # fn example() {
606/// let counter = from_fn(|msg, conn, state, ext, mut next| async move {
607///     // Access state
608///     if let Some(counter) = state.get::<Arc<std::sync::atomic::AtomicU64>>() {
609///         counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
610///     }
611///     next.run(msg, conn, state, ext).await
612/// });
613/// # }
614/// ```
615pub fn from_fn<F, Fut>(f: F) -> Arc<FnMiddleware<F>>
616where
617    F: Fn(Message, Connection, AppState, Extensions, Next) -> Fut + Send + Sync + 'static,
618    Fut: std::future::Future<Output = Result<Option<Message>>> + Send + 'static,
619{
620    FnMiddleware::new(f)
621}