torch_web/extractors/
query.rs

1//! Query parameter extraction
2//!
3//! Extract and deserialize query parameters from the URL.
4
5use std::collections::HashMap;
6use std::pin::Pin;
7use std::future::Future;
8use crate::{Request, extractors::{FromRequestParts, ExtractionError}};
9
10/// Extract query parameters from the request URL
11///
12/// # Example
13///
14/// ```rust,no_run
15/// use torch_web::extractors::Query;
16/// use std::collections::HashMap;
17/// use serde::Deserialize;
18///
19/// // Extract as HashMap
20/// async fn search(Query(params): Query<HashMap<String, String>>) {
21///     // params contains all query parameters
22/// }
23///
24/// // Extract into a custom struct
25/// #[derive(Deserialize)]
26/// struct SearchParams {
27///     q: String,
28///     page: Option<u32>,
29///     limit: Option<u32>,
30/// }
31///
32/// async fn search_typed(Query(params): Query<SearchParams>) {
33///     // Automatically deserializes and validates query parameters
34/// }
35/// ```
36pub struct Query<T>(pub T);
37
38impl<T> FromRequestParts for Query<T>
39where
40    T: DeserializeFromQuery,
41{
42    type Error = ExtractionError;
43
44    fn from_request_parts(
45        req: &mut Request,
46    ) -> Pin<Box<dyn Future<Output = Result<Self, Self::Error>> + Send + 'static>> {
47        let query_string = req.query_string().unwrap_or("").to_string();
48        
49        Box::pin(async move {
50            let value = T::deserialize_from_query(&query_string)?;
51            Ok(Query(value))
52        })
53    }
54}
55
56/// Trait for types that can be deserialized from query parameters
57pub trait DeserializeFromQuery: Sized {
58    fn deserialize_from_query(query: &str) -> Result<Self, ExtractionError>;
59}
60
61// Implement for HashMap<String, String> to get all parameters
62impl DeserializeFromQuery for HashMap<String, String> {
63    fn deserialize_from_query(query: &str) -> Result<Self, ExtractionError> {
64        let mut params = HashMap::new();
65        
66        if query.is_empty() {
67            return Ok(params);
68        }
69
70        for pair in query.split('&') {
71            if let Some((key, value)) = pair.split_once('=') {
72                let key = urlencoding::decode(key)
73                    .map_err(|e| ExtractionError::InvalidQuery(format!("Invalid key encoding: {}", e)))?
74                    .into_owned();
75                let value = urlencoding::decode(value)
76                    .map_err(|e| ExtractionError::InvalidQuery(format!("Invalid value encoding: {}", e)))?
77                    .into_owned();
78                params.insert(key, value);
79            } else {
80                // Handle keys without values (e.g., "?debug&verbose")
81                let key = urlencoding::decode(pair)
82                    .map_err(|e| ExtractionError::InvalidQuery(format!("Invalid key encoding: {}", e)))?
83                    .into_owned();
84                params.insert(key, String::new());
85            }
86        }
87
88        Ok(params)
89    }
90}
91
92// Implement for Vec<(String, String)> to preserve order and duplicates
93impl DeserializeFromQuery for Vec<(String, String)> {
94    fn deserialize_from_query(query: &str) -> Result<Self, ExtractionError> {
95        let mut params = Vec::new();
96        
97        if query.is_empty() {
98            return Ok(params);
99        }
100
101        for pair in query.split('&') {
102            if let Some((key, value)) = pair.split_once('=') {
103                let key = urlencoding::decode(key)
104                    .map_err(|e| ExtractionError::InvalidQuery(format!("Invalid key encoding: {}", e)))?
105                    .into_owned();
106                let value = urlencoding::decode(value)
107                    .map_err(|e| ExtractionError::InvalidQuery(format!("Invalid value encoding: {}", e)))?
108                    .into_owned();
109                params.push((key, value));
110            } else {
111                let key = urlencoding::decode(pair)
112                    .map_err(|e| ExtractionError::InvalidQuery(format!("Invalid key encoding: {}", e)))?
113                    .into_owned();
114                params.push((key, String::new()));
115            }
116        }
117
118        Ok(params)
119    }
120}
121
122/// Serde-based query parameter extractor for custom types
123#[cfg(feature = "json")]
124pub struct SerdeQuery<T>(pub T);
125
126#[cfg(feature = "json")]
127impl<T> FromRequestParts for SerdeQuery<T>
128where
129    T: serde::de::DeserializeOwned,
130{
131    type Error = ExtractionError;
132
133    fn from_request_parts(
134        req: &mut Request,
135    ) -> Pin<Box<dyn Future<Output = Result<Self, Self::Error>> + Send + 'static>> {
136        let query_string = req.query_string().unwrap_or("").to_string();
137
138        Box::pin(async move {
139            let value = deserialize_query_with_serde(&query_string)?;
140            Ok(SerdeQuery(value))
141        })
142    }
143}
144
145#[cfg(feature = "json")]
146fn deserialize_query_with_serde<T: serde::de::DeserializeOwned>(query: &str) -> Result<T, ExtractionError> {
147
148    // First parse into HashMap
149    let params: HashMap<String, String> = DeserializeFromQuery::deserialize_from_query(query)?;
150
151    // Convert to serde_json::Value for deserialization
152    let mut json_map = serde_json::Map::new();
153    for (key, value) in params {
154        // Try to parse as different types with better type inference
155        let json_value = if value.is_empty() {
156            serde_json::Value::Bool(true) // For flag-style parameters
157        } else if value == "true" {
158            serde_json::Value::Bool(true)
159        } else if value == "false" {
160            serde_json::Value::Bool(false)
161        } else if value == "null" {
162            serde_json::Value::Null
163        } else if let Ok(num) = value.parse::<i64>() {
164            serde_json::Value::Number(serde_json::Number::from(num))
165        } else if let Ok(float) = value.parse::<f64>() {
166            if let Some(num) = serde_json::Number::from_f64(float) {
167                serde_json::Value::Number(num)
168            } else {
169                serde_json::Value::String(value)
170            }
171        } else {
172            // Handle arrays (comma-separated values)
173            if value.contains(',') {
174                let array_values: Vec<serde_json::Value> = value
175                    .split(',')
176                    .map(|s| {
177                        let trimmed = s.trim();
178                        if let Ok(num) = trimmed.parse::<i64>() {
179                            serde_json::Value::Number(serde_json::Number::from(num))
180                        } else if let Ok(float) = trimmed.parse::<f64>() {
181                            if let Some(num) = serde_json::Number::from_f64(float) {
182                                serde_json::Value::Number(num)
183                            } else {
184                                serde_json::Value::String(trimmed.to_string())
185                            }
186                        } else {
187                            serde_json::Value::String(trimmed.to_string())
188                        }
189                    })
190                    .collect();
191                serde_json::Value::Array(array_values)
192            } else {
193                serde_json::Value::String(value)
194            }
195        };
196        json_map.insert(key, json_value);
197    }
198
199    let json_value = serde_json::Value::Object(json_map);
200    serde_json::from_value(json_value).map_err(|e| {
201        ExtractionError::InvalidQuery(format!("Failed to deserialize query parameters: {}", e))
202    })
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn test_empty_query() {
211        let result: Result<HashMap<String, String>, _> = 
212            DeserializeFromQuery::deserialize_from_query("");
213        assert!(result.unwrap().is_empty());
214    }
215
216    #[test]
217    fn test_simple_query() {
218        let result: Result<HashMap<String, String>, _> = 
219            DeserializeFromQuery::deserialize_from_query("name=john&age=30");
220        
221        let params = result.unwrap();
222        assert_eq!(params.get("name"), Some(&"john".to_string()));
223        assert_eq!(params.get("age"), Some(&"30".to_string()));
224    }
225
226    #[test]
227    fn test_url_encoded_query() {
228        let result: Result<HashMap<String, String>, _> = 
229            DeserializeFromQuery::deserialize_from_query("name=John%20Doe&city=New%20York");
230        
231        let params = result.unwrap();
232        assert_eq!(params.get("name"), Some(&"John Doe".to_string()));
233        assert_eq!(params.get("city"), Some(&"New York".to_string()));
234    }
235
236    #[test]
237    fn test_flag_parameters() {
238        let result: Result<HashMap<String, String>, _> = 
239            DeserializeFromQuery::deserialize_from_query("debug&verbose&name=test");
240        
241        let params = result.unwrap();
242        assert_eq!(params.get("debug"), Some(&"".to_string()));
243        assert_eq!(params.get("verbose"), Some(&"".to_string()));
244        assert_eq!(params.get("name"), Some(&"test".to_string()));
245    }
246
247    #[test]
248    fn test_vec_preserves_order() {
249        let result: Result<Vec<(String, String)>, _> = 
250            DeserializeFromQuery::deserialize_from_query("a=1&b=2&a=3");
251        
252        let params = result.unwrap();
253        assert_eq!(params.len(), 3);
254        assert_eq!(params[0], ("a".to_string(), "1".to_string()));
255        assert_eq!(params[1], ("b".to_string(), "2".to_string()));
256        assert_eq!(params[2], ("a".to_string(), "3".to_string()));
257    }
258}