wsforge_core/extractor.rs
1//! Type-safe data extraction from WebSocket messages and context.
2//!
3//! This module provides a powerful type extraction system inspired by frameworks like Axum,
4//! allowing handlers to declaratively specify what data they need. Extractors automatically
5//! parse and validate data from messages, connection state, and application context.
6//!
7//! # Overview
8//!
9//! Extractors are types that implement the [`FromMessage`] trait. They can extract:
10//! - **Message content**: JSON, binary data, text
11//! - **Connection info**: Client address, connection ID, metadata
12//! - **Application state**: Shared data like database pools, configuration
13//! - **Route parameters**: Path and query parameters from routing
14//! - **Custom extensions**: User-defined request-scoped data
15//!
16//! # Design Philosophy
17//!
18//! The extractor system follows these principles:
19//! - **Type safety**: Extraction failures are caught at runtime with clear errors
20//! - **Composability**: Multiple extractors can be used in a single handler
21//! - **Zero cost**: Extraction happens only once per handler invocation
22//! - **Flexibility**: Custom extractors can be easily implemented
23//!
24//! # Built-in Extractors
25//!
26//! | Extractor | Description | Example |
27//! |-----------|-------------|---------|
28//! | [`Json<T>`] | Deserialize JSON from message | `Json(user): Json<User>` |
29//! | [`State<T>`] | Extract shared application state | `State(db): State<Arc<Database>>` |
30//! | [`Connection`] | Get the active connection | `conn: Connection` |
31//! | [`ConnectInfo`] | Get connection metadata | `ConnectInfo(info)` |
32//! | [`Message`] | Get raw message | `msg: Message` |
33//! | [`Data`] | Extract binary data | `Data(bytes)` |
34//! | [`Path<T>`] | Extract path parameters | `Path(id): Path<UserId>` |
35//! | [`Query<T>`] | Extract query parameters | `Query(params): Query<SearchParams>` |
36//! | [`Extension<T>`] | Extract custom extensions | `Extension(auth): Extension<Auth>` |
37//!
38//! # Examples
39//!
40//! ## Simple JSON Extraction
41//!
42//! ```
43//! use wsforge::prelude::*;
44//! use serde::Deserialize;
45//!
46//! #[derive(Deserialize)]
47//! struct ChatMessage {
48//! username: String,
49//! text: String,
50//! }
51//!
52//! async fn chat_handler(Json(msg): Json<ChatMessage>) -> Result<String> {
53//! println!("{} says: {}", msg.username, msg.text);
54//! Ok(format!("Message from {} received", msg.username))
55//! }
56//! ```
57//!
58//! ## Multiple Extractors
59//!
60//! ```
61//! use wsforge::prelude::*;
62//! use serde::Deserialize;
63//! use std::sync::Arc;
64//!
65//! #[derive(Deserialize)]
66//! struct GameMove {
67//! player: String,
68//! action: String,
69//! }
70//!
71//! async fn game_handler(
72//! Json(game_move): Json<GameMove>,
73//! conn: Connection,
74//! State(manager): State<Arc<ConnectionManager>>,
75//! ) -> Result<()> {
76//! println!("Player {} from connection {} made move: {}",
77//! game_move.player, conn.id(), game_move.action);
78//!
79//! // Broadcast to other players
80//! manager.broadcast_except(conn.id(),
81//! Message::text(format!("{} moved", game_move.player)));
82//!
83//! Ok(())
84//! }
85//! ```
86//!
87//! ## Custom Extractors
88//!
89//! ```
90//! use wsforge::prelude::*;
91//! use async_trait::async_trait;
92//!
93//! // Custom extractor for authenticated users
94//! struct AuthUser {
95//! user_id: u64,
96//! username: String,
97//! }
98//!
99//! #[async_trait]
100//! impl FromMessage for AuthUser {
101//! async fn from_message(
102//! message: &Message,
103//! conn: &Connection,
104//! state: &AppState,
105//! extensions: &Extensions,
106//! ) -> Result<Self> {
107//! // Extract authentication token from message
108//! let text = message.as_text()
109//! .ok_or_else(|| Error::extractor("Message must be text"))?;
110//!
111//! // Validate and extract user info
112//! // (In production, verify JWT, session token, etc.)
113//! Ok(AuthUser {
114//! user_id: 123,
115//! username: "user".to_string(),
116//! })
117//! }
118//! }
119//!
120//! async fn protected_handler(user: AuthUser) -> Result<String> {
121//! Ok(format!("Hello, {}!", user.username))
122//! }
123//! ```
124
125use crate::connection::{Connection, ConnectionInfo};
126use crate::error::{Error, Result};
127use crate::message::Message;
128use crate::state::AppState;
129use async_trait::async_trait;
130use dashmap::DashMap;
131use serde::Serialize;
132use serde::de::DeserializeOwned;
133use std::sync::Arc;
134
135/// Trait for types that can be extracted from WebSocket messages and context.
136///
137/// This trait is the core of the extractor system. Types that implement `FromMessage`
138/// can be used as handler parameters, and the framework will automatically extract
139/// and validate the data before calling the handler.
140///
141/// # Implementation Guidelines
142///
143/// When implementing custom extractors:
144/// 1. **Be specific**: Return clear error messages when extraction fails
145/// 2. **Be efficient**: Avoid expensive operations if possible
146/// 3. **Be safe**: Validate all extracted data
147/// 4. **Document**: Explain what data is extracted and any requirements
148///
149/// # Examples
150///
151/// ## Simple Extractor
152///
153/// ```
154/// use wsforge::prelude::*;
155/// use async_trait::async_trait;
156///
157/// struct MessageLength(usize);
158///
159/// #[async_trait]
160/// impl FromMessage for MessageLength {
161/// async fn from_message(
162/// message: &Message,
163/// _conn: &Connection,
164/// _state: &AppState,
165/// _extensions: &Extensions,
166/// ) -> Result<Self> {
167/// let len = message.as_bytes().len();
168/// Ok(MessageLength(len))
169/// }
170/// }
171///
172/// async fn handler(MessageLength(len): MessageLength) -> Result<String> {
173/// Ok(format!("Message length: {}", len))
174/// }
175/// ```
176///
177/// ## Extractor with Validation
178///
179/// ```
180/// use wsforge::prelude::*;
181/// use async_trait::async_trait;
182///
183/// struct ValidatedText(String);
184///
185/// #[async_trait]
186/// impl FromMessage for ValidatedText {
187/// async fn from_message(
188/// message: &Message,
189/// _conn: &Connection,
190/// _state: &AppState,
191/// _extensions: &Extensions,
192/// ) -> Result<Self> {
193/// let text = message.as_text()
194/// .ok_or_else(|| Error::extractor("Message must be text"))?;
195///
196/// if text.is_empty() {
197/// return Err(Error::extractor("Text cannot be empty"));
198/// }
199///
200/// if text.len() > 1000 {
201/// return Err(Error::extractor("Text too long (max 1000 characters)"));
202/// }
203///
204/// Ok(ValidatedText(text.to_string()))
205/// }
206/// }
207/// ```
208#[async_trait]
209pub trait FromMessage: Sized {
210 /// Extracts `Self` from the message and context.
211 ///
212 /// # Arguments
213 ///
214 /// * `message` - The WebSocket message being processed
215 /// * `conn` - The connection that sent the message
216 /// * `state` - The application state
217 /// * `extensions` - Request-scoped extension data
218 ///
219 /// # Errors
220 ///
221 /// Returns an error if extraction fails. Common reasons include:
222 /// - Required data is missing
223 /// - Data format is invalid
224 /// - Type mismatch
225 /// - Validation failure
226 async fn from_message(
227 message: &Message,
228 conn: &Connection,
229 state: &AppState,
230 extensions: &Extensions,
231 ) -> Result<Self>;
232}
233
234/// Container for request-scoped extension data.
235///
236/// Extensions provide a way to pass arbitrary data through the request pipeline.
237/// This is useful for middleware to attach data (like authentication info, request IDs)
238/// that handlers can later extract.
239///
240/// # Thread Safety
241///
242/// Extensions are thread-safe and can be safely shared across tasks.
243///
244/// # Examples
245///
246/// ## Adding and Retrieving Data
247///
248/// ```
249/// use wsforge::prelude::*;
250///
251/// # fn example() {
252/// let extensions = Extensions::new();
253///
254/// // Add data
255/// extensions.insert("request_id", "req_123");
256/// extensions.insert("user_id", 42_u64);
257///
258/// // Retrieve data
259/// if let Some(request_id) = extensions.get::<&str>("request_id") {
260/// println!("Request ID: {}", request_id);
261/// }
262///
263/// if let Some(user_id) = extensions.get::<u64>("user_id") {
264/// println!("User ID: {}", user_id);
265/// }
266/// # }
267/// ```
268///
269/// ## Use in Middleware
270///
271/// ```
272/// use wsforge::prelude::*;
273///
274/// async fn auth_middleware(
275/// msg: Message,
276/// conn: Connection,
277/// extensions: &Extensions,
278/// ) -> Result<()> {
279/// // Extract and validate auth token
280/// let token = extract_token(&msg)?;
281/// let user_id = validate_token(&token)?;
282///
283/// // Store for handler to use
284/// extensions.insert("user_id", user_id);
285///
286/// Ok(())
287/// }
288///
289/// # fn extract_token(_: &Message) -> Result<String> { Ok("token".to_string()) }
290/// # fn validate_token(_: &str) -> Result<u64> { Ok(123) }
291/// ```
292#[derive(Clone)]
293pub struct Extensions {
294 data: Arc<DashMap<String, Arc<dyn std::any::Any + Send + Sync>>>,
295}
296
297impl Extensions {
298 /// Creates a new empty `Extensions` container.
299 ///
300 /// # Examples
301 ///
302 /// ```
303 /// use wsforge::prelude::*;
304 ///
305 /// let extensions = Extensions::new();
306 /// ```
307 pub fn new() -> Self {
308 Self {
309 data: Arc::new(DashMap::new()),
310 }
311 }
312
313 /// Inserts a value into the extensions.
314 ///
315 /// The value is stored under the given key and can be retrieved later
316 /// using the same key and type.
317 ///
318 /// # Arguments
319 ///
320 /// * `key` - A unique identifier for this value
321 /// * `value` - The value to store (must be `Send + Sync + 'static`)
322 ///
323 /// # Examples
324 ///
325 /// ```
326 /// use wsforge::prelude::*;
327 ///
328 /// # fn example() {
329 /// let extensions = Extensions::new();
330 ///
331 /// // Store different types
332 /// extensions.insert("count", 42_u32);
333 /// extensions.insert("name", "Alice".to_string());
334 /// extensions.insert("active", true);
335 /// # }
336 /// ```
337 pub fn insert<T: Send + Sync + 'static>(&self, key: impl Into<String>, value: T) {
338 self.data.insert(key.into(), Arc::new(value));
339 }
340
341 /// Retrieves a value from the extensions.
342 ///
343 /// Returns `None` if the key doesn't exist or if the stored type doesn't
344 /// match the requested type.
345 ///
346 /// # Type Safety
347 ///
348 /// The returned value must match the type that was originally inserted.
349 /// Attempting to retrieve with a different type will return `None`.
350 ///
351 /// # Examples
352 ///
353 /// ```
354 /// use wsforge::prelude::*;
355 ///
356 /// # fn example() {
357 /// let extensions = Extensions::new();
358 /// extensions.insert("count", 42_u32);
359 ///
360 /// // Correct type - succeeds
361 /// let count: Option<Arc<u32>> = extensions.get("count");
362 /// assert_eq!(*count.unwrap(), 42);
363 ///
364 /// // Wrong type - returns None
365 /// let wrong: Option<Arc<String>> = extensions.get("count");
366 /// assert!(wrong.is_none());
367 /// # }
368 /// ```
369 pub fn get<T: Send + Sync + 'static>(&self, key: &str) -> Option<Arc<T>> {
370 self.data
371 .get(key)
372 .and_then(|arc| arc.value().clone().downcast::<T>().ok())
373 }
374}
375
376impl Default for Extensions {
377 fn default() -> Self {
378 Self::new()
379 }
380}
381
382/// Extractor for shared application state.
383///
384/// Use this to access data that's shared across all connections, such as:
385/// - Database connection pools
386/// - Configuration
387/// - Caches
388/// - Connection managers
389///
390/// # Type Parameter
391///
392/// The generic parameter `T` should be wrapped in `Arc` since state is shared.
393///
394/// # Examples
395///
396/// ## Accessing Connection Manager
397///
398/// ```
399/// use wsforge::prelude::*;
400/// use std::sync::Arc;
401///
402/// async fn broadcast_handler(
403/// msg: Message,
404/// State(manager): State<Arc<ConnectionManager>>,
405/// ) -> Result<()> {
406/// manager.broadcast(msg);
407/// Ok(())
408/// }
409/// ```
410///
411/// ## Custom State Type
412///
413/// ```
414/// use wsforge::prelude::*;
415/// use std::sync::Arc;
416///
417/// struct AppConfig {
418/// max_message_size: usize,
419/// rate_limit: u32,
420/// }
421///
422/// async fn handler(State(config): State<Arc<AppConfig>>) -> Result<String> {
423/// Ok(format!("Max message size: {}", config.max_message_size))
424/// }
425/// ```
426pub struct State<T>(pub Arc<T>);
427
428#[async_trait]
429impl<T: Send + Sync + 'static> FromMessage for State<T> {
430 async fn from_message(
431 _message: &Message,
432 _conn: &Connection,
433 state: &AppState,
434 _extensions: &Extensions,
435 ) -> Result<Self> {
436 state
437 .get::<T>()
438 .ok_or_else(|| Error::extractor("State not found"))
439 .map(State)
440 }
441}
442
443/// Extractor for JSON data from messages.
444///
445/// Automatically deserializes the message content as JSON into the specified type.
446/// The type must implement `serde::Deserialize`.
447///
448/// # Errors
449///
450/// Returns an error if:
451/// - The message is not text
452/// - The JSON is malformed
453/// - Required fields are missing
454/// - Type constraints are not satisfied
455///
456/// # Examples
457///
458/// ## Simple Struct
459///
460/// ```
461/// use wsforge::prelude::*;
462/// use serde::Deserialize;
463///
464/// #[derive(Deserialize)]
465/// struct LoginRequest {
466/// username: String,
467/// password: String,
468/// }
469///
470/// async fn login_handler(Json(req): Json<LoginRequest>) -> Result<String> {
471/// // Validate credentials
472/// Ok(format!("Login attempt by {}", req.username))
473/// }
474/// ```
475///
476/// ## With Validation
477///
478/// ```
479/// use wsforge::prelude::*;
480/// use serde::Deserialize;
481///
482/// #[derive(Deserialize)]
483/// struct CreateUser {
484/// #[serde(deserialize_with = "validate_username")]
485/// username: String,
486/// age: u8,
487/// }
488///
489/// async fn create_user(Json(user): Json<CreateUser>) -> Result<String> {
490/// Ok(format!("Creating user: {}", user.username))
491/// }
492///
493/// # fn validate_username<'de, D>(_: D) -> std::result::Result<String, D::Error>
494/// # where D: serde::Deserializer<'de> {
495/// # Ok("valid".to_string())
496/// # }
497/// ```
498///
499/// ## Nested Structures
500///
501/// ```
502/// use wsforge::prelude::*;
503/// use serde::Deserialize;
504///
505/// #[derive(Deserialize)]
506/// struct GameState {
507/// player: Player,
508/// score: u32,
509/// }
510///
511/// #[derive(Deserialize)]
512/// struct Player {
513/// id: u64,
514/// name: String,
515/// }
516///
517/// async fn update_game(Json(state): Json<GameState>) -> Result<()> {
518/// println!("Player {} score: {}", state.player.name, state.score);
519/// Ok(())
520/// }
521/// ```
522pub struct Json<T>(pub T);
523
524#[async_trait]
525impl<T: DeserializeOwned + Send> FromMessage for Json<T> {
526 async fn from_message(
527 message: &Message,
528 _conn: &Connection,
529 _state: &AppState,
530 _extensions: &Extensions,
531 ) -> Result<Self> {
532 let data: T = message.json()?;
533 Ok(Json(data))
534 }
535}
536
537impl<T: Serialize> Json<T> {
538 /// Converts this JSON extractor back into a message.
539 ///
540 /// This is useful for echoing back modified data or creating responses.
541 ///
542 /// # Examples
543 ///
544 /// ```
545 /// use wsforge::prelude::*;
546 /// use serde::{Deserialize, Serialize};
547 ///
548 /// #[derive(Deserialize, Serialize)]
549 /// struct Echo {
550 /// text: String,
551 /// }
552 ///
553 /// async fn echo_handler(Json(data): Json<Echo>) -> Result<Message> {
554 /// Json(data).into_message()
555 /// }
556 /// ```
557 pub fn into_message(self) -> Result<Message> {
558 let json = serde_json::to_string(&self.0)?;
559 Ok(Message::text(json))
560 }
561}
562
563/// Extractor for the active connection.
564///
565/// Provides access to the connection that sent the message, allowing you to:
566/// - Get the connection ID
567/// - Access connection metadata
568/// - Send messages back to the specific client
569///
570/// # Examples
571///
572/// ## Sending Response
573///
574/// ```
575/// use wsforge::prelude::*;
576///
577/// async fn handler(msg: Message, conn: Connection) -> Result<()> {
578/// conn.send_text("Message received!")?;
579/// Ok(())
580/// }
581/// ```
582///
583/// ## Using Connection Info
584///
585/// ```
586/// use wsforge::prelude::*;
587///
588/// async fn handler(conn: Connection) -> Result<String> {
589/// let info = conn.info();
590/// Ok(format!("Client from {} connected at {}",
591/// info.addr, info.connected_at))
592/// }
593/// ```
594#[async_trait]
595impl FromMessage for Connection {
596 async fn from_message(
597 _message: &Message,
598 conn: &Connection,
599 _state: &AppState,
600 _extensions: &Extensions,
601 ) -> Result<Self> {
602 Ok(conn.clone())
603 }
604}
605
606/// Extractor for connection metadata.
607///
608/// Provides detailed information about the connection, including:
609/// - Connection ID
610/// - Client socket address
611/// - Connection timestamp
612/// - Protocol information
613///
614/// # Examples
615///
616/// ## Logging Connection Info
617///
618/// ```
619/// use wsforge::prelude::*;
620///
621/// async fn handler(ConnectInfo(info): ConnectInfo) -> Result<String> {
622/// println!("Connection {} from {} at {}",
623/// info.id, info.addr, info.connected_at);
624/// Ok("Connected".to_string())
625/// }
626/// ```
627pub struct ConnectInfo(pub ConnectionInfo);
628
629#[async_trait]
630impl FromMessage for ConnectInfo {
631 async fn from_message(
632 _message: &Message,
633 conn: &Connection,
634 _state: &AppState,
635 _extensions: &Extensions,
636 ) -> Result<Self> {
637 Ok(ConnectInfo(conn.info.clone()))
638 }
639}
640
641/// Extractor for the raw message.
642///
643/// Use this when you need access to the complete message without
644/// automatic deserialization.
645///
646/// # Examples
647///
648/// ## Raw Message Processing
649///
650/// ```
651/// use wsforge::prelude::*;
652///
653/// async fn handler(msg: Message) -> Result<String> {
654/// if msg.is_text() {
655/// Ok(format!("Text: {}", msg.as_text().unwrap()))
656/// } else if msg.is_binary() {
657/// Ok(format!("Binary: {} bytes", msg.as_bytes().len()))
658/// } else {
659/// Ok("Unknown message type".to_string())
660/// }
661/// }
662/// ```
663#[async_trait]
664impl FromMessage for Message {
665 async fn from_message(
666 message: &Message,
667 _conn: &Connection,
668 _state: &AppState,
669 _extensions: &Extensions,
670 ) -> Result<Self> {
671 Ok(message.clone())
672 }
673}
674
675/// Extractor for raw binary data.
676///
677/// Extracts the message payload as raw bytes. Works with both text and binary messages.
678///
679/// # Examples
680///
681/// ## Processing Binary Data
682///
683/// ```
684/// use wsforge::prelude::*;
685///
686/// async fn handler(Data(bytes): Data) -> Result<String> {
687/// println!("Received {} bytes", bytes.len());
688/// Ok(format!("Processed {} bytes", bytes.len()))
689/// }
690/// ```
691pub struct Data(pub Vec<u8>);
692
693#[async_trait]
694impl FromMessage for Data {
695 async fn from_message(
696 message: &Message,
697 _conn: &Connection,
698 _state: &AppState,
699 _extensions: &Extensions,
700 ) -> Result<Self> {
701 Ok(Data(message.data.clone()))
702 }
703}
704
705/// Extractor for path parameters.
706///
707/// Extracts typed parameters from the request path. The type must implement
708/// `serde::Deserialize` and be stored in extensions by routing middleware.
709///
710/// # Examples
711///
712/// ## Single Parameter
713///
714/// ```
715/// use wsforge::prelude::*;
716/// use serde::Deserialize;
717///
718/// #[derive(Deserialize)]
719/// struct UserId(u64);
720///
721/// async fn get_user(Path(UserId(id)): Path<UserId>) -> Result<String> {
722/// Ok(format!("Getting user {}", id))
723/// }
724/// ```
725///
726/// ## Multiple Parameters
727///
728/// ```
729/// use wsforge::prelude::*;
730/// use serde::Deserialize;
731///
732/// #[derive(Deserialize)]
733/// struct RoomParams {
734/// room_id: String,
735/// user_id: u64,
736/// }
737///
738/// async fn join_room(Path(params): Path<RoomParams>) -> Result<String> {
739/// Ok(format!("User {} joining room {}", params.user_id, params.room_id))
740/// }
741/// ```
742pub struct Path<T>(pub T);
743
744#[async_trait]
745impl<T: DeserializeOwned + Send + Sync + Clone + 'static> FromMessage for Path<T> {
746 async fn from_message(
747 _message: &Message,
748 _conn: &Connection,
749 _state: &AppState,
750 extensions: &Extensions,
751 ) -> Result<Self> {
752 extensions
753 .get::<T>("path_params")
754 .ok_or_else(|| Error::extractor("Path parameters not found"))
755 .map(|arc| Path((*arc).clone()))
756 }
757}
758
759/// Extractor for query parameters.
760///
761/// Extracts typed parameters from the query string. The type must implement
762/// `serde::Deserialize` and be stored in extensions during connection establishment.
763///
764/// # Examples
765///
766/// ## Search Parameters
767///
768/// ```
769/// use wsforge::prelude::*;
770/// use serde::Deserialize;
771///
772/// #[derive(Deserialize)]
773/// struct SearchQuery {
774/// q: String,
775/// limit: Option<u32>,
776/// }
777///
778/// async fn search(Query(params): Query<SearchQuery>) -> Result<String> {
779/// let limit = params.limit.unwrap_or(10);
780/// Ok(format!("Searching for '{}' (limit: {})", params.q, limit))
781/// }
782/// ```
783pub struct Query<T>(pub T);
784
785#[async_trait]
786impl<T: DeserializeOwned + Send + Sync + Clone + 'static> FromMessage for Query<T> {
787 async fn from_message(
788 _message: &Message,
789 _conn: &Connection,
790 _state: &AppState,
791 extensions: &Extensions,
792 ) -> Result<Self> {
793 extensions
794 .get::<T>("query_params")
795 .ok_or_else(|| Error::extractor("Query parameters not found"))
796 .map(|arc| Query((*arc).clone()))
797 }
798}
799
800/// Extractor for custom extension data.
801///
802/// Retrieves data that was previously stored in extensions by middleware or other handlers.
803///
804/// # Examples
805///
806/// ## Authentication Data
807///
808/// ```
809/// use wsforge::prelude::*;
810/// use std::sync::Arc;
811///
812/// #[derive(Clone)]
813/// struct AuthData {
814/// user_id: u64,
815/// role: String,
816/// }
817///
818/// async fn protected_handler(Extension(auth): Extension<AuthData>) -> Result<String> {
819/// Ok(format!("User {} with role {}", auth.user_id, auth.role))
820/// }
821/// ```
822pub struct Extension<T>(pub Arc<T>);
823
824#[async_trait]
825impl<T: Send + Sync + Clone + 'static> FromMessage for Extension<T> {
826 async fn from_message(
827 _message: &Message,
828 _conn: &Connection,
829 _state: &AppState,
830 extensions: &Extensions,
831 ) -> Result<Self> {
832 extensions
833 .get::<T>(std::any::type_name::<T>())
834 .ok_or_else(|| Error::extractor("Extension not found"))
835 .map(Extension)
836 }
837}