Skip to main content

synapse_pingora/waf/
types.rs

1//! Core types for the WAF engine.
2
3use std::collections::HashMap;
4
5use percent_encoding::percent_decode_str;
6use serde::{Deserialize, Serialize};
7
8/// HTTP request to analyze.
9#[derive(Debug, Clone, Default)]
10pub struct Request<'a> {
11    /// HTTP method (GET, POST, etc.)
12    pub method: &'a str,
13    /// Request path including query string
14    pub path: &'a str,
15    /// Query string (if separate from path)
16    pub query: Option<&'a str>,
17    /// Request headers
18    pub headers: Vec<Header<'a>>,
19    /// Request body
20    pub body: Option<&'a [u8]>,
21    /// Client IP address
22    pub client_ip: &'a str,
23    /// Whether this is static content
24    pub is_static: bool,
25}
26
27/// HTTP header key-value pair.
28#[derive(Debug, Clone)]
29pub struct Header<'a> {
30    pub name: &'a str,
31    pub value: &'a str,
32}
33
34impl<'a> Header<'a> {
35    pub fn new(name: &'a str, value: &'a str) -> Self {
36        Self { name, value }
37    }
38}
39
40/// Analysis result.
41#[derive(Debug, Clone)]
42pub struct Verdict {
43    /// Recommended action
44    pub action: Action,
45    /// Combined risk score (0-1000 for extended range, 0-100 for default)
46    pub risk_score: u16,
47    /// IDs of matched rules
48    pub matched_rules: Vec<u32>,
49    /// Entity (IP) cumulative risk score (0.0-max_risk)
50    pub entity_risk: f64,
51    /// Whether the entity is blocked (risk or rule-based)
52    pub entity_blocked: bool,
53    /// Reason for blocking (if entity_blocked is true)
54    pub block_reason: Option<String>,
55    /// Per-rule risk contributions for explainability
56    pub risk_contributions: Vec<RiskContribution>,
57
58    // Anomaly detection fields
59    /// Endpoint template (e.g., "/api/users/{id}")
60    pub endpoint_template: Option<String>,
61    /// Aggregate endpoint risk score (0-100)
62    pub endpoint_risk: Option<f32>,
63    /// Per-request anomaly score (-10 to +10)
64    pub anomaly_score: Option<f64>,
65    /// Adjusted blocking threshold used for this request
66    pub adjusted_threshold: Option<f64>,
67    /// Anomaly signals detected for observability
68    pub anomaly_signals: Vec<AnomalySignal>,
69
70    // Timeout fields
71    /// Whether evaluation timed out (partial result)
72    pub timed_out: bool,
73    /// Number of rules evaluated before timeout (if timed_out)
74    pub rules_evaluated: Option<u32>,
75}
76
77impl Default for Verdict {
78    fn default() -> Self {
79        Self {
80            action: Action::Allow,
81            risk_score: 0,
82            matched_rules: Vec::new(),
83            entity_risk: 0.0,
84            entity_blocked: false,
85            block_reason: None,
86            risk_contributions: Vec::new(),
87            endpoint_template: None,
88            endpoint_risk: None,
89            anomaly_score: None,
90            adjusted_threshold: None,
91            anomaly_signals: Vec::new(),
92            timed_out: false,
93            rules_evaluated: None,
94        }
95    }
96}
97
98/// Action recommendation.
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100#[repr(u8)]
101pub enum Action {
102    Allow = 0,
103    Block = 1,
104}
105
106/// Per-rule risk contribution for explainability.
107#[derive(Debug, Clone)]
108pub struct RiskContribution {
109    /// Rule ID that contributed risk.
110    pub rule_id: u32,
111    /// Base risk from rule.effective_risk().
112    pub base_risk: f64,
113    /// Repeat offender multiplier (1.0 = first match).
114    pub multiplier: f64,
115    /// Final risk after multiplier: base_risk * multiplier.
116    pub final_risk: f64,
117}
118
119impl RiskContribution {
120    /// Create a new risk contribution.
121    #[inline]
122    pub fn new(rule_id: u32, base_risk: f64, multiplier: f64) -> Self {
123        Self {
124            rule_id,
125            base_risk,
126            multiplier,
127            final_risk: base_risk * multiplier,
128        }
129    }
130}
131
132/// Anomaly signal detected during request analysis.
133#[derive(Debug, Clone)]
134pub struct AnomalySignal {
135    /// Type of anomaly signal
136    pub signal_type: AnomalySignalType,
137    /// Severity score (0-100)
138    pub severity: f32,
139    /// Human-readable detail
140    pub detail: String,
141}
142
143impl AnomalySignal {
144    /// Convert to AnomalyType for entity tracking.
145    pub fn to_anomaly_type(&self) -> AnomalyType {
146        match self.signal_type {
147            AnomalySignalType::PayloadSize => AnomalyType::OversizedRequest,
148            AnomalySignalType::RequestRate => AnomalyType::VelocitySpike,
149            AnomalySignalType::ErrorRate => AnomalyType::TimingAnomaly,
150            AnomalySignalType::ParameterAnomaly => AnomalyType::Custom,
151            AnomalySignalType::ContentTypeAnomaly => AnomalyType::Custom,
152            AnomalySignalType::TimingAnomaly => AnomalyType::TimingAnomaly,
153            AnomalySignalType::SchemaViolation => AnomalyType::Custom,
154        }
155    }
156}
157
158/// Types of anomaly signals.
159#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub enum AnomalySignalType {
161    /// Request payload size outside normal distribution
162    PayloadSize,
163    /// Request rate exceeds baseline
164    RequestRate,
165    /// Error rate spike detected
166    ErrorRate,
167    /// Unexpected parameters in request
168    ParameterAnomaly,
169    /// Unexpected content type
170    ContentTypeAnomaly,
171    /// Request timing pattern anomaly
172    TimingAnomaly,
173    /// Schema validation violation
174    SchemaViolation,
175}
176
177/// Modes for behavioral anomaly blocking.
178#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
179pub enum BlockingMode {
180    /// Only log anomalies, never block.
181    #[default]
182    Learning,
183    /// Block requests that exceed the anomaly threshold.
184    Enforcement,
185}
186
187/// Risk calculation configuration.
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct RiskConfig {
190    /// Maximum risk score (default: 100.0, can be 1000.0 for extended range).
191    pub max_risk: f64,
192    /// Whether to apply repeat offender multipliers.
193    pub enable_repeat_multipliers: bool,
194    /// Custom anomaly risk overrides (type -> risk).
195    pub anomaly_risk_overrides: HashMap<AnomalyType, f64>,
196    /// Threshold for anomaly-based blocking (0 = disabled, default: 10.0)
197    pub anomaly_blocking_threshold: f64,
198    /// Behavioral blocking mode
199    pub blocking_mode: BlockingMode,
200}
201
202impl Default for RiskConfig {
203    fn default() -> Self {
204        Self {
205            max_risk: 100.0,
206            enable_repeat_multipliers: true,
207            anomaly_risk_overrides: HashMap::new(),
208            anomaly_blocking_threshold: 10.0,
209            blocking_mode: BlockingMode::Learning, // Default to safe mode
210        }
211    }
212}
213
214impl RiskConfig {
215    /// Create config with extended risk range (1000).
216    pub fn with_extended_range() -> Self {
217        Self {
218            max_risk: 1000.0,
219            ..Default::default()
220        }
221    }
222
223    /// Get risk for anomaly type (override or default).
224    #[inline]
225    pub fn anomaly_risk(&self, anomaly_type: AnomalyType) -> f64 {
226        self.anomaly_risk_overrides
227            .get(&anomaly_type)
228            .copied()
229            .unwrap_or_else(|| anomaly_type.default_risk())
230    }
231
232    /// Set custom risk for an anomaly type.
233    pub fn set_anomaly_risk(&mut self, anomaly_type: AnomalyType, risk: f64) {
234        self.anomaly_risk_overrides.insert(anomaly_type, risk);
235    }
236
237    /// Reset anomaly type to default risk.
238    pub fn reset_anomaly_risk(&mut self, anomaly_type: AnomalyType) {
239        self.anomaly_risk_overrides.remove(&anomaly_type);
240    }
241}
242
243/// Anomaly risk contribution for explainability.
244///
245/// Tracks anomaly-based risk applied to an entity.
246#[derive(Debug, Clone)]
247pub struct AnomalyContribution {
248    /// Anomaly type that contributed risk.
249    pub anomaly_type: AnomalyType,
250    /// Risk score applied.
251    pub risk: f64,
252    /// Optional custom reason.
253    pub reason: Option<String>,
254    /// Timestamp when applied (ms since epoch).
255    pub applied_at: u64,
256}
257
258impl AnomalyContribution {
259    /// Create a new anomaly contribution.
260    pub fn new(anomaly_type: AnomalyType, risk: f64, reason: Option<String>, now: u64) -> Self {
261        Self {
262            anomaly_type,
263            risk,
264            reason,
265            applied_at: now,
266        }
267    }
268}
269
270/// Calculate repeat offender multiplier based on match count.
271///
272/// Multiplier tiers:
273/// - 1 match: 1.0x (no boost)
274/// - 2-5 matches: 1.25x
275/// - 6-10 matches: 1.5x
276/// - 11+ matches: 2.0x
277///
278/// # Arguments
279/// * `match_count` - Number of times the rule has matched for this entity
280///
281/// # Returns
282/// Multiplier to apply to base risk
283#[inline]
284pub fn repeat_multiplier(match_count: u32) -> f64 {
285    match match_count {
286        0 | 1 => 1.0,
287        2..=5 => 1.25,
288        6..=10 => 1.5,
289        _ => 2.0,
290    }
291}
292
293/// Anomaly types for behavioral risk scoring.
294#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
295#[repr(u8)]
296pub enum AnomalyType {
297    /// Device fingerprint changed for same IP.
298    FingerprintChange = 0,
299    /// Same session token used from multiple IPs.
300    SessionSharing = 1,
301    /// Auth token reused after expiration.
302    TokenReuse = 2,
303    /// Sudden spike in request velocity.
304    VelocitySpike = 3,
305    /// Suspicious rotation pattern (IPs, user agents).
306    RotationPattern = 4,
307    /// Request timing anomaly.
308    TimingAnomaly = 5,
309    /// Geographic impossibility (too fast travel).
310    ImpossibleTravel = 6,
311    /// Request body exceeds normal size.
312    OversizedRequest = 7,
313    /// Response body exceeds normal size.
314    OversizedResponse = 8,
315    /// Sudden bandwidth consumption spike.
316    BandwidthSpike = 9,
317    /// Large responses, small requests (data theft).
318    ExfiltrationPattern = 10,
319    /// Large requests, small responses (file upload).
320    UploadPattern = 11,
321    /// Custom anomaly with explicit risk.
322    Custom = 255,
323}
324
325impl AnomalyType {
326    /// Default risk score for each anomaly type.
327    #[inline]
328    pub const fn default_risk(self) -> f64 {
329        match self {
330            AnomalyType::SessionSharing => 50.0,
331            AnomalyType::ExfiltrationPattern => 40.0,
332            AnomalyType::TokenReuse => 40.0,
333            AnomalyType::RotationPattern => 35.0,
334            AnomalyType::UploadPattern => 35.0,
335            AnomalyType::FingerprintChange => 30.0,
336            AnomalyType::BandwidthSpike => 25.0,
337            AnomalyType::ImpossibleTravel => 25.0,
338            AnomalyType::OversizedRequest => 20.0,
339            AnomalyType::OversizedResponse => 15.0,
340            AnomalyType::VelocitySpike => 15.0,
341            AnomalyType::TimingAnomaly => 10.0,
342            AnomalyType::Custom => 0.0,
343        }
344    }
345
346    /// Get the name of this anomaly type.
347    pub const fn name(self) -> &'static str {
348        match self {
349            AnomalyType::FingerprintChange => "fingerprint_change",
350            AnomalyType::SessionSharing => "session_sharing",
351            AnomalyType::TokenReuse => "token_reuse",
352            AnomalyType::VelocitySpike => "velocity_spike",
353            AnomalyType::RotationPattern => "rotation_pattern",
354            AnomalyType::TimingAnomaly => "timing_anomaly",
355            AnomalyType::ImpossibleTravel => "impossible_travel",
356            AnomalyType::OversizedRequest => "oversized_request",
357            AnomalyType::OversizedResponse => "oversized_response",
358            AnomalyType::BandwidthSpike => "bandwidth_spike",
359            AnomalyType::ExfiltrationPattern => "exfiltration_pattern",
360            AnomalyType::UploadPattern => "upload_pattern",
361            AnomalyType::Custom => "custom",
362        }
363    }
364}
365
366/// Internal evaluation context (converted from Request).
367#[derive(Debug)]
368pub struct EvalContext<'a> {
369    pub ip: &'a str,
370    pub method: &'a str,
371    pub url: &'a str,
372    pub headers: HashMap<String, &'a str>,
373    pub args: Vec<String>,
374    pub arg_entries: Vec<ArgEntry>,
375    pub body_text: Option<&'a str>,
376    pub raw_body: Option<&'a [u8]>,
377    pub is_static: bool,
378    pub json_text: Option<String>,
379    /// Deadline for rule evaluation (prevents DoS via complex regexes)
380    pub deadline: Option<std::time::Instant>,
381}
382
383#[derive(Debug, Clone)]
384pub struct ArgEntry {
385    pub key: String,
386    pub value: String,
387}
388
389impl<'a> EvalContext<'a> {
390    /// Convert a Request to an EvalContext.
391    pub fn from_request(req: &'a Request<'a>) -> Self {
392        // Build headers map (lowercase keys)
393        let mut headers = HashMap::new();
394        for h in &req.headers {
395            headers.insert(h.name.to_ascii_lowercase(), h.value);
396        }
397
398        // Parse query string into args and arg_entries
399        let (mut args, mut arg_entries) = parse_query_args(req.path, req.query);
400
401        // Extract body text
402        let body_text = req.body.and_then(|b| std::str::from_utf8(b).ok());
403
404        // Parse body args if content-type is x-www-form-urlencoded
405        if let Some(text) = body_text {
406            if headers
407                .get("content-type")
408                .map(|ct| ct.contains("application/x-www-form-urlencoded"))
409                .unwrap_or(false)
410            {
411                // Parse body as query string and append to existing args
412                let (body_args, body_entries) = parse_query_args("", Some(text));
413                args.extend(body_args);
414                arg_entries.extend(body_entries);
415            }
416        }
417
418        // Try to parse JSON
419        let json_text = body_text.and_then(|text| {
420            if headers
421                .get("content-type")
422                .map(|ct| ct.contains("application/json"))
423                .unwrap_or(false)
424            {
425                // Attempt to parse JSON and flatten into args
426                if let Ok(value) = serde_json::from_str::<serde_json::Value>(text) {
427                    flatten_json(&value, &mut args, &mut arg_entries);
428                }
429
430                // Just store the raw JSON for pattern matching
431                Some(text.to_string())
432            } else {
433                None
434            }
435        });
436
437        // Handle Multipart/Form-Data
438        if let Some(raw_body) = req.body {
439            if let Some(content_type) = headers.get("content-type") {
440                if content_type.contains("multipart/form-data") {
441                    if let Some(boundary) = extract_multipart_boundary(content_type) {
442                        let (mp_args, mp_entries) = parse_multipart(raw_body, &boundary);
443                        args.extend(mp_args);
444                        arg_entries.extend(mp_entries);
445                    }
446                }
447            }
448        }
449
450        Self {
451            ip: req.client_ip,
452            method: req.method,
453            url: req.path,
454            headers,
455            args,
456            arg_entries,
457            body_text,
458            raw_body: req.body,
459            is_static: req.is_static,
460            json_text,
461            deadline: None,
462        }
463    }
464
465    /// Creates an EvalContext with a deadline for timeout protection.
466    pub fn from_request_with_deadline(req: &'a Request<'a>, deadline: std::time::Instant) -> Self {
467        let mut ctx = Self::from_request(req);
468        ctx.deadline = Some(deadline);
469        ctx
470    }
471
472    /// Checks if the evaluation deadline has been exceeded.
473    #[inline]
474    pub fn is_deadline_exceeded(&self) -> bool {
475        self.deadline
476            .map(|d| std::time::Instant::now() >= d)
477            .unwrap_or(false)
478    }
479}
480
481fn extract_multipart_boundary(content_type: &str) -> Option<String> {
482    content_type
483        .split(';')
484        .map(|p| p.trim())
485        .find_map(|p| {
486            let (key, value) = p.split_once('=')?;
487            if key.trim().eq_ignore_ascii_case("boundary") {
488                Some(value.trim().trim_matches('"').to_string())
489            } else {
490                None
491            }
492        })
493        .filter(|b| !b.is_empty())
494}
495
496fn parse_multipart(raw_body: &[u8], boundary: &str) -> (Vec<String>, Vec<ArgEntry>) {
497    let mut args = Vec::new();
498    let mut entries = Vec::new();
499
500    // Naive implementation: search for boundary and Content-Disposition
501    let body_str = String::from_utf8_lossy(raw_body);
502    let marker = format!("--{}", boundary);
503
504    for part in body_str.split(&marker) {
505        // Each part has headers \r\n\r\n body \r\n
506        let part = part.trim_matches('\r').trim_matches('\n');
507        if part.is_empty() || part == "--" {
508            continue;
509        }
510
511        if let Some((headers, body)) = part.split_once("\r\n\r\n") {
512            // Extract name from Content-Disposition
513            // Content-Disposition: form-data; name="fieldName"
514            let name = headers
515                .lines()
516                .find(|l| l.to_ascii_lowercase().starts_with("content-disposition"))
517                .and_then(|l| {
518                    l.split(';')
519                        .find(|p| p.trim().starts_with("name="))
520                        .map(|p| {
521                            p.trim()
522                                .trim_start_matches("name=")
523                                .trim_matches('"')
524                                .to_string()
525                        })
526                });
527
528            if let Some(key) = name {
529                let value = body.trim_end_matches("\r\n").to_string();
530                args.push(value.clone());
531                entries.push(ArgEntry { key, value });
532            }
533        }
534    }
535
536    (args, entries)
537}
538
539/// Maximum JSON nesting depth to prevent stack overflow attacks
540const MAX_JSON_DEPTH: usize = 32;
541/// Maximum total elements to extract from JSON to prevent memory exhaustion
542const MAX_JSON_ELEMENTS: usize = 1000;
543
544fn flatten_json(value: &serde_json::Value, args: &mut Vec<String>, entries: &mut Vec<ArgEntry>) {
545    let mut element_count = 0usize;
546    flatten_json_recursive(value, args, entries, 0, &mut element_count);
547}
548
549fn flatten_json_recursive(
550    value: &serde_json::Value,
551    args: &mut Vec<String>,
552    entries: &mut Vec<ArgEntry>,
553    depth: usize,
554    element_count: &mut usize,
555) {
556    // Guard: prevent stack overflow from deeply nested JSON
557    if depth > MAX_JSON_DEPTH {
558        return;
559    }
560    // Guard: prevent memory exhaustion from large JSON
561    if *element_count >= MAX_JSON_ELEMENTS {
562        return;
563    }
564
565    match value {
566        serde_json::Value::Object(map) => {
567            for (k, v) in map {
568                *element_count += 1;
569                if *element_count >= MAX_JSON_ELEMENTS {
570                    return;
571                }
572                match v {
573                    serde_json::Value::String(s) => {
574                        args.push(s.clone());
575                        entries.push(ArgEntry {
576                            key: k.clone(),
577                            value: s.clone(),
578                        });
579                    }
580                    serde_json::Value::Number(n) => {
581                        let s = n.to_string();
582                        args.push(s.clone());
583                        entries.push(ArgEntry {
584                            key: k.clone(),
585                            value: s,
586                        });
587                    }
588                    serde_json::Value::Bool(b) => {
589                        let s = b.to_string();
590                        args.push(s.clone());
591                        entries.push(ArgEntry {
592                            key: k.clone(),
593                            value: s,
594                        });
595                    }
596                    _ => flatten_json_recursive(v, args, entries, depth + 1, element_count),
597                }
598            }
599        }
600        serde_json::Value::Array(arr) => {
601            for v in arr {
602                *element_count += 1;
603                if *element_count >= MAX_JSON_ELEMENTS {
604                    return;
605                }
606                flatten_json_recursive(v, args, entries, depth + 1, element_count);
607            }
608        }
609        _ => {}
610    }
611}
612
613fn parse_query_args(path: &str, query: Option<&str>) -> (Vec<String>, Vec<ArgEntry>) {
614    let mut args = Vec::new();
615    let mut arg_entries = Vec::new();
616
617    // Get query string from path or explicit query param
618    let query_str = if let Some(q) = query {
619        q
620    } else if let Some(idx) = path.find('?') {
621        &path[idx + 1..]
622    } else {
623        return (args, arg_entries);
624    };
625
626    for pair in query_str.split('&') {
627        if pair.is_empty() {
628            continue;
629        }
630
631        // Add raw value to args
632        args.push(pair.to_string());
633
634        // Parse key=value and decode (handling + as space for form encoding)
635        if let Some((key, value)) = pair.split_once('=') {
636            let key_fixed = key.replace('+', " ");
637            let value_fixed = value.replace('+', " ");
638            let decoded_key = percent_decode_str(&key_fixed)
639                .decode_utf8_lossy()
640                .to_string();
641            let decoded_value = percent_decode_str(&value_fixed)
642                .decode_utf8_lossy()
643                .to_string();
644            arg_entries.push(ArgEntry {
645                key: decoded_key,
646                value: decoded_value,
647            });
648        } else {
649            let pair_fixed = pair.replace('+', " ");
650            let decoded_key = percent_decode_str(&pair_fixed)
651                .decode_utf8_lossy()
652                .to_string();
653            arg_entries.push(ArgEntry {
654                key: decoded_key,
655                value: String::new(),
656            });
657        }
658    }
659    (args, arg_entries)
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    #[test]
667    fn test_parse_query_args() {
668        let (args, entries) = parse_query_args("/api/users?id=1&name=test", None);
669        assert_eq!(args.len(), 2);
670        assert_eq!(entries.len(), 2);
671        assert_eq!(entries[0].key, "id");
672        assert_eq!(entries[0].value, "1");
673        assert_eq!(entries[1].key, "name");
674        assert_eq!(entries[1].value, "test");
675    }
676
677    #[test]
678    fn test_eval_context_from_request() {
679        let req = Request {
680            method: "POST",
681            path: "/api/login?username=admin",
682            headers: vec![Header::new("Content-Type", "application/json")],
683            body: Some(b"{\"password\": \"test\"}"),
684            client_ip: "192.168.1.1",
685            ..Default::default()
686        };
687
688        let ctx = EvalContext::from_request(&req);
689        assert_eq!(ctx.method, "POST");
690        assert_eq!(ctx.ip, "192.168.1.1");
691        // 2 entries: username from query + password from JSON body (flattened)
692        assert_eq!(ctx.arg_entries.len(), 2);
693        assert!(ctx.json_text.is_some());
694    }
695
696    #[test]
697    fn test_anomaly_type_default_risk() {
698        assert_eq!(AnomalyType::SessionSharing.default_risk(), 50.0);
699        assert_eq!(AnomalyType::ImpossibleTravel.default_risk(), 25.0);
700        assert_eq!(AnomalyType::Custom.default_risk(), 0.0);
701    }
702}