torch_web/extractors/
path.rs1use std::collections::HashMap;
6use std::pin::Pin;
7use std::future::Future;
8use std::str::FromStr;
9use crate::{Request, extractors::{FromRequestParts, ExtractionError}};
10
11pub struct Path<T>(pub T);
40
41impl<T> FromRequestParts for Path<T>
42where
43 T: DeserializeFromPath,
44{
45 type Error = ExtractionError;
46
47 fn from_request_parts(
48 req: &mut Request,
49 ) -> Pin<Box<dyn Future<Output = Result<Self, Self::Error>> + Send + 'static>> {
50 let params = req.path_params().clone();
51
52 Box::pin(async move {
53 let value = T::deserialize_from_path(params)?;
54 Ok(Path(value))
55 })
56 }
57}
58
59pub trait DeserializeFromPath: Sized {
61 fn deserialize_from_path(params: HashMap<String, String>) -> Result<Self, ExtractionError>;
62}
63
64pub trait PathDeserializable {}
66
67impl PathDeserializable for String {}
69impl PathDeserializable for u8 {}
70impl PathDeserializable for u16 {}
71impl PathDeserializable for u32 {}
72impl PathDeserializable for u64 {}
73impl PathDeserializable for usize {}
74impl PathDeserializable for i8 {}
75impl PathDeserializable for i16 {}
76impl PathDeserializable for i32 {}
77impl PathDeserializable for i64 {}
78impl PathDeserializable for isize {}
79impl PathDeserializable for f32 {}
80impl PathDeserializable for f64 {}
81impl PathDeserializable for bool {}
82impl PathDeserializable for std::net::IpAddr {}
83impl PathDeserializable for std::net::Ipv4Addr {}
84impl PathDeserializable for std::net::Ipv6Addr {}
85
86#[cfg(feature = "uuid")]
87impl PathDeserializable for uuid::Uuid {}
88
89impl<T> DeserializeFromPath for T
91where
92 T: FromStr + PathDeserializable,
93 T::Err: std::fmt::Display,
94{
95 fn deserialize_from_path(params: HashMap<String, String>) -> Result<Self, ExtractionError> {
96 if params.len() != 1 {
97 return Err(ExtractionError::InvalidPathParam(
98 format!("Expected exactly one path parameter for type {}, got {}",
99 std::any::type_name::<T>(), params.len())
100 ));
101 }
102
103 let (param_name, value) = params.into_iter().next().unwrap();
104 value.parse().map_err(|e| {
105 ExtractionError::InvalidPathParam(
106 format!("Failed to parse parameter '{}' as {}: {}",
107 param_name, std::any::type_name::<T>(), e)
108 )
109 })
110 }
111}
112
113impl<T1, T2> DeserializeFromPath for (T1, T2)
115where
116 T1: FromStr + PathDeserializable,
117 T2: FromStr + PathDeserializable,
118 T1::Err: std::fmt::Display,
119 T2::Err: std::fmt::Display,
120{
121 fn deserialize_from_path(params: HashMap<String, String>) -> Result<Self, ExtractionError> {
122 if params.len() != 2 {
123 return Err(ExtractionError::InvalidPathParam(
124 format!("Expected exactly 2 path parameters, got {}", params.len())
125 ));
126 }
127
128 let mut param_pairs: Vec<_> = params.into_iter().collect();
130 param_pairs.sort_by(|a, b| a.0.cmp(&b.0)); let first = param_pairs[0].1.parse().map_err(|e| {
133 ExtractionError::InvalidPathParam(
134 format!("Failed to parse parameter '{}' as {}: {}",
135 param_pairs[0].0, std::any::type_name::<T1>(), e)
136 )
137 })?;
138
139 let second = param_pairs[1].1.parse().map_err(|e| {
140 ExtractionError::InvalidPathParam(
141 format!("Failed to parse parameter '{}' as {}: {}",
142 param_pairs[1].0, std::any::type_name::<T2>(), e)
143 )
144 })?;
145
146 Ok((first, second))
147 }
148}
149
150impl DeserializeFromPath for HashMap<String, String> {
152 fn deserialize_from_path(params: HashMap<String, String>) -> Result<Self, ExtractionError> {
153 Ok(params)
154 }
155}
156
157impl<T1, T2, T3> DeserializeFromPath for (T1, T2, T3)
159where
160 T1: FromStr + PathDeserializable,
161 T2: FromStr + PathDeserializable,
162 T3: FromStr + PathDeserializable,
163 T1::Err: std::fmt::Display,
164 T2::Err: std::fmt::Display,
165 T3::Err: std::fmt::Display,
166{
167 fn deserialize_from_path(params: HashMap<String, String>) -> Result<Self, ExtractionError> {
168 if params.len() != 3 {
169 return Err(ExtractionError::InvalidPathParam(
170 format!("Expected exactly 3 path parameters, got {}", params.len())
171 ));
172 }
173
174 let mut param_pairs: Vec<_> = params.into_iter().collect();
175 param_pairs.sort_by(|a, b| a.0.cmp(&b.0));
176
177 let first = param_pairs[0].1.parse().map_err(|e| {
178 ExtractionError::InvalidPathParam(
179 format!("Failed to parse parameter '{}': {}", param_pairs[0].0, e)
180 )
181 })?;
182
183 let second = param_pairs[1].1.parse().map_err(|e| {
184 ExtractionError::InvalidPathParam(
185 format!("Failed to parse parameter '{}': {}", param_pairs[1].0, e)
186 )
187 })?;
188
189 let third = param_pairs[2].1.parse().map_err(|e| {
190 ExtractionError::InvalidPathParam(
191 format!("Failed to parse parameter '{}': {}", param_pairs[2].0, e)
192 )
193 })?;
194
195 Ok((first, second, third))
196 }
197}
198
199#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_single_param_extraction() {
209 let mut params = HashMap::new();
210 params.insert("id".to_string(), "123".to_string());
211
212 let result: Result<u32, _> = DeserializeFromPath::deserialize_from_path(params);
213 assert_eq!(result.unwrap(), 123);
214 }
215
216 #[test]
217 fn test_invalid_param_extraction() {
218 let mut params = HashMap::new();
219 params.insert("id".to_string(), "not_a_number".to_string());
220
221 let result: Result<u32, _> = DeserializeFromPath::deserialize_from_path(params);
222 assert!(result.is_err());
223 }
224
225 #[test]
226 fn test_hashmap_extraction() {
227 let mut params = HashMap::new();
228 params.insert("user_id".to_string(), "123".to_string());
229 params.insert("post_id".to_string(), "456".to_string());
230
231 let result: Result<HashMap<String, String>, _> =
232 DeserializeFromPath::deserialize_from_path(params.clone());
233 assert_eq!(result.unwrap(), params);
234 }
235}