torch_web/
request.rs

1use std::collections::HashMap;
2use std::any::{Any, TypeId};
3use http::{HeaderMap, Method, Uri, Version};
4use http_body_util::BodyExt;
5use hyper::body::Incoming;
6
7/// HTTP Request wrapper that provides convenient access to request data
8#[derive(Debug)]
9pub struct Request {
10    method: Method,
11    uri: Uri,
12    version: Version,
13    headers: HeaderMap,
14    body: Vec<u8>,
15    params: HashMap<String, String>,
16    query: HashMap<String, String>,
17    extensions: HashMap<TypeId, Box<dyn Any + Send + Sync>>,
18}
19
20impl Request {
21    /// Create a simple empty request (for internal use)
22    pub fn new() -> Self {
23        Self {
24            method: Method::GET,
25            uri: "/".parse().unwrap(),
26            version: Version::HTTP_11,
27            headers: HeaderMap::new(),
28            body: Vec::new(),
29            params: HashMap::new(),
30            query: HashMap::new(),
31            extensions: HashMap::new(),
32        }
33    }
34
35    /// Create a new Request from hyper's request parts and body
36    pub async fn from_hyper(
37        parts: http::request::Parts,
38        body: Incoming,
39    ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
40        let body_bytes = body.collect().await?.to_bytes().to_vec();
41        
42        let query = Self::parse_query_string(parts.uri.query().unwrap_or(""));
43        
44        Ok(Request {
45            method: parts.method,
46            uri: parts.uri,
47            version: parts.version,
48            headers: parts.headers,
49            body: body_bytes,
50            params: HashMap::new(),
51            query,
52            extensions: HashMap::new(),
53        })
54    }
55
56    /// Get the HTTP method
57    pub fn method(&self) -> &Method {
58        &self.method
59    }
60
61    /// Get the URI
62    pub fn uri(&self) -> &Uri {
63        &self.uri
64    }
65
66    /// Get the path from the URI
67    pub fn path(&self) -> &str {
68        self.uri.path()
69    }
70
71    /// Get the HTTP version
72    pub fn version(&self) -> Version {
73        self.version
74    }
75
76    /// Get the headers
77    pub fn headers(&self) -> &HeaderMap {
78        &self.headers
79    }
80
81    /// Get a specific header value
82    pub fn header(&self, name: &str) -> Option<&str> {
83        self.headers.get(name)?.to_str().ok()
84    }
85
86    /// Get the request body as bytes
87    pub fn body(&self) -> &[u8] {
88        &self.body
89    }
90
91    /// Get the request body as a string
92    pub fn body_string(&self) -> Result<String, std::string::FromUtf8Error> {
93        String::from_utf8(self.body.clone())
94    }
95
96    /// Parse the request body as JSON (requires "json" feature)
97    #[cfg(feature = "json")]
98    pub async fn json<T>(&self) -> Result<T, serde_json::Error>
99    where
100        T: serde::de::DeserializeOwned,
101    {
102        serde_json::from_slice(&self.body)
103    }
104
105    /// Get a path parameter by name
106    pub fn param(&self, name: &str) -> Option<&str> {
107        self.params.get(name).map(|s| s.as_str())
108    }
109
110    /// Get all path parameters
111    pub fn params(&self) -> &HashMap<String, String> {
112        &self.params
113    }
114
115    /// Set a path parameter (used internally by the router)
116    pub(crate) fn set_param(&mut self, name: String, value: String) {
117        self.params.insert(name, value);
118    }
119
120    /// Get a reference to the request extensions
121    pub fn extensions(&self) -> &HashMap<TypeId, Box<dyn Any + Send + Sync>> {
122        &self.extensions
123    }
124
125    /// Get a mutable reference to the request extensions
126    pub fn extensions_mut(&mut self) -> &mut HashMap<TypeId, Box<dyn Any + Send + Sync>> {
127        &mut self.extensions
128    }
129
130    /// Insert a value into the request extensions
131    pub fn insert_extension<T: Send + Sync + 'static>(&mut self, value: T) {
132        self.extensions.insert(TypeId::of::<T>(), Box::new(value));
133    }
134
135    /// Get a value from the request extensions
136    pub fn get_extension<T: Send + Sync + 'static>(&self) -> Option<&T> {
137        self.extensions
138            .get(&TypeId::of::<T>())
139            .and_then(|boxed| boxed.downcast_ref())
140    }
141
142    /// Get a query parameter by name
143    pub fn query(&self, name: &str) -> Option<&str> {
144        self.query.get(name).map(|s| s.as_str())
145    }
146
147    /// Get all query parameters
148    pub fn query_params(&self) -> &HashMap<String, String> {
149        &self.query
150    }
151
152    /// Parse query string into a HashMap
153    fn parse_query_string(query: &str) -> HashMap<String, String> {
154        let mut params = HashMap::new();
155        
156        for pair in query.split('&') {
157            if let Some((key, value)) = pair.split_once('=') {
158                let key = urlencoding::decode(key).unwrap_or_else(|_| key.into()).into_owned();
159                let value = urlencoding::decode(value).unwrap_or_else(|_| value.into()).into_owned();
160                params.insert(key, value);
161            } else if !pair.is_empty() {
162                let key = urlencoding::decode(pair).unwrap_or_else(|_| pair.into()).into_owned();
163                params.insert(key, String::new());
164            }
165        }
166        
167        params
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    // Imports for potential future test use
175
176    #[test]
177    fn test_parse_query_string() {
178        let query = "name=John&age=30&city=New%20York";
179        let params = Request::parse_query_string(query);
180        
181        assert_eq!(params.get("name"), Some(&"John".to_string()));
182        assert_eq!(params.get("age"), Some(&"30".to_string()));
183        assert_eq!(params.get("city"), Some(&"New York".to_string()));
184    }
185
186    #[test]
187    fn test_parse_empty_query_string() {
188        let params = Request::parse_query_string("");
189        assert!(params.is_empty());
190    }
191}