Middleware

Trait Middleware 

Source
pub trait Middleware:
    Send
    + Sync
    + 'static {
    // Required method
    fn handle<'life0, 'async_trait>(
        &'life0 self,
        message: Message,
        conn: Connection,
        state: AppState,
        extensions: Extensions,
        next: Next,
    ) -> Pin<Box<dyn Future<Output = Result<Option<Message>>> + Send + 'async_trait>>
       where Self: 'async_trait,
             'life0: 'async_trait;
}
Expand description

Middleware trait that all middleware must implement.

Middleware can intercept messages before they reach handlers, perform transformations, add metadata to extensions, or short-circuit the request.

§Implementation Guidelines

  • Always call next.run() unless you want to short-circuit
  • Use extensions to pass data to handlers or other middleware
  • Handle errors gracefully and provide clear error messages
  • Be mindful of performance - middleware runs on every message

§Examples

§Authentication Middleware

use wsforge::prelude::*;
use async_trait::async_trait;

struct AuthMiddleware {
    required_token: String,
}

#[async_trait]
impl Middleware for AuthMiddleware {
    async fn handle(
        &self,
        message: Message,
        conn: Connection,
        state: AppState,
        extensions: Extensions,
        mut next: Next,
    ) -> Result<Option<Message>> {
        if let Some(text) = message.as_text() {
            if let Some(token) = text.strip_prefix("TOKEN:") {
                if token == self.required_token {
                    extensions.insert("authenticated", true);
                    return next.run(message, conn, state, extensions).await;
                }
            }
        }

        Err(Error::custom("Unauthorized"))
    }
}

§Rate Limiting Middleware

use wsforge::prelude::*;
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::RwLock;
use std::collections::HashMap;

struct RateLimitMiddleware {
    limits: Arc<RwLock<HashMap<String, u32>>>,
    max_requests: u32,
}

#[async_trait]
impl Middleware for RateLimitMiddleware {
    async fn handle(
        &self,
        message: Message,
        conn: Connection,
        state: AppState,
        extensions: Extensions,
        mut next: Next,
    ) -> Result<Option<Message>> {
        let conn_id = conn.id();
        let mut limits = self.limits.write().await;
        let count = limits.entry(conn_id.clone()).or_insert(0);

        if *count >= self.max_requests {
            return Err(Error::custom("Rate limit exceeded"));
        }

        *count += 1;
        drop(limits);

        next.run(message, conn, state, extensions).await
    }
}

§Request ID Middleware

use wsforge::prelude::*;
use async_trait::async_trait;

struct RequestIdMiddleware;

#[async_trait]
impl Middleware for RequestIdMiddleware {
    async fn handle(
        &self,
        message: Message,
        conn: Connection,
        state: AppState,
        extensions: Extensions,
        mut next: Next,
    ) -> Result<Option<Message>> {
        use std::sync::atomic::{AtomicU64, Ordering};
        static COUNTER: AtomicU64 = AtomicU64::new(0);

        let request_id = COUNTER.fetch_add(1, Ordering::SeqCst);
        extensions.insert("request_id", request_id);

        next.run(message, conn, state, extensions).await
    }
}

Required Methods§

Source

fn handle<'life0, 'async_trait>( &'life0 self, message: Message, conn: Connection, state: AppState, extensions: Extensions, next: Next, ) -> Pin<Box<dyn Future<Output = Result<Option<Message>>> + Send + 'async_trait>>
where Self: 'async_trait, 'life0: 'async_trait,

Handle a message and optionally pass it to the next middleware.

§Arguments
  • message - The incoming WebSocket message
  • conn - The connection that sent the message
  • state - Application state
  • extensions - Request-scoped extension data
  • next - The next step in the middleware chain
§Returns

Returns an optional message to send back to the client, or an error.

Implementors§

Source§

impl Middleware for LoggerMiddleware

Source§

impl<F, Fut> Middleware for FnMiddleware<F>
where F: Fn(Message, Connection, AppState, Extensions, Next) -> Fut + Send + Sync + 'static, Fut: Future<Output = Result<Option<Message>>> + Send + 'static,