Skip to main content

spikard_core/
http.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::sync::OnceLock;
4
5/// HTTP method
6#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
7pub enum Method {
8    #[default]
9    Get,
10    Post,
11    Put,
12    Patch,
13    Delete,
14    Head,
15    Options,
16    Trace,
17}
18
19impl Method {
20    #[must_use]
21    pub const fn as_str(&self) -> &'static str {
22        match self {
23            Self::Get => "GET",
24            Self::Post => "POST",
25            Self::Put => "PUT",
26            Self::Patch => "PATCH",
27            Self::Delete => "DELETE",
28            Self::Head => "HEAD",
29            Self::Options => "OPTIONS",
30            Self::Trace => "TRACE",
31        }
32    }
33}
34
35impl std::fmt::Display for Method {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        write!(f, "{}", self.as_str())
38    }
39}
40
41impl From<Method> for http::method::Method {
42    fn from(method: Method) -> Self {
43        match method {
44            Method::Get => Self::GET,
45            Method::Post => Self::POST,
46            Method::Put => Self::PUT,
47            Method::Patch => Self::PATCH,
48            Method::Delete => Self::DELETE,
49            Method::Head => Self::HEAD,
50            Method::Options => Self::OPTIONS,
51            Method::Trace => Self::TRACE,
52        }
53    }
54}
55
56impl From<&Method> for http::method::Method {
57    fn from(method: &Method) -> Self {
58        match method {
59            Method::Get => Self::GET,
60            Method::Post => Self::POST,
61            Method::Put => Self::PUT,
62            Method::Patch => Self::PATCH,
63            Method::Delete => Self::DELETE,
64            Method::Head => Self::HEAD,
65            Method::Options => Self::OPTIONS,
66            Method::Trace => Self::TRACE,
67        }
68    }
69}
70
71impl std::str::FromStr for Method {
72    type Err = String;
73
74    fn from_str(s: &str) -> Result<Self, Self::Err> {
75        match s.to_uppercase().as_str() {
76            "GET" => Ok(Self::Get),
77            "POST" => Ok(Self::Post),
78            "PUT" => Ok(Self::Put),
79            "PATCH" => Ok(Self::Patch),
80            "DELETE" => Ok(Self::Delete),
81            "HEAD" => Ok(Self::Head),
82            "OPTIONS" => Ok(Self::Options),
83            "TRACE" => Ok(Self::Trace),
84            _ => Err(format!("Unknown HTTP method: {s}")),
85        }
86    }
87}
88
89/// CORS configuration for a route
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct CorsConfig {
92    pub allowed_origins: Vec<String>,
93    pub allowed_methods: Vec<String>,
94    #[serde(default)]
95    pub allowed_headers: Vec<String>,
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub expose_headers: Option<Vec<String>>,
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub max_age: Option<u32>,
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub allow_credentials: Option<bool>,
102
103    // Optimized caches (lazy-initialized on first use)
104    #[serde(skip)]
105    #[doc(hidden)]
106    #[cfg_attr(alef, alef(skip))]
107    pub methods_joined_cache: OnceLock<String>,
108    #[serde(skip)]
109    #[doc(hidden)]
110    #[cfg_attr(alef, alef(skip))]
111    pub headers_joined_cache: OnceLock<String>,
112}
113
114impl CorsConfig {
115    /// Get the cached joined methods string for preflight responses
116    pub fn allowed_methods_joined(&self) -> &str {
117        self.methods_joined_cache
118            .get_or_init(|| self.allowed_methods.join(", "))
119    }
120
121    /// Get the cached joined headers string for preflight responses
122    pub fn allowed_headers_joined(&self) -> &str {
123        self.headers_joined_cache
124            .get_or_init(|| self.allowed_headers.join(", "))
125    }
126
127    /// Check if an origin is allowed (O(1) with wildcard, O(n) for exact match)
128    pub fn is_origin_allowed(&self, origin: &str) -> bool {
129        if origin.is_empty() {
130            return false;
131        }
132        self.allowed_origins.iter().any(|o| o == "*" || o == origin)
133    }
134
135    /// Check if a method is allowed (O(1) with wildcard, O(n) for exact match)
136    pub fn is_method_allowed(&self, method: &str) -> bool {
137        self.allowed_methods
138            .iter()
139            .any(|m| m == "*" || m.eq_ignore_ascii_case(method))
140    }
141
142    /// Check if all requested headers are allowed (O(n) where n = num requested headers)
143    pub fn are_headers_allowed(&self, requested: &[&str]) -> bool {
144        // Check if wildcard is set
145        if self.allowed_headers.iter().any(|h| h == "*") {
146            return true;
147        }
148
149        // Check each requested header
150        requested.iter().all(|req_header| {
151            self.allowed_headers
152                .iter()
153                .any(|h| h.to_lowercase() == req_header.to_lowercase())
154        })
155    }
156}
157
158impl Default for CorsConfig {
159    fn default() -> Self {
160        Self {
161            allowed_origins: vec!["*".to_string()],
162            allowed_methods: vec!["*".to_string()],
163            allowed_headers: vec![],
164            expose_headers: None,
165            max_age: None,
166            allow_credentials: None,
167            methods_joined_cache: OnceLock::new(),
168            headers_joined_cache: OnceLock::new(),
169        }
170    }
171}
172
173/// Route metadata extracted from bindings
174#[derive(Debug, Clone, Serialize, Deserialize)]
175pub struct RouteMetadata {
176    pub method: String,
177    pub path: String,
178    pub handler_name: String,
179    pub request_schema: Option<Value>,
180    pub response_schema: Option<Value>,
181    pub parameter_schema: Option<Value>,
182    #[serde(skip_serializing_if = "Option::is_none")]
183    pub file_params: Option<Value>,
184    #[serde(default)]
185    pub is_async: bool,
186    pub cors: Option<CorsConfig>,
187    /// Name of the body parameter (defaults to "body" if not specified)
188    #[serde(skip_serializing_if = "Option::is_none")]
189    pub body_param_name: Option<String>,
190    /// List of dependency keys this handler requires (for DI)
191    #[cfg(feature = "di")]
192    #[serde(skip_serializing_if = "Option::is_none")]
193    pub handler_dependencies: Option<Vec<String>>,
194    /// JSON-RPC method metadata (if this route is exposed as a JSON-RPC method)
195    #[serde(skip_serializing_if = "Option::is_none")]
196    pub jsonrpc_method: Option<Value>,
197    /// Optional static response configuration: `{"status": 200, "body": "OK", "content_type": "text/plain"}`
198    /// When present, the handler is replaced by a `StaticResponseHandler` that bypasses the full
199    /// middleware pipeline for maximum throughput.
200    #[serde(skip_serializing_if = "Option::is_none")]
201    pub static_response: Option<Value>,
202}
203
204impl Default for RouteMetadata {
205    fn default() -> Self {
206        Self {
207            method: "GET".to_string(),
208            path: "/".to_string(),
209            handler_name: String::new(),
210            request_schema: None,
211            response_schema: None,
212            parameter_schema: None,
213            file_params: None,
214            is_async: true,
215            cors: None,
216            body_param_name: None,
217            #[cfg(feature = "di")]
218            handler_dependencies: None,
219            jsonrpc_method: None,
220            static_response: None,
221        }
222    }
223}
224
225/// Compression configuration shared across runtimes
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct CompressionConfig {
228    /// Enable gzip compression
229    #[serde(default = "default_true")]
230    pub gzip: bool,
231    /// Enable brotli compression
232    #[serde(default = "default_true")]
233    pub brotli: bool,
234    /// Minimum response size to compress (bytes)
235    #[serde(default = "default_compression_min_size")]
236    pub min_size: usize,
237    /// Compression quality (0-11 for brotli, 0-9 for gzip)
238    #[serde(default = "default_compression_quality")]
239    pub quality: u32,
240}
241
242const fn default_true() -> bool {
243    true
244}
245
246const fn default_compression_min_size() -> usize {
247    1024
248}
249
250const fn default_compression_quality() -> u32 {
251    6
252}
253
254impl Default for CompressionConfig {
255    fn default() -> Self {
256        Self {
257            gzip: true,
258            brotli: true,
259            min_size: default_compression_min_size(),
260            quality: default_compression_quality(),
261        }
262    }
263}
264
265/// Rate limiting configuration shared across runtimes
266#[derive(Debug, Clone, Serialize, Deserialize)]
267pub struct RateLimitConfig {
268    /// Requests per second
269    pub per_second: u64,
270    /// Burst allowance
271    pub burst: u32,
272    /// Use IP-based rate limiting
273    #[serde(default = "default_true")]
274    pub ip_based: bool,
275}
276
277impl Default for RateLimitConfig {
278    fn default() -> Self {
279        Self {
280            per_second: 100,
281            burst: 200,
282            ip_based: true,
283        }
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use std::str::FromStr;
291
292    #[test]
293    fn test_method_as_str_get() {
294        assert_eq!(Method::Get.as_str(), "GET");
295    }
296
297    #[test]
298    fn test_method_as_str_post() {
299        assert_eq!(Method::Post.as_str(), "POST");
300    }
301
302    #[test]
303    fn test_method_as_str_put() {
304        assert_eq!(Method::Put.as_str(), "PUT");
305    }
306
307    #[test]
308    fn test_method_as_str_patch() {
309        assert_eq!(Method::Patch.as_str(), "PATCH");
310    }
311
312    #[test]
313    fn test_method_as_str_delete() {
314        assert_eq!(Method::Delete.as_str(), "DELETE");
315    }
316
317    #[test]
318    fn test_method_as_str_head() {
319        assert_eq!(Method::Head.as_str(), "HEAD");
320    }
321
322    #[test]
323    fn test_method_as_str_options() {
324        assert_eq!(Method::Options.as_str(), "OPTIONS");
325    }
326
327    #[test]
328    fn test_method_as_str_trace() {
329        assert_eq!(Method::Trace.as_str(), "TRACE");
330    }
331
332    #[test]
333    fn test_method_display_get() {
334        assert_eq!(Method::Get.to_string(), "GET");
335    }
336
337    #[test]
338    fn test_method_display_post() {
339        assert_eq!(Method::Post.to_string(), "POST");
340    }
341
342    #[test]
343    fn test_method_display_put() {
344        assert_eq!(Method::Put.to_string(), "PUT");
345    }
346
347    #[test]
348    fn test_method_display_patch() {
349        assert_eq!(Method::Patch.to_string(), "PATCH");
350    }
351
352    #[test]
353    fn test_method_display_delete() {
354        assert_eq!(Method::Delete.to_string(), "DELETE");
355    }
356
357    #[test]
358    fn test_method_display_head() {
359        assert_eq!(Method::Head.to_string(), "HEAD");
360    }
361
362    #[test]
363    fn test_method_display_options() {
364        assert_eq!(Method::Options.to_string(), "OPTIONS");
365    }
366
367    #[test]
368    fn test_method_display_trace() {
369        assert_eq!(Method::Trace.to_string(), "TRACE");
370    }
371
372    #[test]
373    fn test_from_str_get() {
374        assert_eq!(Method::from_str("GET"), Ok(Method::Get));
375    }
376
377    #[test]
378    fn test_from_str_post() {
379        assert_eq!(Method::from_str("POST"), Ok(Method::Post));
380    }
381
382    #[test]
383    fn test_from_str_put() {
384        assert_eq!(Method::from_str("PUT"), Ok(Method::Put));
385    }
386
387    #[test]
388    fn test_from_str_patch() {
389        assert_eq!(Method::from_str("PATCH"), Ok(Method::Patch));
390    }
391
392    #[test]
393    fn test_from_str_delete() {
394        assert_eq!(Method::from_str("DELETE"), Ok(Method::Delete));
395    }
396
397    #[test]
398    fn test_from_str_head() {
399        assert_eq!(Method::from_str("HEAD"), Ok(Method::Head));
400    }
401
402    #[test]
403    fn test_from_str_options() {
404        assert_eq!(Method::from_str("OPTIONS"), Ok(Method::Options));
405    }
406
407    #[test]
408    fn test_from_str_trace() {
409        assert_eq!(Method::from_str("TRACE"), Ok(Method::Trace));
410    }
411
412    #[test]
413    fn test_from_str_lowercase() {
414        assert_eq!(Method::from_str("get"), Ok(Method::Get));
415    }
416
417    #[test]
418    fn test_from_str_mixed_case() {
419        assert_eq!(Method::from_str("PoSt"), Ok(Method::Post));
420    }
421
422    #[test]
423    fn test_from_str_invalid_method() {
424        let result = Method::from_str("INVALID");
425        assert!(result.is_err());
426        assert_eq!(result.unwrap_err(), "Unknown HTTP method: INVALID");
427    }
428
429    #[test]
430    fn test_from_str_empty_string() {
431        let result = Method::from_str("");
432        assert!(result.is_err());
433        assert_eq!(result.unwrap_err(), "Unknown HTTP method: ");
434    }
435
436    #[test]
437    fn test_compression_config_default() {
438        let config = CompressionConfig::default();
439        assert!(config.gzip);
440        assert!(config.brotli);
441        assert_eq!(config.min_size, 1024);
442        assert_eq!(config.quality, 6);
443    }
444
445    #[test]
446    fn test_default_true() {
447        assert!(default_true());
448    }
449
450    #[test]
451    fn test_default_compression_min_size() {
452        assert_eq!(default_compression_min_size(), 1024);
453    }
454
455    #[test]
456    fn test_default_compression_quality() {
457        assert_eq!(default_compression_quality(), 6);
458    }
459
460    #[test]
461    fn test_rate_limit_config_default() {
462        let config = RateLimitConfig::default();
463        assert_eq!(config.per_second, 100);
464        assert_eq!(config.burst, 200);
465        assert!(config.ip_based);
466    }
467
468    #[test]
469    fn test_method_equality() {
470        assert_eq!(Method::Get, Method::Get);
471        assert_ne!(Method::Get, Method::Post);
472    }
473
474    #[test]
475    fn test_method_clone() {
476        let method = Method::Post;
477        let cloned = method.clone();
478        assert_eq!(method, cloned);
479    }
480
481    #[test]
482    fn test_compression_config_custom_values() {
483        let config = CompressionConfig {
484            gzip: false,
485            brotli: false,
486            min_size: 2048,
487            quality: 11,
488        };
489        assert!(!config.gzip);
490        assert!(!config.brotli);
491        assert_eq!(config.min_size, 2048);
492        assert_eq!(config.quality, 11);
493    }
494
495    #[test]
496    fn test_rate_limit_config_custom_values() {
497        let config = RateLimitConfig {
498            per_second: 50,
499            burst: 100,
500            ip_based: false,
501        };
502        assert_eq!(config.per_second, 50);
503        assert_eq!(config.burst, 100);
504        assert!(!config.ip_based);
505    }
506
507    #[test]
508    fn test_cors_config_construction() {
509        let cors = CorsConfig {
510            allowed_origins: vec!["http://localhost:3000".to_string()],
511            allowed_methods: vec!["GET".to_string(), "POST".to_string()],
512            allowed_headers: vec![],
513            expose_headers: None,
514            max_age: None,
515            allow_credentials: None,
516            ..Default::default()
517        };
518        assert_eq!(cors.allowed_origins.len(), 1);
519        assert_eq!(cors.allowed_methods.len(), 2);
520        assert_eq!(cors.allowed_headers.len(), 0);
521    }
522
523    #[test]
524    fn test_route_metadata_construction() {
525        let metadata = RouteMetadata {
526            method: "GET".to_string(),
527            path: "/api/users".to_string(),
528            handler_name: "get_users".to_string(),
529            request_schema: None,
530            response_schema: None,
531            parameter_schema: None,
532            file_params: None,
533            is_async: true,
534            cors: None,
535            body_param_name: None,
536            #[cfg(feature = "di")]
537            handler_dependencies: None,
538            jsonrpc_method: None,
539            static_response: None,
540        };
541        assert_eq!(metadata.method, "GET");
542        assert_eq!(metadata.path, "/api/users");
543        assert_eq!(metadata.handler_name, "get_users");
544        assert!(metadata.is_async);
545    }
546}