wsforge_core/middleware/mod.rs
1//! Middleware system for request/response processing.
2//!
3//! This module provides a flexible middleware chain system that allows you to intercept
4//! and process WebSocket messages before they reach handlers. Middleware can modify messages,
5//! perform authentication, logging, rate limiting, and more.
6//!
7//! # Overview
8//!
9//! The middleware system is built around three core types:
10//! - [`Middleware`] - Trait that all middleware must implement
11//! - [`MiddlewareChain`] - Container that holds and executes middleware in order
12//! - [`Next`] - Represents the next step in the middleware chain
13//!
14//! # Architecture
15//!
16//! ```
17//! Message → Middleware 1 → Middleware 2 → ... → Handler → Response
18//! ↓ ↓ ↓
19//! Next::run Next::run Handler::call
20//! ```
21//!
22//! Each middleware can:
23//! - Inspect the incoming message
24//! - Modify the message before passing it forward
25//! - Short-circuit the chain by not calling `next.run()`
26//! - Modify the response after calling `next.run()`
27//! - Handle errors and transform responses
28//!
29//! # Examples
30//!
31//! ## Using Built-in Logger Middleware
32//!
33//! ```
34//! use wsforge::prelude::*;
35//!
36//! async fn echo(msg: Message) -> Result<Message> {
37//! Ok(msg)
38//! }
39//!
40//! # async fn example() -> Result<()> {
41//! let router = Router::new()
42//! .layer(LoggerMiddleware::new())
43//! .default_handler(handler(echo));
44//!
45//! router.listen("127.0.0.1:8080").await?;
46//! # Ok(())
47//! # }
48//! ```
49//!
50//! ## Creating Custom Middleware
51//!
52//! ```
53//! use wsforge::prelude::*;
54//! use async_trait::async_trait;
55//!
56//! struct AuthMiddleware {
57//! secret: String,
58//! }
59//!
60//! #[async_trait]
61//! impl Middleware for AuthMiddleware {
62//! async fn handle(
63//! &self,
64//! message: Message,
65//! conn: Connection,
66//! state: AppState,
67//! extensions: Extensions,
68//! mut next: Next,
69//! ) -> Result<Option<Message>> {
70//! // Check for auth token in message
71//! if let Some(text) = message.as_text() {
72//! if !text.contains(&self.secret) {
73//! return Err(Error::custom("Unauthorized"));
74//! }
75//! }
76//!
77//! // Continue to next middleware/handler
78//! next.run(message, conn, state, extensions).await
79//! }
80//! }
81//! ```
82//!
83//! ## Function-based Middleware
84//!
85//! ```
86//! use wsforge::prelude::*;
87//!
88//! # async fn example() {
89//! let logging_middleware = from_fn(|msg, conn, state, ext, mut next| async move {
90//! println!("Before handler: {:?}", msg.as_text());
91//! let response = next.run(msg, conn, state, ext).await?;
92//! println!("After handler");
93//! Ok(response)
94//! });
95//!
96//! // Use in router
97//! // router.layer(logging_middleware);
98//! # }
99//! ```
100//!
101//! ## Chaining Multiple Middleware
102//!
103//! ```
104//! use wsforge::prelude::*;
105//!
106//! # async fn example() -> Result<()> {
107//! let router = Router::new()
108//! .layer(LoggerMiddleware::new())
109//! .layer(auth_middleware())
110//! .layer(rate_limit_middleware())
111//! .default_handler(handler(my_handler));
112//! # Ok(())
113//! # }
114//! # async fn my_handler() -> Result<String> { Ok("".to_string()) }
115//! # fn auth_middleware() -> Arc<dyn Middleware> { unimplemented!() }
116//! # fn rate_limit_middleware() -> Arc<dyn Middleware> { unimplemented!() }
117//! ```
118
119pub mod logger;
120
121pub use logger::LoggerMiddleware;
122
123use crate::connection::Connection;
124use crate::error::Result;
125use crate::extractor::Extensions;
126use crate::message::Message;
127use crate::state::AppState;
128use async_trait::async_trait;
129use std::sync::Arc;
130
131/// Represents the next middleware or handler in the chain.
132///
133/// `Next` is used to pass control to the next step in the middleware pipeline.
134/// When a middleware calls `next.run()`, it invokes the next middleware or,
135/// if there are no more middleware, the final handler.
136///
137/// # Examples
138///
139/// ```
140/// use wsforge::prelude::*;
141/// use async_trait::async_trait;
142///
143/// struct MyMiddleware;
144///
145/// #[async_trait]
146/// impl Middleware for MyMiddleware {
147/// async fn handle(
148/// &self,
149/// message: Message,
150/// conn: Connection,
151/// state: AppState,
152/// extensions: Extensions,
153/// mut next: Next,
154/// ) -> Result<Option<Message>> {
155/// println!("Before next");
156///
157/// // Call the next middleware/handler
158/// let response = next.run(message, conn, state, extensions).await?;
159///
160/// println!("After next");
161/// Ok(response)
162/// }
163/// }
164/// ```
165pub struct Next {
166 chain: Arc<MiddlewareChain>,
167 index: usize,
168}
169
170impl Next {
171 /// Creates a new `Next` instance.
172 ///
173 /// # Arguments
174 ///
175 /// * `chain` - The middleware chain to execute
176 /// * `index` - Current position in the chain
177 pub fn new(chain: Arc<MiddlewareChain>, index: usize) -> Self {
178 Self { chain, index }
179 }
180
181 /// Call the next middleware in the chain.
182 ///
183 /// This method executes the next middleware in the sequence. If all middleware
184 /// have been executed, it calls the final handler.
185 ///
186 /// # Arguments
187 ///
188 /// * `message` - The WebSocket message being processed
189 /// * `conn` - The connection that sent the message
190 /// * `state` - Application state
191 /// * `extensions` - Request-scoped extension data
192 ///
193 /// # Returns
194 ///
195 /// Returns the response from the next middleware or handler, or `None` if
196 /// no response should be sent.
197 ///
198 /// # Examples
199 ///
200 /// ```
201 /// use wsforge::prelude::*;
202 /// use async_trait::async_trait;
203 ///
204 /// struct TimingMiddleware;
205 ///
206 /// #[async_trait]
207 /// impl Middleware for TimingMiddleware {
208 /// async fn handle(
209 /// &self,
210 /// message: Message,
211 /// conn: Connection,
212 /// state: AppState,
213 /// extensions: Extensions,
214 /// mut next: Next,
215 /// ) -> Result<Option<Message>> {
216 /// let start = std::time::Instant::now();
217 ///
218 /// let response = next.run(message, conn, state, extensions).await?;
219 ///
220 /// let duration = start.elapsed();
221 /// println!("Request took: {:?}", duration);
222 ///
223 /// Ok(response)
224 /// }
225 /// }
226 /// ```
227 pub async fn run(
228 mut self,
229 message: Message,
230 conn: Connection,
231 state: AppState,
232 extensions: Extensions,
233 ) -> Result<Option<Message>> {
234 if self.index < self.chain.middlewares.len() {
235 let middleware = self.chain.middlewares[self.index].clone();
236 self.index += 1;
237 middleware
238 .handle(message, conn, state, extensions, self)
239 .await
240 } else if let Some(ref handler) = self.chain.handler {
241 handler.call(message, conn, state, extensions).await
242 } else {
243 Ok(None)
244 }
245 }
246}
247
248/// Middleware trait that all middleware must implement.
249///
250/// Middleware can intercept messages before they reach handlers, perform
251/// transformations, add metadata to extensions, or short-circuit the request.
252///
253/// # Implementation Guidelines
254///
255/// - **Always call `next.run()`** unless you want to short-circuit
256/// - **Use extensions** to pass data to handlers or other middleware
257/// - **Handle errors gracefully** and provide clear error messages
258/// - **Be mindful of performance** - middleware runs on every message
259///
260/// # Examples
261///
262/// ## Authentication Middleware
263///
264/// ```
265/// use wsforge::prelude::*;
266/// use async_trait::async_trait;
267///
268/// struct AuthMiddleware {
269/// required_token: String,
270/// }
271///
272/// #[async_trait]
273/// impl Middleware for AuthMiddleware {
274/// async fn handle(
275/// &self,
276/// message: Message,
277/// conn: Connection,
278/// state: AppState,
279/// extensions: Extensions,
280/// mut next: Next,
281/// ) -> Result<Option<Message>> {
282/// if let Some(text) = message.as_text() {
283/// if let Some(token) = text.strip_prefix("TOKEN:") {
284/// if token == self.required_token {
285/// extensions.insert("authenticated", true);
286/// return next.run(message, conn, state, extensions).await;
287/// }
288/// }
289/// }
290///
291/// Err(Error::custom("Unauthorized"))
292/// }
293/// }
294/// ```
295///
296/// ## Rate Limiting Middleware
297///
298/// ```
299/// use wsforge::prelude::*;
300/// use async_trait::async_trait;
301/// use std::sync::Arc;
302/// use tokio::sync::RwLock;
303/// use std::collections::HashMap;
304///
305/// struct RateLimitMiddleware {
306/// limits: Arc<RwLock<HashMap<String, u32>>>,
307/// max_requests: u32,
308/// }
309///
310/// #[async_trait]
311/// impl Middleware for RateLimitMiddleware {
312/// async fn handle(
313/// &self,
314/// message: Message,
315/// conn: Connection,
316/// state: AppState,
317/// extensions: Extensions,
318/// mut next: Next,
319/// ) -> Result<Option<Message>> {
320/// let conn_id = conn.id();
321/// let mut limits = self.limits.write().await;
322/// let count = limits.entry(conn_id.clone()).or_insert(0);
323///
324/// if *count >= self.max_requests {
325/// return Err(Error::custom("Rate limit exceeded"));
326/// }
327///
328/// *count += 1;
329/// drop(limits);
330///
331/// next.run(message, conn, state, extensions).await
332/// }
333/// }
334/// ```
335///
336/// ## Request ID Middleware
337///
338/// ```
339/// use wsforge::prelude::*;
340/// use async_trait::async_trait;
341///
342/// struct RequestIdMiddleware;
343///
344/// #[async_trait]
345/// impl Middleware for RequestIdMiddleware {
346/// async fn handle(
347/// &self,
348/// message: Message,
349/// conn: Connection,
350/// state: AppState,
351/// extensions: Extensions,
352/// mut next: Next,
353/// ) -> Result<Option<Message>> {
354/// use std::sync::atomic::{AtomicU64, Ordering};
355/// static COUNTER: AtomicU64 = AtomicU64::new(0);
356///
357/// let request_id = COUNTER.fetch_add(1, Ordering::SeqCst);
358/// extensions.insert("request_id", request_id);
359///
360/// next.run(message, conn, state, extensions).await
361/// }
362/// }
363/// ```
364#[async_trait]
365pub trait Middleware: Send + Sync + 'static {
366 /// Handle a message and optionally pass it to the next middleware.
367 ///
368 /// # Arguments
369 ///
370 /// * `message` - The incoming WebSocket message
371 /// * `conn` - The connection that sent the message
372 /// * `state` - Application state
373 /// * `extensions` - Request-scoped extension data
374 /// * `next` - The next step in the middleware chain
375 ///
376 /// # Returns
377 ///
378 /// Returns an optional message to send back to the client, or an error.
379 async fn handle(
380 &self,
381 message: Message,
382 conn: Connection,
383 state: AppState,
384 extensions: Extensions,
385 next: Next,
386 ) -> Result<Option<Message>>;
387}
388
389/// Middleware chain holds all middlewares and the final handler.
390///
391/// The chain executes middleware in the order they were added, and finally
392/// calls the handler if all middleware pass control forward.
393///
394/// # Examples
395///
396/// ```
397/// use wsforge::prelude::*;
398///
399/// # fn example() {
400/// let mut chain = MiddlewareChain::new();
401///
402/// // Add middleware
403/// chain.layer(LoggerMiddleware::new());
404///
405/// // Set final handler
406/// chain.handler(handler(my_handler));
407/// # }
408/// # async fn my_handler() -> Result<String> { Ok("".to_string()) }
409/// ```
410#[derive(Clone)]
411pub struct MiddlewareChain {
412 /// All middleware in the chain, executed in order
413 pub middlewares: Vec<Arc<dyn Middleware>>,
414 /// The final handler to call after all middleware
415 pub handler: Option<Arc<dyn crate::handler::Handler>>,
416}
417
418impl MiddlewareChain {
419 /// Creates a new empty middleware chain.
420 ///
421 /// # Examples
422 ///
423 /// ```
424 /// use wsforge::prelude::*;
425 ///
426 /// let chain = MiddlewareChain::new();
427 /// ```
428 pub fn new() -> Self {
429 Self {
430 middlewares: Vec::new(),
431 handler: None,
432 }
433 }
434
435 /// Add a middleware to the chain.
436 ///
437 /// Middleware are executed in the order they are added.
438 ///
439 /// # Arguments
440 ///
441 /// * `middleware` - The middleware to add
442 ///
443 /// # Examples
444 ///
445 /// ```
446 /// use wsforge::prelude::*;
447 ///
448 /// # fn example() {
449 /// let mut chain = MiddlewareChain::new();
450 ///
451 /// chain.layer(LoggerMiddleware::new());
452 /// # }
453 /// ```
454 pub fn layer(mut self, middleware: Arc<dyn Middleware>) -> Self {
455 self.middlewares.push(middleware);
456 self
457 }
458
459 /// Set the final handler for the chain.
460 ///
461 /// The handler is called after all middleware have been executed.
462 ///
463 /// # Arguments
464 ///
465 /// * `handler` - The handler to call
466 ///
467 /// # Examples
468 ///
469 /// ```
470 /// use wsforge::prelude::*;
471 ///
472 /// async fn my_handler(msg: Message) -> Result<String> {
473 /// Ok("response".to_string())
474 /// }
475 ///
476 /// # fn example() {
477 /// let mut chain = MiddlewareChain::new();
478 /// chain.handler(handler(my_handler));
479 /// # }
480 /// ```
481 pub fn handler(mut self, handler: Arc<dyn crate::handler::Handler>) -> Self {
482 self.handler = Some(handler);
483 self
484 }
485
486 /// Execute the middleware chain.
487 ///
488 /// This runs all middleware in order, then calls the handler if present.
489 ///
490 /// # Arguments
491 ///
492 /// * `message` - The message to process
493 /// * `conn` - The connection
494 /// * `state` - Application state
495 /// * `extensions` - Extension data
496 ///
497 /// # Examples
498 ///
499 /// ```
500 /// use wsforge::prelude::*;
501 ///
502 /// # async fn example(chain: MiddlewareChain, msg: Message, conn: Connection) -> Result<()> {
503 /// let state = AppState::new();
504 /// let extensions = Extensions::new();
505 ///
506 /// let response = chain.execute(msg, conn, state, extensions).await?;
507 /// # Ok(())
508 /// # }
509 /// ```
510 pub async fn execute(
511 &self,
512 message: Message,
513 conn: Connection,
514 state: AppState,
515 extensions: Extensions,
516 ) -> Result<Option<Message>> {
517 let next = Next::new(Arc::new(self.clone()), 0);
518 next.run(message, conn, state, extensions).await
519 }
520}
521
522impl Default for MiddlewareChain {
523 fn default() -> Self {
524 Self::new()
525 }
526}
527
528/// Function-based Middleware
529///
530/// Helper to create middleware from async functions without implementing
531/// the full `Middleware` trait.
532pub struct FnMiddleware<F> {
533 func: F,
534}
535
536impl<F> FnMiddleware<F> {
537 /// Creates a new function-based middleware.
538 ///
539 /// # Examples
540 ///
541 /// ```
542 /// use wsforge::prelude::*;
543 ///
544 /// # fn example() {
545 /// let middleware = FnMiddleware::new(|msg, conn, state, ext, mut next| async move {
546 /// println!("Before handler");
547 /// let response = next.run(msg, conn, state, ext).await?;
548 /// println!("After handler");
549 /// Ok(response)
550 /// });
551 /// # }
552 /// ```
553 pub fn new(func: F) -> Arc<Self> {
554 Arc::new(Self { func })
555 }
556}
557
558#[async_trait]
559impl<F, Fut> Middleware for FnMiddleware<F>
560where
561 F: Fn(Message, Connection, AppState, Extensions, Next) -> Fut + Send + Sync + 'static,
562 Fut: std::future::Future<Output = Result<Option<Message>>> + Send + 'static,
563{
564 async fn handle(
565 &self,
566 message: Message,
567 conn: Connection,
568 state: AppState,
569 extensions: Extensions,
570 next: Next,
571 ) -> Result<Option<Message>> {
572 (self.func)(message, conn, state, extensions, next).await
573 }
574}
575
576/// Helper function to create middleware from async functions.
577///
578/// This is a convenience function that wraps an async function in a middleware.
579///
580/// # Arguments
581///
582/// * `f` - Async function with signature matching middleware requirements
583///
584/// # Examples
585///
586/// ## Simple Logging
587///
588/// ```
589/// use wsforge::prelude::*;
590///
591/// # fn example() {
592/// let logging = from_fn(|msg, conn, state, ext, mut next| async move {
593/// println!("Processing message from {}", conn.id());
594/// next.run(msg, conn, state, ext).await
595/// });
596/// # }
597/// ```
598///
599/// ## With State Access
600///
601/// ```
602/// use wsforge::prelude::*;
603/// use std::sync::Arc;
604///
605/// # fn example() {
606/// let counter = from_fn(|msg, conn, state, ext, mut next| async move {
607/// // Access state
608/// if let Some(counter) = state.get::<Arc<std::sync::atomic::AtomicU64>>() {
609/// counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
610/// }
611/// next.run(msg, conn, state, ext).await
612/// });
613/// # }
614/// ```
615pub fn from_fn<F, Fut>(f: F) -> Arc<FnMiddleware<F>>
616where
617 F: Fn(Message, Connection, AppState, Extensions, Next) -> Fut + Send + Sync + 'static,
618 Fut: std::future::Future<Output = Result<Option<Message>>> + Send + 'static,
619{
620 FnMiddleware::new(f)
621}