wsforge_core/
extractor.rs

1//! Type-safe data extraction from WebSocket messages and context.
2//!
3//! This module provides a powerful type extraction system inspired by frameworks like Axum,
4//! allowing handlers to declaratively specify what data they need. Extractors automatically
5//! parse and validate data from messages, connection state, and application context.
6//!
7//! # Overview
8//!
9//! Extractors are types that implement the [`FromMessage`] trait. They can extract:
10//! - **Message content**: JSON, binary data, text
11//! - **Connection info**: Client address, connection ID, metadata
12//! - **Application state**: Shared data like database pools, configuration
13//! - **Route parameters**: Path and query parameters from routing
14//! - **Custom extensions**: User-defined request-scoped data
15//!
16//! # Design Philosophy
17//!
18//! The extractor system follows these principles:
19//! - **Type safety**: Extraction failures are caught at runtime with clear errors
20//! - **Composability**: Multiple extractors can be used in a single handler
21//! - **Zero cost**: Extraction happens only once per handler invocation
22//! - **Flexibility**: Custom extractors can be easily implemented
23//!
24//! # Built-in Extractors
25//!
26//! | Extractor | Description | Example |
27//! |-----------|-------------|---------|
28//! | [`Json<T>`] | Deserialize JSON from message | `Json(user): Json<User>` |
29//! | [`State<T>`] | Extract shared application state | `State(db): State<Arc<Database>>` |
30//! | [`Connection`] | Get the active connection | `conn: Connection` |
31//! | [`ConnectInfo`] | Get connection metadata | `ConnectInfo(info)` |
32//! | [`Message`] | Get raw message | `msg: Message` |
33//! | [`Data`] | Extract binary data | `Data(bytes)` |
34//! | [`Path<T>`] | Extract path parameters | `Path(id): Path<UserId>` |
35//! | [`Query<T>`] | Extract query parameters | `Query(params): Query<SearchParams>` |
36//! | [`Extension<T>`] | Extract custom extensions | `Extension(auth): Extension<Auth>` |
37//!
38//! # Examples
39//!
40//! ## Simple JSON Extraction
41//!
42//! ```
43//! use wsforge::prelude::*;
44//! use serde::Deserialize;
45//!
46//! #[derive(Deserialize)]
47//! struct ChatMessage {
48//!     username: String,
49//!     text: String,
50//! }
51//!
52//! async fn chat_handler(Json(msg): Json<ChatMessage>) -> Result<String> {
53//!     println!("{} says: {}", msg.username, msg.text);
54//!     Ok(format!("Message from {} received", msg.username))
55//! }
56//! ```
57//!
58//! ## Multiple Extractors
59//!
60//! ```
61//! use wsforge::prelude::*;
62//! use serde::Deserialize;
63//! use std::sync::Arc;
64//!
65//! #[derive(Deserialize)]
66//! struct GameMove {
67//!     player: String,
68//!     action: String,
69//! }
70//!
71//! async fn game_handler(
72//!     Json(game_move): Json<GameMove>,
73//!     conn: Connection,
74//!     State(manager): State<Arc<ConnectionManager>>,
75//! ) -> Result<()> {
76//!     println!("Player {} from connection {} made move: {}",
77//!         game_move.player, conn.id(), game_move.action);
78//!
79//!     // Broadcast to other players
80//!     manager.broadcast_except(conn.id(),
81//!         Message::text(format!("{} moved", game_move.player)));
82//!
83//!     Ok(())
84//! }
85//! ```
86//!
87//! ## Custom Extractors
88//!
89//! ```
90//! use wsforge::prelude::*;
91//! use async_trait::async_trait;
92//!
93//! // Custom extractor for authenticated users
94//! struct AuthUser {
95//!     user_id: u64,
96//!     username: String,
97//! }
98//!
99//! #[async_trait]
100//! impl FromMessage for AuthUser {
101//!     async fn from_message(
102//!         message: &Message,
103//!         conn: &Connection,
104//!         state: &AppState,
105//!         extensions: &Extensions,
106//!     ) -> Result<Self> {
107//!         // Extract authentication token from message
108//!         let text = message.as_text()
109//!             .ok_or_else(|| Error::extractor("Message must be text"))?;
110//!
111//!         // Validate and extract user info
112//!         // (In production, verify JWT, session token, etc.)
113//!         Ok(AuthUser {
114//!             user_id: 123,
115//!             username: "user".to_string(),
116//!         })
117//!     }
118//! }
119//!
120//! async fn protected_handler(user: AuthUser) -> Result<String> {
121//!     Ok(format!("Hello, {}!", user.username))
122//! }
123//! ```
124
125use crate::connection::{Connection, ConnectionInfo};
126use crate::error::{Error, Result};
127use crate::message::Message;
128use crate::state::AppState;
129use async_trait::async_trait;
130use dashmap::DashMap;
131use serde::Serialize;
132use serde::de::DeserializeOwned;
133use std::sync::Arc;
134
135/// Trait for types that can be extracted from WebSocket messages and context.
136///
137/// This trait is the core of the extractor system. Types that implement `FromMessage`
138/// can be used as handler parameters, and the framework will automatically extract
139/// and validate the data before calling the handler.
140///
141/// # Implementation Guidelines
142///
143/// When implementing custom extractors:
144/// 1. **Be specific**: Return clear error messages when extraction fails
145/// 2. **Be efficient**: Avoid expensive operations if possible
146/// 3. **Be safe**: Validate all extracted data
147/// 4. **Document**: Explain what data is extracted and any requirements
148///
149/// # Examples
150///
151/// ## Simple Extractor
152///
153/// ```
154/// use wsforge::prelude::*;
155/// use async_trait::async_trait;
156///
157/// struct MessageLength(usize);
158///
159/// #[async_trait]
160/// impl FromMessage for MessageLength {
161///     async fn from_message(
162///         message: &Message,
163///         _conn: &Connection,
164///         _state: &AppState,
165///         _extensions: &Extensions,
166///     ) -> Result<Self> {
167///         let len = message.as_bytes().len();
168///         Ok(MessageLength(len))
169///     }
170/// }
171///
172/// async fn handler(MessageLength(len): MessageLength) -> Result<String> {
173///     Ok(format!("Message length: {}", len))
174/// }
175/// ```
176///
177/// ## Extractor with Validation
178///
179/// ```
180/// use wsforge::prelude::*;
181/// use async_trait::async_trait;
182///
183/// struct ValidatedText(String);
184///
185/// #[async_trait]
186/// impl FromMessage for ValidatedText {
187///     async fn from_message(
188///         message: &Message,
189///         _conn: &Connection,
190///         _state: &AppState,
191///         _extensions: &Extensions,
192///     ) -> Result<Self> {
193///         let text = message.as_text()
194///             .ok_or_else(|| Error::extractor("Message must be text"))?;
195///
196///         if text.is_empty() {
197///             return Err(Error::extractor("Text cannot be empty"));
198///         }
199///
200///         if text.len() > 1000 {
201///             return Err(Error::extractor("Text too long (max 1000 characters)"));
202///         }
203///
204///         Ok(ValidatedText(text.to_string()))
205///     }
206/// }
207/// ```
208#[async_trait]
209pub trait FromMessage: Sized {
210    /// Extracts `Self` from the message and context.
211    ///
212    /// # Arguments
213    ///
214    /// * `message` - The WebSocket message being processed
215    /// * `conn` - The connection that sent the message
216    /// * `state` - The application state
217    /// * `extensions` - Request-scoped extension data
218    ///
219    /// # Errors
220    ///
221    /// Returns an error if extraction fails. Common reasons include:
222    /// - Required data is missing
223    /// - Data format is invalid
224    /// - Type mismatch
225    /// - Validation failure
226    async fn from_message(
227        message: &Message,
228        conn: &Connection,
229        state: &AppState,
230        extensions: &Extensions,
231    ) -> Result<Self>;
232}
233
234/// Container for request-scoped extension data.
235///
236/// Extensions provide a way to pass arbitrary data through the request pipeline.
237/// This is useful for middleware to attach data (like authentication info, request IDs)
238/// that handlers can later extract.
239///
240/// # Thread Safety
241///
242/// Extensions are thread-safe and can be safely shared across tasks.
243///
244/// # Examples
245///
246/// ## Adding and Retrieving Data
247///
248/// ```
249/// use wsforge::prelude::*;
250///
251/// # fn example() {
252/// let extensions = Extensions::new();
253///
254/// // Add data
255/// extensions.insert("request_id", "req_123");
256/// extensions.insert("user_id", 42_u64);
257///
258/// // Retrieve data
259/// if let Some(request_id) = extensions.get::<&str>("request_id") {
260///     println!("Request ID: {}", request_id);
261/// }
262///
263/// if let Some(user_id) = extensions.get::<u64>("user_id") {
264///     println!("User ID: {}", user_id);
265/// }
266/// # }
267/// ```
268///
269/// ## Use in Middleware
270///
271/// ```
272/// use wsforge::prelude::*;
273///
274/// async fn auth_middleware(
275///     msg: Message,
276///     conn: Connection,
277///     extensions: &Extensions,
278/// ) -> Result<()> {
279///     // Extract and validate auth token
280///     let token = extract_token(&msg)?;
281///     let user_id = validate_token(&token)?;
282///
283///     // Store for handler to use
284///     extensions.insert("user_id", user_id);
285///
286///     Ok(())
287/// }
288///
289/// # fn extract_token(_: &Message) -> Result<String> { Ok("token".to_string()) }
290/// # fn validate_token(_: &str) -> Result<u64> { Ok(123) }
291/// ```
292#[derive(Clone)]
293pub struct Extensions {
294    data: Arc<DashMap<String, Arc<dyn std::any::Any + Send + Sync>>>,
295}
296
297impl Extensions {
298    /// Creates a new empty `Extensions` container.
299    ///
300    /// # Examples
301    ///
302    /// ```
303    /// use wsforge::prelude::*;
304    ///
305    /// let extensions = Extensions::new();
306    /// ```
307    pub fn new() -> Self {
308        Self {
309            data: Arc::new(DashMap::new()),
310        }
311    }
312
313    /// Inserts a value into the extensions.
314    ///
315    /// The value is stored under the given key and can be retrieved later
316    /// using the same key and type.
317    ///
318    /// # Arguments
319    ///
320    /// * `key` - A unique identifier for this value
321    /// * `value` - The value to store (must be `Send + Sync + 'static`)
322    ///
323    /// # Examples
324    ///
325    /// ```
326    /// use wsforge::prelude::*;
327    ///
328    /// # fn example() {
329    /// let extensions = Extensions::new();
330    ///
331    /// // Store different types
332    /// extensions.insert("count", 42_u32);
333    /// extensions.insert("name", "Alice".to_string());
334    /// extensions.insert("active", true);
335    /// # }
336    /// ```
337    pub fn insert<T: Send + Sync + 'static>(&self, key: impl Into<String>, value: T) {
338        self.data.insert(key.into(), Arc::new(value));
339    }
340
341    /// Retrieves a value from the extensions.
342    ///
343    /// Returns `None` if the key doesn't exist or if the stored type doesn't
344    /// match the requested type.
345    ///
346    /// # Type Safety
347    ///
348    /// The returned value must match the type that was originally inserted.
349    /// Attempting to retrieve with a different type will return `None`.
350    ///
351    /// # Examples
352    ///
353    /// ```
354    /// use wsforge::prelude::*;
355    ///
356    /// # fn example() {
357    /// let extensions = Extensions::new();
358    /// extensions.insert("count", 42_u32);
359    ///
360    /// // Correct type - succeeds
361    /// let count: Option<Arc<u32>> = extensions.get("count");
362    /// assert_eq!(*count.unwrap(), 42);
363    ///
364    /// // Wrong type - returns None
365    /// let wrong: Option<Arc<String>> = extensions.get("count");
366    /// assert!(wrong.is_none());
367    /// # }
368    /// ```
369    pub fn get<T: Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
370        self.data
371            .get(key)
372            .and_then(|arc| arc.value().clone().downcast::<T>().ok())
373    }
374}
375
376impl Default for Extensions {
377    fn default() -> Self {
378        Self::new()
379    }
380}
381
382/// Extractor for shared application state.
383///
384/// Use this to access data that's shared across all connections, such as:
385/// - Database connection pools
386/// - Configuration
387/// - Caches
388/// - Connection managers
389///
390/// # Type Parameter
391///
392/// The generic parameter `T` should be wrapped in `Arc` since state is shared.
393///
394/// # Examples
395///
396/// ## Accessing Connection Manager
397///
398/// ```
399/// use wsforge::prelude::*;
400/// use std::sync::Arc;
401///
402/// async fn broadcast_handler(
403///     msg: Message,
404///     State(manager): State<Arc<ConnectionManager>>,
405/// ) -> Result<()> {
406///     manager.broadcast(msg);
407///     Ok(())
408/// }
409/// ```
410///
411/// ## Custom State Type
412///
413/// ```
414/// use wsforge::prelude::*;
415/// use std::sync::Arc;
416///
417/// struct AppConfig {
418///     max_message_size: usize,
419///     rate_limit: u32,
420/// }
421///
422/// async fn handler(State(config): State<Arc<AppConfig>>) -> Result<String> {
423///     Ok(format!("Max message size: {}", config.max_message_size))
424/// }
425/// ```
426pub struct State<T>(pub Arc<T>);
427
428#[async_trait]
429impl<T: Send + Sync + 'static> FromMessage for State<T> {
430    async fn from_message(
431        _message: &Message,
432        _conn: &Connection,
433        state: &AppState,
434        _extensions: &Extensions,
435    ) -> Result<Self> {
436        state
437            .get::<T>()
438            .ok_or_else(|| Error::extractor("State not found"))
439            .map(State)
440    }
441}
442
443/// Extractor for JSON data from messages.
444///
445/// Automatically deserializes the message content as JSON into the specified type.
446/// The type must implement `serde::Deserialize`.
447///
448/// # Errors
449///
450/// Returns an error if:
451/// - The message is not text
452/// - The JSON is malformed
453/// - Required fields are missing
454/// - Type constraints are not satisfied
455///
456/// # Examples
457///
458/// ## Simple Struct
459///
460/// ```
461/// use wsforge::prelude::*;
462/// use serde::Deserialize;
463///
464/// #[derive(Deserialize)]
465/// struct LoginRequest {
466///     username: String,
467///     password: String,
468/// }
469///
470/// async fn login_handler(Json(req): Json<LoginRequest>) -> Result<String> {
471///     // Validate credentials
472///     Ok(format!("Login attempt by {}", req.username))
473/// }
474/// ```
475///
476/// ## With Validation
477///
478/// ```
479/// use wsforge::prelude::*;
480/// use serde::Deserialize;
481///
482/// #[derive(Deserialize)]
483/// struct CreateUser {
484///     #[serde(deserialize_with = "validate_username")]
485///     username: String,
486///     age: u8,
487/// }
488///
489/// async fn create_user(Json(user): Json<CreateUser>) -> Result<String> {
490///     Ok(format!("Creating user: {}", user.username))
491/// }
492///
493/// # fn validate_username<'de, D>(_: D) -> std::result::Result<String, D::Error>
494/// # where D: serde::Deserializer<'de> {
495/// #     Ok("valid".to_string())
496/// # }
497/// ```
498///
499/// ## Nested Structures
500///
501/// ```
502/// use wsforge::prelude::*;
503/// use serde::Deserialize;
504///
505/// #[derive(Deserialize)]
506/// struct GameState {
507///     player: Player,
508///     score: u32,
509/// }
510///
511/// #[derive(Deserialize)]
512/// struct Player {
513///     id: u64,
514///     name: String,
515/// }
516///
517/// async fn update_game(Json(state): Json<GameState>) -> Result<()> {
518///     println!("Player {} score: {}", state.player.name, state.score);
519///     Ok(())
520/// }
521/// ```
522pub struct Json<T>(pub T);
523
524#[async_trait]
525impl<T: DeserializeOwned + Send> FromMessage for Json<T> {
526    async fn from_message(
527        message: &Message,
528        _conn: &Connection,
529        _state: &AppState,
530        _extensions: &Extensions,
531    ) -> Result<Self> {
532        let data: T = message.json()?;
533        Ok(Json(data))
534    }
535}
536
537impl<T: Serialize> Json<T> {
538    /// Converts this JSON extractor back into a message.
539    ///
540    /// This is useful for echoing back modified data or creating responses.
541    ///
542    /// # Examples
543    ///
544    /// ```
545    /// use wsforge::prelude::*;
546    /// use serde::{Deserialize, Serialize};
547    ///
548    /// #[derive(Deserialize, Serialize)]
549    /// struct Echo {
550    ///     text: String,
551    /// }
552    ///
553    /// async fn echo_handler(Json(data): Json<Echo>) -> Result<Message> {
554    ///     Json(data).into_message()
555    /// }
556    /// ```
557    pub fn into_message(self) -> Result<Message> {
558        let json = serde_json::to_string(&self.0)?;
559        Ok(Message::text(json))
560    }
561}
562
563/// Extractor for the active connection.
564///
565/// Provides access to the connection that sent the message, allowing you to:
566/// - Get the connection ID
567/// - Access connection metadata
568/// - Send messages back to the specific client
569///
570/// # Examples
571///
572/// ## Sending Response
573///
574/// ```
575/// use wsforge::prelude::*;
576///
577/// async fn handler(msg: Message, conn: Connection) -> Result<()> {
578///     conn.send_text("Message received!")?;
579///     Ok(())
580/// }
581/// ```
582///
583/// ## Using Connection Info
584///
585/// ```
586/// use wsforge::prelude::*;
587///
588/// async fn handler(conn: Connection) -> Result<String> {
589///     let info = conn.info();
590///     Ok(format!("Client from {} connected at {}",
591///         info.addr, info.connected_at))
592/// }
593/// ```
594#[async_trait]
595impl FromMessage for Connection {
596    async fn from_message(
597        _message: &Message,
598        conn: &Connection,
599        _state: &AppState,
600        _extensions: &Extensions,
601    ) -> Result<Self> {
602        Ok(conn.clone())
603    }
604}
605
606/// Extractor for connection metadata.
607///
608/// Provides detailed information about the connection, including:
609/// - Connection ID
610/// - Client socket address
611/// - Connection timestamp
612/// - Protocol information
613///
614/// # Examples
615///
616/// ## Logging Connection Info
617///
618/// ```
619/// use wsforge::prelude::*;
620///
621/// async fn handler(ConnectInfo(info): ConnectInfo) -> Result<String> {
622///     println!("Connection {} from {} at {}",
623///         info.id, info.addr, info.connected_at);
624///     Ok("Connected".to_string())
625/// }
626/// ```
627pub struct ConnectInfo(pub ConnectionInfo);
628
629#[async_trait]
630impl FromMessage for ConnectInfo {
631    async fn from_message(
632        _message: &Message,
633        conn: &Connection,
634        _state: &AppState,
635        _extensions: &Extensions,
636    ) -> Result<Self> {
637        Ok(ConnectInfo(conn.info.clone()))
638    }
639}
640
641/// Extractor for the raw message.
642///
643/// Use this when you need access to the complete message without
644/// automatic deserialization.
645///
646/// # Examples
647///
648/// ## Raw Message Processing
649///
650/// ```
651/// use wsforge::prelude::*;
652///
653/// async fn handler(msg: Message) -> Result<String> {
654///     if msg.is_text() {
655///         Ok(format!("Text: {}", msg.as_text().unwrap()))
656///     } else if msg.is_binary() {
657///         Ok(format!("Binary: {} bytes", msg.as_bytes().len()))
658///     } else {
659///         Ok("Unknown message type".to_string())
660///     }
661/// }
662/// ```
663#[async_trait]
664impl FromMessage for Message {
665    async fn from_message(
666        message: &Message,
667        _conn: &Connection,
668        _state: &AppState,
669        _extensions: &Extensions,
670    ) -> Result<Self> {
671        Ok(message.clone())
672    }
673}
674
675/// Extractor for raw binary data.
676///
677/// Extracts the message payload as raw bytes. Works with both text and binary messages.
678///
679/// # Examples
680///
681/// ## Processing Binary Data
682///
683/// ```
684/// use wsforge::prelude::*;
685///
686/// async fn handler(Data(bytes): Data) -> Result<String> {
687///     println!("Received {} bytes", bytes.len());
688///     Ok(format!("Processed {} bytes", bytes.len()))
689/// }
690/// ```
691pub struct Data(pub Vec<u8>);
692
693#[async_trait]
694impl FromMessage for Data {
695    async fn from_message(
696        message: &Message,
697        _conn: &Connection,
698        _state: &AppState,
699        _extensions: &Extensions,
700    ) -> Result<Self> {
701        Ok(Data(message.data.clone()))
702    }
703}
704
705/// Extractor for path parameters.
706///
707/// Extracts typed parameters from the request path. The type must implement
708/// `serde::Deserialize` and be stored in extensions by routing middleware.
709///
710/// # Examples
711///
712/// ## Single Parameter
713///
714/// ```
715/// use wsforge::prelude::*;
716/// use serde::Deserialize;
717///
718/// #[derive(Deserialize)]
719/// struct UserId(u64);
720///
721/// async fn get_user(Path(UserId(id)): Path<UserId>) -> Result<String> {
722///     Ok(format!("Getting user {}", id))
723/// }
724/// ```
725///
726/// ## Multiple Parameters
727///
728/// ```
729/// use wsforge::prelude::*;
730/// use serde::Deserialize;
731///
732/// #[derive(Deserialize)]
733/// struct RoomParams {
734///     room_id: String,
735///     user_id: u64,
736/// }
737///
738/// async fn join_room(Path(params): Path<RoomParams>) -> Result<String> {
739///     Ok(format!("User {} joining room {}", params.user_id, params.room_id))
740/// }
741/// ```
742pub struct Path<T>(pub T);
743
744#[async_trait]
745impl<T: DeserializeOwned + Send + Sync + Clone + 'static> FromMessage for Path<T> {
746    async fn from_message(
747        _message: &Message,
748        _conn: &Connection,
749        _state: &AppState,
750        extensions: &Extensions,
751    ) -> Result<Self> {
752        extensions
753            .get::<T>("path_params")
754            .ok_or_else(|| Error::extractor("Path parameters not found"))
755            .map(|arc| Path((*arc).clone()))
756    }
757}
758
759/// Extractor for query parameters.
760///
761/// Extracts typed parameters from the query string. The type must implement
762/// `serde::Deserialize` and be stored in extensions during connection establishment.
763///
764/// # Examples
765///
766/// ## Search Parameters
767///
768/// ```
769/// use wsforge::prelude::*;
770/// use serde::Deserialize;
771///
772/// #[derive(Deserialize)]
773/// struct SearchQuery {
774///     q: String,
775///     limit: Option<u32>,
776/// }
777///
778/// async fn search(Query(params): Query<SearchQuery>) -> Result<String> {
779///     let limit = params.limit.unwrap_or(10);
780///     Ok(format!("Searching for '{}' (limit: {})", params.q, limit))
781/// }
782/// ```
783pub struct Query<T>(pub T);
784
785#[async_trait]
786impl<T: DeserializeOwned + Send + Sync + Clone + 'static> FromMessage for Query<T> {
787    async fn from_message(
788        _message: &Message,
789        _conn: &Connection,
790        _state: &AppState,
791        extensions: &Extensions,
792    ) -> Result<Self> {
793        extensions
794            .get::<T>("query_params")
795            .ok_or_else(|| Error::extractor("Query parameters not found"))
796            .map(|arc| Query((*arc).clone()))
797    }
798}
799
800/// Extractor for custom extension data.
801///
802/// Retrieves data that was previously stored in extensions by middleware or other handlers.
803///
804/// # Examples
805///
806/// ## Authentication Data
807///
808/// ```
809/// use wsforge::prelude::*;
810/// use std::sync::Arc;
811///
812/// #[derive(Clone)]
813/// struct AuthData {
814///     user_id: u64,
815///     role: String,
816/// }
817///
818/// async fn protected_handler(Extension(auth): Extension<AuthData>) -> Result<String> {
819///     Ok(format!("User {} with role {}", auth.user_id, auth.role))
820/// }
821/// ```
822pub struct Extension<T>(pub Arc<T>);
823
824#[async_trait]
825impl<T: Send + Sync + Clone + 'static> FromMessage for Extension<T> {
826    async fn from_message(
827        _message: &Message,
828        _conn: &Connection,
829        _state: &AppState,
830        extensions: &Extensions,
831    ) -> Result<Self> {
832        extensions
833            .get::<T>(std::any::type_name::<T>())
834            .ok_or_else(|| Error::extractor("Extension not found"))
835            .map(Extension)
836    }
837}