pub trait Middleware:
Send
+ Sync
+ 'static {
// Required method
fn handle<'life0, 'async_trait>(
&'life0 self,
message: Message,
conn: Connection,
state: AppState,
extensions: Extensions,
next: Next,
) -> Pin<Box<dyn Future<Output = Result<Option<Message>>> + Send + 'async_trait>>
where Self: 'async_trait,
'life0: 'async_trait;
}Expand description
Middleware trait that all middleware must implement.
Middleware can intercept messages before they reach handlers, perform transformations, add metadata to extensions, or short-circuit the request.
§Implementation Guidelines
- Always call
next.run()unless you want to short-circuit - Use extensions to pass data to handlers or other middleware
- Handle errors gracefully and provide clear error messages
- Be mindful of performance - middleware runs on every message
§Examples
§Authentication Middleware
use wsforge::prelude::*;
use async_trait::async_trait;
struct AuthMiddleware {
required_token: String,
}
#[async_trait]
impl Middleware for AuthMiddleware {
async fn handle(
&self,
message: Message,
conn: Connection,
state: AppState,
extensions: Extensions,
mut next: Next,
) -> Result<Option<Message>> {
if let Some(text) = message.as_text() {
if let Some(token) = text.strip_prefix("TOKEN:") {
if token == self.required_token {
extensions.insert("authenticated", true);
return next.run(message, conn, state, extensions).await;
}
}
}
Err(Error::custom("Unauthorized"))
}
}§Rate Limiting Middleware
use wsforge::prelude::*;
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::RwLock;
use std::collections::HashMap;
struct RateLimitMiddleware {
limits: Arc<RwLock<HashMap<String, u32>>>,
max_requests: u32,
}
#[async_trait]
impl Middleware for RateLimitMiddleware {
async fn handle(
&self,
message: Message,
conn: Connection,
state: AppState,
extensions: Extensions,
mut next: Next,
) -> Result<Option<Message>> {
let conn_id = conn.id();
let mut limits = self.limits.write().await;
let count = limits.entry(conn_id.clone()).or_insert(0);
if *count >= self.max_requests {
return Err(Error::custom("Rate limit exceeded"));
}
*count += 1;
drop(limits);
next.run(message, conn, state, extensions).await
}
}§Request ID Middleware
use wsforge::prelude::*;
use async_trait::async_trait;
struct RequestIdMiddleware;
#[async_trait]
impl Middleware for RequestIdMiddleware {
async fn handle(
&self,
message: Message,
conn: Connection,
state: AppState,
extensions: Extensions,
mut next: Next,
) -> Result<Option<Message>> {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let request_id = COUNTER.fetch_add(1, Ordering::SeqCst);
extensions.insert("request_id", request_id);
next.run(message, conn, state, extensions).await
}
}Required Methods§
Sourcefn handle<'life0, 'async_trait>(
&'life0 self,
message: Message,
conn: Connection,
state: AppState,
extensions: Extensions,
next: Next,
) -> Pin<Box<dyn Future<Output = Result<Option<Message>>> + Send + 'async_trait>>where
Self: 'async_trait,
'life0: 'async_trait,
fn handle<'life0, 'async_trait>(
&'life0 self,
message: Message,
conn: Connection,
state: AppState,
extensions: Extensions,
next: Next,
) -> Pin<Box<dyn Future<Output = Result<Option<Message>>> + Send + 'async_trait>>where
Self: 'async_trait,
'life0: 'async_trait,
Handle a message and optionally pass it to the next middleware.
§Arguments
message- The incoming WebSocket messageconn- The connection that sent the messagestate- Application stateextensions- Request-scoped extension datanext- The next step in the middleware chain
§Returns
Returns an optional message to send back to the client, or an error.