torch_web/extractors/
query.rs1use std::collections::HashMap;
6use std::pin::Pin;
7use std::future::Future;
8use crate::{Request, extractors::{FromRequestParts, ExtractionError}};
9
10pub struct Query<T>(pub T);
37
38impl<T> FromRequestParts for Query<T>
39where
40 T: DeserializeFromQuery,
41{
42 type Error = ExtractionError;
43
44 fn from_request_parts(
45 req: &mut Request,
46 ) -> Pin<Box<dyn Future<Output = Result<Self, Self::Error>> + Send + 'static>> {
47 let query_string = req.query_string().unwrap_or("").to_string();
48
49 Box::pin(async move {
50 let value = T::deserialize_from_query(&query_string)?;
51 Ok(Query(value))
52 })
53 }
54}
55
56pub trait DeserializeFromQuery: Sized {
58 fn deserialize_from_query(query: &str) -> Result<Self, ExtractionError>;
59}
60
61impl DeserializeFromQuery for HashMap<String, String> {
63 fn deserialize_from_query(query: &str) -> Result<Self, ExtractionError> {
64 let mut params = HashMap::new();
65
66 if query.is_empty() {
67 return Ok(params);
68 }
69
70 for pair in query.split('&') {
71 if let Some((key, value)) = pair.split_once('=') {
72 let key = urlencoding::decode(key)
73 .map_err(|e| ExtractionError::InvalidQuery(format!("Invalid key encoding: {}", e)))?
74 .into_owned();
75 let value = urlencoding::decode(value)
76 .map_err(|e| ExtractionError::InvalidQuery(format!("Invalid value encoding: {}", e)))?
77 .into_owned();
78 params.insert(key, value);
79 } else {
80 let key = urlencoding::decode(pair)
82 .map_err(|e| ExtractionError::InvalidQuery(format!("Invalid key encoding: {}", e)))?
83 .into_owned();
84 params.insert(key, String::new());
85 }
86 }
87
88 Ok(params)
89 }
90}
91
92impl DeserializeFromQuery for Vec<(String, String)> {
94 fn deserialize_from_query(query: &str) -> Result<Self, ExtractionError> {
95 let mut params = Vec::new();
96
97 if query.is_empty() {
98 return Ok(params);
99 }
100
101 for pair in query.split('&') {
102 if let Some((key, value)) = pair.split_once('=') {
103 let key = urlencoding::decode(key)
104 .map_err(|e| ExtractionError::InvalidQuery(format!("Invalid key encoding: {}", e)))?
105 .into_owned();
106 let value = urlencoding::decode(value)
107 .map_err(|e| ExtractionError::InvalidQuery(format!("Invalid value encoding: {}", e)))?
108 .into_owned();
109 params.push((key, value));
110 } else {
111 let key = urlencoding::decode(pair)
112 .map_err(|e| ExtractionError::InvalidQuery(format!("Invalid key encoding: {}", e)))?
113 .into_owned();
114 params.push((key, String::new()));
115 }
116 }
117
118 Ok(params)
119 }
120}
121
122#[cfg(feature = "json")]
124pub struct SerdeQuery<T>(pub T);
125
126#[cfg(feature = "json")]
127impl<T> FromRequestParts for SerdeQuery<T>
128where
129 T: serde::de::DeserializeOwned,
130{
131 type Error = ExtractionError;
132
133 fn from_request_parts(
134 req: &mut Request,
135 ) -> Pin<Box<dyn Future<Output = Result<Self, Self::Error>> + Send + 'static>> {
136 let query_string = req.query_string().unwrap_or("").to_string();
137
138 Box::pin(async move {
139 let value = deserialize_query_with_serde(&query_string)?;
140 Ok(SerdeQuery(value))
141 })
142 }
143}
144
145#[cfg(feature = "json")]
146fn deserialize_query_with_serde<T: serde::de::DeserializeOwned>(query: &str) -> Result<T, ExtractionError> {
147
148 let params: HashMap<String, String> = DeserializeFromQuery::deserialize_from_query(query)?;
150
151 let mut json_map = serde_json::Map::new();
153 for (key, value) in params {
154 let json_value = if value.is_empty() {
156 serde_json::Value::Bool(true) } else if value == "true" {
158 serde_json::Value::Bool(true)
159 } else if value == "false" {
160 serde_json::Value::Bool(false)
161 } else if value == "null" {
162 serde_json::Value::Null
163 } else if let Ok(num) = value.parse::<i64>() {
164 serde_json::Value::Number(serde_json::Number::from(num))
165 } else if let Ok(float) = value.parse::<f64>() {
166 if let Some(num) = serde_json::Number::from_f64(float) {
167 serde_json::Value::Number(num)
168 } else {
169 serde_json::Value::String(value)
170 }
171 } else {
172 if value.contains(',') {
174 let array_values: Vec<serde_json::Value> = value
175 .split(',')
176 .map(|s| {
177 let trimmed = s.trim();
178 if let Ok(num) = trimmed.parse::<i64>() {
179 serde_json::Value::Number(serde_json::Number::from(num))
180 } else if let Ok(float) = trimmed.parse::<f64>() {
181 if let Some(num) = serde_json::Number::from_f64(float) {
182 serde_json::Value::Number(num)
183 } else {
184 serde_json::Value::String(trimmed.to_string())
185 }
186 } else {
187 serde_json::Value::String(trimmed.to_string())
188 }
189 })
190 .collect();
191 serde_json::Value::Array(array_values)
192 } else {
193 serde_json::Value::String(value)
194 }
195 };
196 json_map.insert(key, json_value);
197 }
198
199 let json_value = serde_json::Value::Object(json_map);
200 serde_json::from_value(json_value).map_err(|e| {
201 ExtractionError::InvalidQuery(format!("Failed to deserialize query parameters: {}", e))
202 })
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn test_empty_query() {
211 let result: Result<HashMap<String, String>, _> =
212 DeserializeFromQuery::deserialize_from_query("");
213 assert!(result.unwrap().is_empty());
214 }
215
216 #[test]
217 fn test_simple_query() {
218 let result: Result<HashMap<String, String>, _> =
219 DeserializeFromQuery::deserialize_from_query("name=john&age=30");
220
221 let params = result.unwrap();
222 assert_eq!(params.get("name"), Some(&"john".to_string()));
223 assert_eq!(params.get("age"), Some(&"30".to_string()));
224 }
225
226 #[test]
227 fn test_url_encoded_query() {
228 let result: Result<HashMap<String, String>, _> =
229 DeserializeFromQuery::deserialize_from_query("name=John%20Doe&city=New%20York");
230
231 let params = result.unwrap();
232 assert_eq!(params.get("name"), Some(&"John Doe".to_string()));
233 assert_eq!(params.get("city"), Some(&"New York".to_string()));
234 }
235
236 #[test]
237 fn test_flag_parameters() {
238 let result: Result<HashMap<String, String>, _> =
239 DeserializeFromQuery::deserialize_from_query("debug&verbose&name=test");
240
241 let params = result.unwrap();
242 assert_eq!(params.get("debug"), Some(&"".to_string()));
243 assert_eq!(params.get("verbose"), Some(&"".to_string()));
244 assert_eq!(params.get("name"), Some(&"test".to_string()));
245 }
246
247 #[test]
248 fn test_vec_preserves_order() {
249 let result: Result<Vec<(String, String)>, _> =
250 DeserializeFromQuery::deserialize_from_query("a=1&b=2&a=3");
251
252 let params = result.unwrap();
253 assert_eq!(params.len(), 3);
254 assert_eq!(params[0], ("a".to_string(), "1".to_string()));
255 assert_eq!(params[1], ("b".to_string(), "2".to_string()));
256 assert_eq!(params[2], ("a".to_string(), "3".to_string()));
257 }
258}