Skip to main content

spikard_core/
http.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3use std::sync::OnceLock;
4
5/// HTTP method
6#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
7pub enum Method {
8    Get,
9    Post,
10    Put,
11    Patch,
12    Delete,
13    Head,
14    Options,
15    Trace,
16}
17
18impl Method {
19    #[must_use]
20    pub const fn as_str(&self) -> &'static str {
21        match self {
22            Self::Get => "GET",
23            Self::Post => "POST",
24            Self::Put => "PUT",
25            Self::Patch => "PATCH",
26            Self::Delete => "DELETE",
27            Self::Head => "HEAD",
28            Self::Options => "OPTIONS",
29            Self::Trace => "TRACE",
30        }
31    }
32}
33
34impl std::fmt::Display for Method {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        write!(f, "{}", self.as_str())
37    }
38}
39
40impl std::str::FromStr for Method {
41    type Err = String;
42
43    fn from_str(s: &str) -> Result<Self, Self::Err> {
44        match s.to_uppercase().as_str() {
45            "GET" => Ok(Self::Get),
46            "POST" => Ok(Self::Post),
47            "PUT" => Ok(Self::Put),
48            "PATCH" => Ok(Self::Patch),
49            "DELETE" => Ok(Self::Delete),
50            "HEAD" => Ok(Self::Head),
51            "OPTIONS" => Ok(Self::Options),
52            "TRACE" => Ok(Self::Trace),
53            _ => Err(format!("Unknown HTTP method: {s}")),
54        }
55    }
56}
57
58/// CORS configuration for a route
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct CorsConfig {
61    pub allowed_origins: Vec<String>,
62    pub allowed_methods: Vec<String>,
63    #[serde(default)]
64    pub allowed_headers: Vec<String>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub expose_headers: Option<Vec<String>>,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub max_age: Option<u32>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub allow_credentials: Option<bool>,
71
72    // Optimized caches (lazy-initialized on first use)
73    #[serde(skip)]
74    #[doc(hidden)]
75    pub methods_joined_cache: OnceLock<String>,
76    #[serde(skip)]
77    #[doc(hidden)]
78    pub headers_joined_cache: OnceLock<String>,
79}
80
81impl CorsConfig {
82    /// Get the cached joined methods string for preflight responses
83    pub fn allowed_methods_joined(&self) -> &str {
84        self.methods_joined_cache
85            .get_or_init(|| self.allowed_methods.join(", "))
86    }
87
88    /// Get the cached joined headers string for preflight responses
89    pub fn allowed_headers_joined(&self) -> &str {
90        self.headers_joined_cache
91            .get_or_init(|| self.allowed_headers.join(", "))
92    }
93
94    /// Check if an origin is allowed (O(1) with wildcard, O(n) for exact match)
95    pub fn is_origin_allowed(&self, origin: &str) -> bool {
96        if origin.is_empty() {
97            return false;
98        }
99        self.allowed_origins.iter().any(|o| o == "*" || o == origin)
100    }
101
102    /// Check if a method is allowed (O(1) with wildcard, O(n) for exact match)
103    pub fn is_method_allowed(&self, method: &str) -> bool {
104        self.allowed_methods
105            .iter()
106            .any(|m| m == "*" || m.eq_ignore_ascii_case(method))
107    }
108
109    /// Check if all requested headers are allowed (O(n) where n = num requested headers)
110    pub fn are_headers_allowed(&self, requested: &[&str]) -> bool {
111        // Check if wildcard is set
112        if self.allowed_headers.iter().any(|h| h == "*") {
113            return true;
114        }
115
116        // Check each requested header
117        requested.iter().all(|req_header| {
118            self.allowed_headers
119                .iter()
120                .any(|h| h.to_lowercase() == req_header.to_lowercase())
121        })
122    }
123}
124
125impl Default for CorsConfig {
126    fn default() -> Self {
127        Self {
128            allowed_origins: vec!["*".to_string()],
129            allowed_methods: vec!["*".to_string()],
130            allowed_headers: vec![],
131            expose_headers: None,
132            max_age: None,
133            allow_credentials: None,
134            methods_joined_cache: OnceLock::new(),
135            headers_joined_cache: OnceLock::new(),
136        }
137    }
138}
139
140/// Route metadata extracted from bindings
141#[derive(Debug, Clone, Serialize, Deserialize)]
142pub struct RouteMetadata {
143    pub method: String,
144    pub path: String,
145    pub handler_name: String,
146    pub request_schema: Option<Value>,
147    pub response_schema: Option<Value>,
148    pub parameter_schema: Option<Value>,
149    #[serde(skip_serializing_if = "Option::is_none")]
150    pub file_params: Option<Value>,
151    #[serde(default)]
152    pub is_async: bool,
153    pub cors: Option<CorsConfig>,
154    /// Name of the body parameter (defaults to "body" if not specified)
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub body_param_name: Option<String>,
157    /// List of dependency keys this handler requires (for DI)
158    #[cfg(feature = "di")]
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub handler_dependencies: Option<Vec<String>>,
161    /// JSON-RPC method metadata (if this route is exposed as a JSON-RPC method)
162    #[serde(skip_serializing_if = "Option::is_none")]
163    pub jsonrpc_method: Option<Value>,
164    /// Optional static response configuration: `{"status": 200, "body": "OK", "content_type": "text/plain"}`
165    /// When present, the handler is replaced by a `StaticResponseHandler` that bypasses the full
166    /// middleware pipeline for maximum throughput.
167    #[serde(skip_serializing_if = "Option::is_none")]
168    pub static_response: Option<Value>,
169}
170
171/// Compression configuration shared across runtimes
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct CompressionConfig {
174    /// Enable gzip compression
175    #[serde(default = "default_true")]
176    pub gzip: bool,
177    /// Enable brotli compression
178    #[serde(default = "default_true")]
179    pub brotli: bool,
180    /// Minimum response size to compress (bytes)
181    #[serde(default = "default_compression_min_size")]
182    pub min_size: usize,
183    /// Compression quality (0-11 for brotli, 0-9 for gzip)
184    #[serde(default = "default_compression_quality")]
185    pub quality: u32,
186}
187
188const fn default_true() -> bool {
189    true
190}
191
192const fn default_compression_min_size() -> usize {
193    1024
194}
195
196const fn default_compression_quality() -> u32 {
197    6
198}
199
200impl Default for CompressionConfig {
201    fn default() -> Self {
202        Self {
203            gzip: true,
204            brotli: true,
205            min_size: default_compression_min_size(),
206            quality: default_compression_quality(),
207        }
208    }
209}
210
211/// Rate limiting configuration shared across runtimes
212#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct RateLimitConfig {
214    /// Requests per second
215    pub per_second: u64,
216    /// Burst allowance
217    pub burst: u32,
218    /// Use IP-based rate limiting
219    #[serde(default = "default_true")]
220    pub ip_based: bool,
221}
222
223impl Default for RateLimitConfig {
224    fn default() -> Self {
225        Self {
226            per_second: 100,
227            burst: 200,
228            ip_based: true,
229        }
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use std::str::FromStr;
237
238    #[test]
239    fn test_method_as_str_get() {
240        assert_eq!(Method::Get.as_str(), "GET");
241    }
242
243    #[test]
244    fn test_method_as_str_post() {
245        assert_eq!(Method::Post.as_str(), "POST");
246    }
247
248    #[test]
249    fn test_method_as_str_put() {
250        assert_eq!(Method::Put.as_str(), "PUT");
251    }
252
253    #[test]
254    fn test_method_as_str_patch() {
255        assert_eq!(Method::Patch.as_str(), "PATCH");
256    }
257
258    #[test]
259    fn test_method_as_str_delete() {
260        assert_eq!(Method::Delete.as_str(), "DELETE");
261    }
262
263    #[test]
264    fn test_method_as_str_head() {
265        assert_eq!(Method::Head.as_str(), "HEAD");
266    }
267
268    #[test]
269    fn test_method_as_str_options() {
270        assert_eq!(Method::Options.as_str(), "OPTIONS");
271    }
272
273    #[test]
274    fn test_method_as_str_trace() {
275        assert_eq!(Method::Trace.as_str(), "TRACE");
276    }
277
278    #[test]
279    fn test_method_display_get() {
280        assert_eq!(Method::Get.to_string(), "GET");
281    }
282
283    #[test]
284    fn test_method_display_post() {
285        assert_eq!(Method::Post.to_string(), "POST");
286    }
287
288    #[test]
289    fn test_method_display_put() {
290        assert_eq!(Method::Put.to_string(), "PUT");
291    }
292
293    #[test]
294    fn test_method_display_patch() {
295        assert_eq!(Method::Patch.to_string(), "PATCH");
296    }
297
298    #[test]
299    fn test_method_display_delete() {
300        assert_eq!(Method::Delete.to_string(), "DELETE");
301    }
302
303    #[test]
304    fn test_method_display_head() {
305        assert_eq!(Method::Head.to_string(), "HEAD");
306    }
307
308    #[test]
309    fn test_method_display_options() {
310        assert_eq!(Method::Options.to_string(), "OPTIONS");
311    }
312
313    #[test]
314    fn test_method_display_trace() {
315        assert_eq!(Method::Trace.to_string(), "TRACE");
316    }
317
318    #[test]
319    fn test_from_str_get() {
320        assert_eq!(Method::from_str("GET"), Ok(Method::Get));
321    }
322
323    #[test]
324    fn test_from_str_post() {
325        assert_eq!(Method::from_str("POST"), Ok(Method::Post));
326    }
327
328    #[test]
329    fn test_from_str_put() {
330        assert_eq!(Method::from_str("PUT"), Ok(Method::Put));
331    }
332
333    #[test]
334    fn test_from_str_patch() {
335        assert_eq!(Method::from_str("PATCH"), Ok(Method::Patch));
336    }
337
338    #[test]
339    fn test_from_str_delete() {
340        assert_eq!(Method::from_str("DELETE"), Ok(Method::Delete));
341    }
342
343    #[test]
344    fn test_from_str_head() {
345        assert_eq!(Method::from_str("HEAD"), Ok(Method::Head));
346    }
347
348    #[test]
349    fn test_from_str_options() {
350        assert_eq!(Method::from_str("OPTIONS"), Ok(Method::Options));
351    }
352
353    #[test]
354    fn test_from_str_trace() {
355        assert_eq!(Method::from_str("TRACE"), Ok(Method::Trace));
356    }
357
358    #[test]
359    fn test_from_str_lowercase() {
360        assert_eq!(Method::from_str("get"), Ok(Method::Get));
361    }
362
363    #[test]
364    fn test_from_str_mixed_case() {
365        assert_eq!(Method::from_str("PoSt"), Ok(Method::Post));
366    }
367
368    #[test]
369    fn test_from_str_invalid_method() {
370        let result = Method::from_str("INVALID");
371        assert!(result.is_err());
372        assert_eq!(result.unwrap_err(), "Unknown HTTP method: INVALID");
373    }
374
375    #[test]
376    fn test_from_str_empty_string() {
377        let result = Method::from_str("");
378        assert!(result.is_err());
379        assert_eq!(result.unwrap_err(), "Unknown HTTP method: ");
380    }
381
382    #[test]
383    fn test_compression_config_default() {
384        let config = CompressionConfig::default();
385        assert!(config.gzip);
386        assert!(config.brotli);
387        assert_eq!(config.min_size, 1024);
388        assert_eq!(config.quality, 6);
389    }
390
391    #[test]
392    fn test_default_true() {
393        assert!(default_true());
394    }
395
396    #[test]
397    fn test_default_compression_min_size() {
398        assert_eq!(default_compression_min_size(), 1024);
399    }
400
401    #[test]
402    fn test_default_compression_quality() {
403        assert_eq!(default_compression_quality(), 6);
404    }
405
406    #[test]
407    fn test_rate_limit_config_default() {
408        let config = RateLimitConfig::default();
409        assert_eq!(config.per_second, 100);
410        assert_eq!(config.burst, 200);
411        assert!(config.ip_based);
412    }
413
414    #[test]
415    fn test_method_equality() {
416        assert_eq!(Method::Get, Method::Get);
417        assert_ne!(Method::Get, Method::Post);
418    }
419
420    #[test]
421    fn test_method_clone() {
422        let method = Method::Post;
423        let cloned = method.clone();
424        assert_eq!(method, cloned);
425    }
426
427    #[test]
428    fn test_compression_config_custom_values() {
429        let config = CompressionConfig {
430            gzip: false,
431            brotli: false,
432            min_size: 2048,
433            quality: 11,
434        };
435        assert!(!config.gzip);
436        assert!(!config.brotli);
437        assert_eq!(config.min_size, 2048);
438        assert_eq!(config.quality, 11);
439    }
440
441    #[test]
442    fn test_rate_limit_config_custom_values() {
443        let config = RateLimitConfig {
444            per_second: 50,
445            burst: 100,
446            ip_based: false,
447        };
448        assert_eq!(config.per_second, 50);
449        assert_eq!(config.burst, 100);
450        assert!(!config.ip_based);
451    }
452
453    #[test]
454    fn test_cors_config_construction() {
455        let cors = CorsConfig {
456            allowed_origins: vec!["http://localhost:3000".to_string()],
457            allowed_methods: vec!["GET".to_string(), "POST".to_string()],
458            allowed_headers: vec![],
459            expose_headers: None,
460            max_age: None,
461            allow_credentials: None,
462            ..Default::default()
463        };
464        assert_eq!(cors.allowed_origins.len(), 1);
465        assert_eq!(cors.allowed_methods.len(), 2);
466        assert_eq!(cors.allowed_headers.len(), 0);
467    }
468
469    #[test]
470    fn test_route_metadata_construction() {
471        let metadata = RouteMetadata {
472            method: "GET".to_string(),
473            path: "/api/users".to_string(),
474            handler_name: "get_users".to_string(),
475            request_schema: None,
476            response_schema: None,
477            parameter_schema: None,
478            file_params: None,
479            is_async: true,
480            cors: None,
481            body_param_name: None,
482            #[cfg(feature = "di")]
483            handler_dependencies: None,
484            jsonrpc_method: None,
485            static_response: None,
486        };
487        assert_eq!(metadata.method, "GET");
488        assert_eq!(metadata.path, "/api/users");
489        assert_eq!(metadata.handler_name, "get_users");
490        assert!(metadata.is_async);
491    }
492}