torch_web/
request.rs

1use std::collections::HashMap;
2use std::any::{Any, TypeId};
3use std::sync::Arc;
4use http::{HeaderMap, Method, Uri, Version};
5use http_body_util::BodyExt;
6use hyper::body::Incoming;
7use crate::extractors::state::StateMap;
8
9/// HTTP Request wrapper that provides convenient access to request data
10#[derive(Debug)]
11pub struct Request {
12    method: Method,
13    uri: Uri,
14    version: Version,
15    headers: HeaderMap,
16    body: Vec<u8>,
17    params: HashMap<String, String>,
18    query: HashMap<String, String>,
19    extensions: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
20}
21
22impl Request {
23    /// Create a simple empty request (for internal use)
24    pub fn new() -> Self {
25        Self {
26            method: Method::GET,
27            uri: "/".parse().unwrap(),
28            version: Version::HTTP_11,
29            headers: HeaderMap::new(),
30            body: Vec::new(),
31            params: HashMap::new(),
32            query: HashMap::new(),
33            extensions: HashMap::new(),
34        }
35    }
36
37    /// Create a new Request from hyper's request parts and body
38    pub async fn from_hyper(
39        parts: http::request::Parts,
40        body: Incoming,
41    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
42        let body_bytes = body.collect().await?.to_bytes().to_vec();
43        
44        let query = Self::parse_query_string(parts.uri.query().unwrap_or(""));
45        
46        Ok(Request {
47            method: parts.method,
48            uri: parts.uri,
49            version: parts.version,
50            headers: parts.headers,
51            body: body_bytes,
52            params: HashMap::new(),
53            query,
54            extensions: HashMap::new(),
55        })
56    }
57
58    /// Get the HTTP method
59    pub fn method(&self) -> &Method {
60        &self.method
61    }
62
63    /// Get the URI
64    pub fn uri(&self) -> &Uri {
65        &self.uri
66    }
67
68    /// Get the path from the URI
69    pub fn path(&self) -> &str {
70        self.uri.path()
71    }
72
73    /// Get the HTTP version
74    pub fn version(&self) -> Version {
75        self.version
76    }
77
78    /// Get the headers
79    pub fn headers(&self) -> &HeaderMap {
80        &self.headers
81    }
82
83    /// Get a specific header value
84    pub fn header(&self, name: &str) -> Option<&str> {
85        self.headers.get(name)?.to_str().ok()
86    }
87
88    /// Get the request body as bytes
89    pub fn body(&self) -> &[u8] {
90        &self.body
91    }
92
93    /// Get the request body as a string
94    pub fn body_string(&self) -> Result<String, std::string::FromUtf8Error> {
95        String::from_utf8(self.body.clone())
96    }
97
98    /// Parse the request body as JSON (requires "json" feature)
99    #[cfg(feature = "json")]
100    pub async fn json<T>(&self) -> Result<T, serde_json::Error>
101    where
102        T: serde::de::DeserializeOwned,
103    {
104        serde_json::from_slice(&self.body)
105    }
106
107    /// Get a path parameter by name
108    pub fn param(&self, name: &str) -> Option<&str> {
109        self.params.get(name).map(|s| s.as_str())
110    }
111
112    /// Get all path parameters
113    pub fn params(&self) -> &HashMap<String, String> {
114        &self.params
115    }
116
117    /// Get all path parameters (for extractors)
118    pub fn path_params(&self) -> &HashMap<String, String> {
119        &self.params
120    }
121
122    /// Set a path parameter (used internally by the router)
123    pub(crate) fn set_param(&mut self, name: String, value: String) {
124        self.params.insert(name, value);
125    }
126
127    /// Get a reference to the request extensions
128    pub fn extensions(&self) -> &HashMap<TypeId, Box<dyn Any + Send + Sync>> {
129        &self.extensions
130    }
131
132    /// Get a mutable reference to the request extensions
133    pub fn extensions_mut(&mut self) -> &mut HashMap<TypeId, Box<dyn Any + Send + Sync>> {
134        &mut self.extensions
135    }
136
137    /// Insert a value into the request extensions
138    pub fn insert_extension<T: Send + Sync + 'static>(&mut self, value: T) {
139        self.extensions.insert(TypeId::of::<T>(), Box::new(value));
140    }
141
142    /// Get a value from the request extensions
143    pub fn get_extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
144        self.extensions
145            .get(&TypeId::of::<T>())
146            .and_then(|boxed| boxed.downcast_ref())
147    }
148
149    /// Get a query parameter by name
150    pub fn query(&self, name: &str) -> Option<&str> {
151        self.query.get(name).map(|s| s.as_str())
152    }
153
154    /// Get all query parameters
155    pub fn query_params(&self) -> &HashMap<String, String> {
156        &self.query
157    }
158
159    /// Get the raw query string
160    pub fn query_string(&self) -> Option<&str> {
161        self.uri.query()
162    }
163
164    /// Get the request body as bytes (for extractors)
165    pub fn body_bytes(&self) -> &[u8] {
166        &self.body
167    }
168
169    /// Set the request body (for testing)
170    #[cfg(test)]
171    pub fn set_body(&mut self, body: Vec<u8>) {
172        self.body = body;
173    }
174
175    /// Get mutable access to headers (for extractors)
176    pub fn headers_mut(&mut self) -> &mut HeaderMap {
177        &mut self.headers
178    }
179
180    /// Parse query string into a HashMap
181    fn parse_query_string(query: &str) -> HashMap<String, String> {
182        let mut params = HashMap::new();
183        
184        for pair in query.split('&') {
185            if let Some((key, value)) = pair.split_once('=') {
186                let key = urlencoding::decode(key).unwrap_or_else(|_| key.into()).into_owned();
187                let value = urlencoding::decode(value).unwrap_or_else(|_| value.into()).into_owned();
188                params.insert(key, value);
189            } else if !pair.is_empty() {
190                let key = urlencoding::decode(pair).unwrap_or_else(|_| pair.into()).into_owned();
191                params.insert(key, String::new());
192            }
193        }
194        
195        params
196    }
197}
198
199/// Implementation of RequestStateExt for Request
200impl crate::extractors::state::RequestStateExt for Request {
201    fn get_state(&self, type_id: TypeId) -> Option<&Arc<dyn Any + Send + Sync>> {
202        // Check if we have a StateMap stored in extensions
203        if let Some(state_map_any) = self.extensions.get(&TypeId::of::<StateMap>()) {
204            if let Some(state_map) = state_map_any.downcast_ref::<StateMap>() {
205                return state_map.get_by_type_id(type_id);
206            }
207        }
208        None
209    }
210
211    fn set_state_map(&mut self, state_map: StateMap) {
212        self.extensions.insert(TypeId::of::<StateMap>(), Box::new(state_map));
213    }
214
215    fn state_map(&self) -> Option<&StateMap> {
216        self.extensions
217            .get(&TypeId::of::<StateMap>())
218            .and_then(|state_map_any| state_map_any.downcast_ref::<StateMap>())
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    // Imports for potential future test use
226
227    #[test]
228    fn test_parse_query_string() {
229        let query = "name=John&age=30&city=New%20York";
230        let params = Request::parse_query_string(query);
231        
232        assert_eq!(params.get("name"), Some(&"John".to_string()));
233        assert_eq!(params.get("age"), Some(&"30".to_string()));
234        assert_eq!(params.get("city"), Some(&"New York".to_string()));
235    }
236
237    #[test]
238    fn test_parse_empty_query_string() {
239        let params = Request::parse_query_string("");
240        assert!(params.is_empty());
241    }
242}