Skip to main content

spikard_core/
http.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::sync::OnceLock;
4
5/// HTTP method
6#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
7pub enum Method {
8    #[default]
9    Get,
10    Post,
11    Put,
12    Patch,
13    Delete,
14    Head,
15    Options,
16    Connect,
17    Trace,
18}
19
20impl Method {
21    #[must_use]
22    pub const fn as_str(&self) -> &'static str {
23        match self {
24            Self::Get => "GET",
25            Self::Post => "POST",
26            Self::Put => "PUT",
27            Self::Patch => "PATCH",
28            Self::Delete => "DELETE",
29            Self::Head => "HEAD",
30            Self::Options => "OPTIONS",
31            Self::Connect => "CONNECT",
32            Self::Trace => "TRACE",
33        }
34    }
35}
36
37impl std::fmt::Display for Method {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        write!(f, "{}", self.as_str())
40    }
41}
42
43impl From<Method> for http::method::Method {
44    fn from(method: Method) -> Self {
45        match method {
46            Method::Get => Self::GET,
47            Method::Post => Self::POST,
48            Method::Put => Self::PUT,
49            Method::Patch => Self::PATCH,
50            Method::Delete => Self::DELETE,
51            Method::Head => Self::HEAD,
52            Method::Options => Self::OPTIONS,
53            Method::Connect => Self::CONNECT,
54            Method::Trace => Self::TRACE,
55        }
56    }
57}
58
59impl From<&Method> for http::method::Method {
60    fn from(method: &Method) -> Self {
61        match method {
62            Method::Get => Self::GET,
63            Method::Post => Self::POST,
64            Method::Put => Self::PUT,
65            Method::Patch => Self::PATCH,
66            Method::Delete => Self::DELETE,
67            Method::Head => Self::HEAD,
68            Method::Options => Self::OPTIONS,
69            Method::Connect => Self::CONNECT,
70            Method::Trace => Self::TRACE,
71        }
72    }
73}
74
75impl std::str::FromStr for Method {
76    type Err = String;
77
78    fn from_str(s: &str) -> Result<Self, Self::Err> {
79        match s.to_uppercase().as_str() {
80            "GET" => Ok(Self::Get),
81            "POST" => Ok(Self::Post),
82            "PUT" => Ok(Self::Put),
83            "PATCH" => Ok(Self::Patch),
84            "DELETE" => Ok(Self::Delete),
85            "HEAD" => Ok(Self::Head),
86            "OPTIONS" => Ok(Self::Options),
87            "CONNECT" => Ok(Self::Connect),
88            "TRACE" => Ok(Self::Trace),
89            _ => Err(format!("Unknown HTTP method: {s}")),
90        }
91    }
92}
93
94/// CORS configuration for a route
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct CorsConfig {
97    pub allowed_origins: Vec<String>,
98    pub allowed_methods: Vec<String>,
99    #[serde(default)]
100    pub allowed_headers: Vec<String>,
101    #[serde(skip_serializing_if = "Option::is_none")]
102    pub expose_headers: Option<Vec<String>>,
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub max_age: Option<u32>,
105    #[serde(skip_serializing_if = "Option::is_none")]
106    pub allow_credentials: Option<bool>,
107
108    // Optimized caches (lazy-initialized on first use)
109    #[serde(skip)]
110    #[doc(hidden)]
111    #[cfg_attr(alef, alef(skip))]
112    pub methods_joined_cache: OnceLock<String>,
113    #[serde(skip)]
114    #[doc(hidden)]
115    #[cfg_attr(alef, alef(skip))]
116    pub headers_joined_cache: OnceLock<String>,
117}
118
119impl CorsConfig {
120    /// Get the cached joined methods string for preflight responses
121    pub fn allowed_methods_joined(&self) -> &str {
122        self.methods_joined_cache
123            .get_or_init(|| self.allowed_methods.join(", "))
124    }
125
126    /// Get the cached joined headers string for preflight responses
127    pub fn allowed_headers_joined(&self) -> &str {
128        self.headers_joined_cache
129            .get_or_init(|| self.allowed_headers.join(", "))
130    }
131
132    /// Check if an origin is allowed (O(1) with wildcard, O(n) for exact match)
133    pub fn is_origin_allowed(&self, origin: &str) -> bool {
134        if origin.is_empty() {
135            return false;
136        }
137        self.allowed_origins.iter().any(|o| o == "*" || o == origin)
138    }
139
140    /// Check if a method is allowed (O(1) with wildcard, O(n) for exact match)
141    pub fn is_method_allowed(&self, method: &str) -> bool {
142        self.allowed_methods
143            .iter()
144            .any(|m| m == "*" || m.eq_ignore_ascii_case(method))
145    }
146
147    /// Check if all requested headers are allowed (O(n) where n = num requested headers)
148    pub fn are_headers_allowed(&self, requested: &[&str]) -> bool {
149        // Check if wildcard is set
150        if self.allowed_headers.iter().any(|h| h == "*") {
151            return true;
152        }
153
154        // Check each requested header
155        requested.iter().all(|req_header| {
156            self.allowed_headers
157                .iter()
158                .any(|h| h.to_lowercase() == req_header.to_lowercase())
159        })
160    }
161}
162
163impl Default for CorsConfig {
164    fn default() -> Self {
165        Self {
166            allowed_origins: vec!["*".to_string()],
167            allowed_methods: vec!["*".to_string()],
168            allowed_headers: vec![],
169            expose_headers: None,
170            max_age: None,
171            allow_credentials: None,
172            methods_joined_cache: OnceLock::new(),
173            headers_joined_cache: OnceLock::new(),
174        }
175    }
176}
177
178/// Route metadata extracted from bindings
179#[derive(Debug, Clone, Serialize, Deserialize)]
180pub struct RouteMetadata {
181    pub method: String,
182    pub path: String,
183    pub handler_name: String,
184    pub request_schema: Option<Value>,
185    pub response_schema: Option<Value>,
186    pub parameter_schema: Option<Value>,
187    #[serde(skip_serializing_if = "Option::is_none")]
188    pub file_params: Option<Value>,
189    #[serde(default)]
190    pub is_async: bool,
191    pub cors: Option<CorsConfig>,
192    #[serde(skip_serializing_if = "Option::is_none")]
193    pub compression: Option<CompressionConfig>,
194    /// Name of the body parameter (defaults to "body" if not specified)
195    #[serde(skip_serializing_if = "Option::is_none")]
196    pub body_param_name: Option<String>,
197    /// List of dependency keys this handler requires (for DI)
198    #[cfg(feature = "di")]
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub handler_dependencies: Option<Vec<String>>,
201    /// JSON-RPC method metadata (if this route is exposed as a JSON-RPC method)
202    #[serde(skip_serializing_if = "Option::is_none")]
203    pub jsonrpc_method: Option<Value>,
204    /// Optional static response configuration: `{"status": 200, "body": "OK", "content_type": "text/plain"}`
205    /// When present, the handler is replaced by a `StaticResponseHandler` that bypasses the full
206    /// middleware pipeline for maximum throughput.
207    #[serde(skip_serializing_if = "Option::is_none")]
208    pub static_response: Option<Value>,
209}
210
211impl Default for RouteMetadata {
212    fn default() -> Self {
213        Self {
214            method: "GET".to_string(),
215            path: "/".to_string(),
216            handler_name: String::new(),
217            request_schema: None,
218            response_schema: None,
219            parameter_schema: None,
220            file_params: None,
221            is_async: true,
222            cors: None,
223            body_param_name: None,
224            #[cfg(feature = "di")]
225            handler_dependencies: None,
226            jsonrpc_method: None,
227            static_response: None,
228            compression: None,
229        }
230    }
231}
232
233/// Compression configuration shared across runtimes
234#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct CompressionConfig {
236    /// Enable gzip compression
237    #[serde(default = "default_true")]
238    pub gzip: bool,
239    /// Enable brotli compression
240    #[serde(default = "default_true")]
241    pub brotli: bool,
242    /// Minimum response size to compress (bytes)
243    #[serde(default = "default_compression_min_size")]
244    pub min_size: usize,
245    /// Compression quality (0-11 for brotli, 0-9 for gzip)
246    #[serde(default = "default_compression_quality")]
247    pub quality: u32,
248}
249
250const fn default_true() -> bool {
251    true
252}
253
254const fn default_compression_min_size() -> usize {
255    1024
256}
257
258const fn default_compression_quality() -> u32 {
259    6
260}
261
262impl Default for CompressionConfig {
263    fn default() -> Self {
264        Self {
265            gzip: true,
266            brotli: true,
267            min_size: default_compression_min_size(),
268            quality: default_compression_quality(),
269        }
270    }
271}
272
273/// Rate limiting configuration shared across runtimes
274#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct RateLimitConfig {
276    /// Requests per second
277    pub per_second: u64,
278    /// Burst allowance
279    pub burst: u32,
280    /// Use IP-based rate limiting
281    #[serde(default = "default_true")]
282    pub ip_based: bool,
283}
284
285impl Default for RateLimitConfig {
286    fn default() -> Self {
287        Self {
288            per_second: 100,
289            burst: 200,
290            ip_based: true,
291        }
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use std::str::FromStr;
299
300    #[test]
301    fn test_method_as_str_get() {
302        assert_eq!(Method::Get.as_str(), "GET");
303    }
304
305    #[test]
306    fn test_method_as_str_post() {
307        assert_eq!(Method::Post.as_str(), "POST");
308    }
309
310    #[test]
311    fn test_method_as_str_put() {
312        assert_eq!(Method::Put.as_str(), "PUT");
313    }
314
315    #[test]
316    fn test_method_as_str_patch() {
317        assert_eq!(Method::Patch.as_str(), "PATCH");
318    }
319
320    #[test]
321    fn test_method_as_str_delete() {
322        assert_eq!(Method::Delete.as_str(), "DELETE");
323    }
324
325    #[test]
326    fn test_method_as_str_head() {
327        assert_eq!(Method::Head.as_str(), "HEAD");
328    }
329
330    #[test]
331    fn test_method_as_str_options() {
332        assert_eq!(Method::Options.as_str(), "OPTIONS");
333    }
334
335    #[test]
336    fn test_method_as_str_trace() {
337        assert_eq!(Method::Trace.as_str(), "TRACE");
338    }
339
340    #[test]
341    fn test_method_display_get() {
342        assert_eq!(Method::Get.to_string(), "GET");
343    }
344
345    #[test]
346    fn test_method_display_post() {
347        assert_eq!(Method::Post.to_string(), "POST");
348    }
349
350    #[test]
351    fn test_method_display_put() {
352        assert_eq!(Method::Put.to_string(), "PUT");
353    }
354
355    #[test]
356    fn test_method_display_patch() {
357        assert_eq!(Method::Patch.to_string(), "PATCH");
358    }
359
360    #[test]
361    fn test_method_display_delete() {
362        assert_eq!(Method::Delete.to_string(), "DELETE");
363    }
364
365    #[test]
366    fn test_method_display_head() {
367        assert_eq!(Method::Head.to_string(), "HEAD");
368    }
369
370    #[test]
371    fn test_method_display_options() {
372        assert_eq!(Method::Options.to_string(), "OPTIONS");
373    }
374
375    #[test]
376    fn test_method_display_trace() {
377        assert_eq!(Method::Trace.to_string(), "TRACE");
378    }
379
380    #[test]
381    fn test_from_str_get() {
382        assert_eq!(Method::from_str("GET"), Ok(Method::Get));
383    }
384
385    #[test]
386    fn test_from_str_post() {
387        assert_eq!(Method::from_str("POST"), Ok(Method::Post));
388    }
389
390    #[test]
391    fn test_from_str_put() {
392        assert_eq!(Method::from_str("PUT"), Ok(Method::Put));
393    }
394
395    #[test]
396    fn test_from_str_patch() {
397        assert_eq!(Method::from_str("PATCH"), Ok(Method::Patch));
398    }
399
400    #[test]
401    fn test_from_str_delete() {
402        assert_eq!(Method::from_str("DELETE"), Ok(Method::Delete));
403    }
404
405    #[test]
406    fn test_from_str_head() {
407        assert_eq!(Method::from_str("HEAD"), Ok(Method::Head));
408    }
409
410    #[test]
411    fn test_from_str_options() {
412        assert_eq!(Method::from_str("OPTIONS"), Ok(Method::Options));
413    }
414
415    #[test]
416    fn test_from_str_trace() {
417        assert_eq!(Method::from_str("TRACE"), Ok(Method::Trace));
418    }
419
420    #[test]
421    fn test_from_str_lowercase() {
422        assert_eq!(Method::from_str("get"), Ok(Method::Get));
423    }
424
425    #[test]
426    fn test_from_str_mixed_case() {
427        assert_eq!(Method::from_str("PoSt"), Ok(Method::Post));
428    }
429
430    #[test]
431    fn test_from_str_invalid_method() {
432        let result = Method::from_str("INVALID");
433        assert!(result.is_err());
434        assert_eq!(result.unwrap_err(), "Unknown HTTP method: INVALID");
435    }
436
437    #[test]
438    fn test_from_str_empty_string() {
439        let result = Method::from_str("");
440        assert!(result.is_err());
441        assert_eq!(result.unwrap_err(), "Unknown HTTP method: ");
442    }
443
444    #[test]
445    fn test_compression_config_default() {
446        let config = CompressionConfig::default();
447        assert!(config.gzip);
448        assert!(config.brotli);
449        assert_eq!(config.min_size, 1024);
450        assert_eq!(config.quality, 6);
451    }
452
453    #[test]
454    fn test_default_true() {
455        assert!(default_true());
456    }
457
458    #[test]
459    fn test_default_compression_min_size() {
460        assert_eq!(default_compression_min_size(), 1024);
461    }
462
463    #[test]
464    fn test_default_compression_quality() {
465        assert_eq!(default_compression_quality(), 6);
466    }
467
468    #[test]
469    fn test_rate_limit_config_default() {
470        let config = RateLimitConfig::default();
471        assert_eq!(config.per_second, 100);
472        assert_eq!(config.burst, 200);
473        assert!(config.ip_based);
474    }
475
476    #[test]
477    fn test_method_equality() {
478        assert_eq!(Method::Get, Method::Get);
479        assert_ne!(Method::Get, Method::Post);
480    }
481
482    #[test]
483    fn test_method_clone() {
484        let method = Method::Post;
485        let cloned = method.clone();
486        assert_eq!(method, cloned);
487    }
488
489    #[test]
490    fn test_compression_config_custom_values() {
491        let config = CompressionConfig {
492            gzip: false,
493            brotli: false,
494            min_size: 2048,
495            quality: 11,
496        };
497        assert!(!config.gzip);
498        assert!(!config.brotli);
499        assert_eq!(config.min_size, 2048);
500        assert_eq!(config.quality, 11);
501    }
502
503    #[test]
504    fn test_rate_limit_config_custom_values() {
505        let config = RateLimitConfig {
506            per_second: 50,
507            burst: 100,
508            ip_based: false,
509        };
510        assert_eq!(config.per_second, 50);
511        assert_eq!(config.burst, 100);
512        assert!(!config.ip_based);
513    }
514
515    #[test]
516    fn test_cors_config_construction() {
517        let cors = CorsConfig {
518            allowed_origins: vec!["http://localhost:3000".to_string()],
519            allowed_methods: vec!["GET".to_string(), "POST".to_string()],
520            allowed_headers: vec![],
521            expose_headers: None,
522            max_age: None,
523            allow_credentials: None,
524            ..Default::default()
525        };
526        assert_eq!(cors.allowed_origins.len(), 1);
527        assert_eq!(cors.allowed_methods.len(), 2);
528        assert_eq!(cors.allowed_headers.len(), 0);
529    }
530
531    #[test]
532    fn test_route_metadata_construction() {
533        let metadata = RouteMetadata {
534            method: "GET".to_string(),
535            path: "/api/users".to_string(),
536            handler_name: "get_users".to_string(),
537            request_schema: None,
538            response_schema: None,
539            parameter_schema: None,
540            file_params: None,
541            is_async: true,
542            cors: None,
543            body_param_name: None,
544            #[cfg(feature = "di")]
545            handler_dependencies: None,
546            jsonrpc_method: None,
547            static_response: None,
548            compression: None,
549        };
550        assert_eq!(metadata.method, "GET");
551        assert_eq!(metadata.path, "/api/users");
552        assert_eq!(metadata.handler_name, "get_users");
553        assert!(metadata.is_async);
554    }
555}