1use std::collections::HashMap;
4
5use percent_encoding::percent_decode_str;
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Default)]
10pub struct Request<'a> {
11 pub method: &'a str,
13 pub path: &'a str,
15 pub query: Option<&'a str>,
17 pub headers: Vec<Header<'a>>,
19 pub body: Option<&'a [u8]>,
21 pub client_ip: &'a str,
23 pub is_static: bool,
25}
26
27#[derive(Debug, Clone)]
29pub struct Header<'a> {
30 pub name: &'a str,
31 pub value: &'a str,
32}
33
34impl<'a> Header<'a> {
35 pub fn new(name: &'a str, value: &'a str) -> Self {
36 Self { name, value }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct Verdict {
43 pub action: Action,
45 pub risk_score: u16,
47 pub matched_rules: Vec<u32>,
49 pub entity_risk: f64,
51 pub entity_blocked: bool,
53 pub block_reason: Option<String>,
55 pub risk_contributions: Vec<RiskContribution>,
57
58 pub endpoint_template: Option<String>,
61 pub endpoint_risk: Option<f32>,
63 pub anomaly_score: Option<f64>,
65 pub adjusted_threshold: Option<f64>,
67 pub anomaly_signals: Vec<AnomalySignal>,
69
70 pub timed_out: bool,
73 pub rules_evaluated: Option<u32>,
75}
76
77impl Default for Verdict {
78 fn default() -> Self {
79 Self {
80 action: Action::Allow,
81 risk_score: 0,
82 matched_rules: Vec::new(),
83 entity_risk: 0.0,
84 entity_blocked: false,
85 block_reason: None,
86 risk_contributions: Vec::new(),
87 endpoint_template: None,
88 endpoint_risk: None,
89 anomaly_score: None,
90 adjusted_threshold: None,
91 anomaly_signals: Vec::new(),
92 timed_out: false,
93 rules_evaluated: None,
94 }
95 }
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100#[repr(u8)]
101pub enum Action {
102 Allow = 0,
103 Block = 1,
104}
105
106#[derive(Debug, Clone)]
108pub struct RiskContribution {
109 pub rule_id: u32,
111 pub base_risk: f64,
113 pub multiplier: f64,
115 pub final_risk: f64,
117}
118
119impl RiskContribution {
120 #[inline]
122 pub fn new(rule_id: u32, base_risk: f64, multiplier: f64) -> Self {
123 Self {
124 rule_id,
125 base_risk,
126 multiplier,
127 final_risk: base_risk * multiplier,
128 }
129 }
130}
131
132#[derive(Debug, Clone)]
134pub struct AnomalySignal {
135 pub signal_type: AnomalySignalType,
137 pub severity: f32,
139 pub detail: String,
141}
142
143impl AnomalySignal {
144 pub fn to_anomaly_type(&self) -> AnomalyType {
146 match self.signal_type {
147 AnomalySignalType::PayloadSize => AnomalyType::OversizedRequest,
148 AnomalySignalType::RequestRate => AnomalyType::VelocitySpike,
149 AnomalySignalType::ErrorRate => AnomalyType::TimingAnomaly,
150 AnomalySignalType::ParameterAnomaly => AnomalyType::Custom,
151 AnomalySignalType::ContentTypeAnomaly => AnomalyType::Custom,
152 AnomalySignalType::TimingAnomaly => AnomalyType::TimingAnomaly,
153 AnomalySignalType::SchemaViolation => AnomalyType::Custom,
154 }
155 }
156}
157
158#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub enum AnomalySignalType {
161 PayloadSize,
163 RequestRate,
165 ErrorRate,
167 ParameterAnomaly,
169 ContentTypeAnomaly,
171 TimingAnomaly,
173 SchemaViolation,
175}
176
177#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
179pub enum BlockingMode {
180 #[default]
182 Learning,
183 Enforcement,
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct RiskConfig {
190 pub max_risk: f64,
192 pub enable_repeat_multipliers: bool,
194 pub anomaly_risk_overrides: HashMap<AnomalyType, f64>,
196 pub anomaly_blocking_threshold: f64,
198 pub blocking_mode: BlockingMode,
200}
201
202impl Default for RiskConfig {
203 fn default() -> Self {
204 Self {
205 max_risk: 100.0,
206 enable_repeat_multipliers: true,
207 anomaly_risk_overrides: HashMap::new(),
208 anomaly_blocking_threshold: 10.0,
209 blocking_mode: BlockingMode::Learning, }
211 }
212}
213
214impl RiskConfig {
215 pub fn with_extended_range() -> Self {
217 Self {
218 max_risk: 1000.0,
219 ..Default::default()
220 }
221 }
222
223 #[inline]
225 pub fn anomaly_risk(&self, anomaly_type: AnomalyType) -> f64 {
226 self.anomaly_risk_overrides
227 .get(&anomaly_type)
228 .copied()
229 .unwrap_or_else(|| anomaly_type.default_risk())
230 }
231
232 pub fn set_anomaly_risk(&mut self, anomaly_type: AnomalyType, risk: f64) {
234 self.anomaly_risk_overrides.insert(anomaly_type, risk);
235 }
236
237 pub fn reset_anomaly_risk(&mut self, anomaly_type: AnomalyType) {
239 self.anomaly_risk_overrides.remove(&anomaly_type);
240 }
241}
242
243#[derive(Debug, Clone)]
247pub struct AnomalyContribution {
248 pub anomaly_type: AnomalyType,
250 pub risk: f64,
252 pub reason: Option<String>,
254 pub applied_at: u64,
256}
257
258impl AnomalyContribution {
259 pub fn new(anomaly_type: AnomalyType, risk: f64, reason: Option<String>, now: u64) -> Self {
261 Self {
262 anomaly_type,
263 risk,
264 reason,
265 applied_at: now,
266 }
267 }
268}
269
270#[inline]
284pub fn repeat_multiplier(match_count: u32) -> f64 {
285 match match_count {
286 0 | 1 => 1.0,
287 2..=5 => 1.25,
288 6..=10 => 1.5,
289 _ => 2.0,
290 }
291}
292
293#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
295#[repr(u8)]
296pub enum AnomalyType {
297 FingerprintChange = 0,
299 SessionSharing = 1,
301 TokenReuse = 2,
303 VelocitySpike = 3,
305 RotationPattern = 4,
307 TimingAnomaly = 5,
309 ImpossibleTravel = 6,
311 OversizedRequest = 7,
313 OversizedResponse = 8,
315 BandwidthSpike = 9,
317 ExfiltrationPattern = 10,
319 UploadPattern = 11,
321 Custom = 255,
323}
324
325impl AnomalyType {
326 #[inline]
328 pub const fn default_risk(self) -> f64 {
329 match self {
330 AnomalyType::SessionSharing => 50.0,
331 AnomalyType::ExfiltrationPattern => 40.0,
332 AnomalyType::TokenReuse => 40.0,
333 AnomalyType::RotationPattern => 35.0,
334 AnomalyType::UploadPattern => 35.0,
335 AnomalyType::FingerprintChange => 30.0,
336 AnomalyType::BandwidthSpike => 25.0,
337 AnomalyType::ImpossibleTravel => 25.0,
338 AnomalyType::OversizedRequest => 20.0,
339 AnomalyType::OversizedResponse => 15.0,
340 AnomalyType::VelocitySpike => 15.0,
341 AnomalyType::TimingAnomaly => 10.0,
342 AnomalyType::Custom => 0.0,
343 }
344 }
345
346 pub const fn name(self) -> &'static str {
348 match self {
349 AnomalyType::FingerprintChange => "fingerprint_change",
350 AnomalyType::SessionSharing => "session_sharing",
351 AnomalyType::TokenReuse => "token_reuse",
352 AnomalyType::VelocitySpike => "velocity_spike",
353 AnomalyType::RotationPattern => "rotation_pattern",
354 AnomalyType::TimingAnomaly => "timing_anomaly",
355 AnomalyType::ImpossibleTravel => "impossible_travel",
356 AnomalyType::OversizedRequest => "oversized_request",
357 AnomalyType::OversizedResponse => "oversized_response",
358 AnomalyType::BandwidthSpike => "bandwidth_spike",
359 AnomalyType::ExfiltrationPattern => "exfiltration_pattern",
360 AnomalyType::UploadPattern => "upload_pattern",
361 AnomalyType::Custom => "custom",
362 }
363 }
364}
365
366#[derive(Debug)]
368pub struct EvalContext<'a> {
369 pub ip: &'a str,
370 pub method: &'a str,
371 pub url: &'a str,
372 pub headers: HashMap<String, &'a str>,
373 pub args: Vec<String>,
374 pub arg_entries: Vec<ArgEntry>,
375 pub body_text: Option<&'a str>,
376 pub raw_body: Option<&'a [u8]>,
377 pub is_static: bool,
378 pub json_text: Option<String>,
379 pub deadline: Option<std::time::Instant>,
381}
382
383#[derive(Debug, Clone)]
384pub struct ArgEntry {
385 pub key: String,
386 pub value: String,
387}
388
389impl<'a> EvalContext<'a> {
390 pub fn from_request(req: &'a Request<'a>) -> Self {
392 let mut headers = HashMap::new();
394 for h in &req.headers {
395 headers.insert(h.name.to_ascii_lowercase(), h.value);
396 }
397
398 let (mut args, mut arg_entries) = parse_query_args(req.path, req.query);
400
401 let body_text = req.body.and_then(|b| std::str::from_utf8(b).ok());
403
404 if let Some(text) = body_text {
406 if headers
407 .get("content-type")
408 .map(|ct| ct.contains("application/x-www-form-urlencoded"))
409 .unwrap_or(false)
410 {
411 let (body_args, body_entries) = parse_query_args("", Some(text));
413 args.extend(body_args);
414 arg_entries.extend(body_entries);
415 }
416 }
417
418 let json_text = body_text.and_then(|text| {
420 if headers
421 .get("content-type")
422 .map(|ct| ct.contains("application/json"))
423 .unwrap_or(false)
424 {
425 if let Ok(value) = serde_json::from_str::<serde_json::Value>(text) {
427 flatten_json(&value, &mut args, &mut arg_entries);
428 }
429
430 Some(text.to_string())
432 } else {
433 None
434 }
435 });
436
437 if let Some(raw_body) = req.body {
439 if let Some(content_type) = headers.get("content-type") {
440 if content_type.contains("multipart/form-data") {
441 if let Some(boundary) = extract_multipart_boundary(content_type) {
442 let (mp_args, mp_entries) = parse_multipart(raw_body, &boundary);
443 args.extend(mp_args);
444 arg_entries.extend(mp_entries);
445 }
446 }
447 }
448 }
449
450 Self {
451 ip: req.client_ip,
452 method: req.method,
453 url: req.path,
454 headers,
455 args,
456 arg_entries,
457 body_text,
458 raw_body: req.body,
459 is_static: req.is_static,
460 json_text,
461 deadline: None,
462 }
463 }
464
465 pub fn from_request_with_deadline(req: &'a Request<'a>, deadline: std::time::Instant) -> Self {
467 let mut ctx = Self::from_request(req);
468 ctx.deadline = Some(deadline);
469 ctx
470 }
471
472 #[inline]
474 pub fn is_deadline_exceeded(&self) -> bool {
475 self.deadline
476 .map(|d| std::time::Instant::now() >= d)
477 .unwrap_or(false)
478 }
479}
480
481fn extract_multipart_boundary(content_type: &str) -> Option<String> {
482 content_type
483 .split(';')
484 .map(|p| p.trim())
485 .find_map(|p| {
486 let (key, value) = p.split_once('=')?;
487 if key.trim().eq_ignore_ascii_case("boundary") {
488 Some(value.trim().trim_matches('"').to_string())
489 } else {
490 None
491 }
492 })
493 .filter(|b| !b.is_empty())
494}
495
496fn parse_multipart(raw_body: &[u8], boundary: &str) -> (Vec<String>, Vec<ArgEntry>) {
497 let mut args = Vec::new();
498 let mut entries = Vec::new();
499
500 let body_str = String::from_utf8_lossy(raw_body);
502 let marker = format!("--{}", boundary);
503
504 for part in body_str.split(&marker) {
505 let part = part.trim_matches('\r').trim_matches('\n');
507 if part.is_empty() || part == "--" {
508 continue;
509 }
510
511 if let Some((headers, body)) = part.split_once("\r\n\r\n") {
512 let name = headers
515 .lines()
516 .find(|l| l.to_ascii_lowercase().starts_with("content-disposition"))
517 .and_then(|l| {
518 l.split(';')
519 .find(|p| p.trim().starts_with("name="))
520 .map(|p| {
521 p.trim()
522 .trim_start_matches("name=")
523 .trim_matches('"')
524 .to_string()
525 })
526 });
527
528 if let Some(key) = name {
529 let value = body.trim_end_matches("\r\n").to_string();
530 args.push(value.clone());
531 entries.push(ArgEntry { key, value });
532 }
533 }
534 }
535
536 (args, entries)
537}
538
539const MAX_JSON_DEPTH: usize = 32;
541const MAX_JSON_ELEMENTS: usize = 1000;
543
544fn flatten_json(value: &serde_json::Value, args: &mut Vec<String>, entries: &mut Vec<ArgEntry>) {
545 let mut element_count = 0usize;
546 flatten_json_recursive(value, args, entries, 0, &mut element_count);
547}
548
549fn flatten_json_recursive(
550 value: &serde_json::Value,
551 args: &mut Vec<String>,
552 entries: &mut Vec<ArgEntry>,
553 depth: usize,
554 element_count: &mut usize,
555) {
556 if depth > MAX_JSON_DEPTH {
558 return;
559 }
560 if *element_count >= MAX_JSON_ELEMENTS {
562 return;
563 }
564
565 match value {
566 serde_json::Value::Object(map) => {
567 for (k, v) in map {
568 *element_count += 1;
569 if *element_count >= MAX_JSON_ELEMENTS {
570 return;
571 }
572 match v {
573 serde_json::Value::String(s) => {
574 args.push(s.clone());
575 entries.push(ArgEntry {
576 key: k.clone(),
577 value: s.clone(),
578 });
579 }
580 serde_json::Value::Number(n) => {
581 let s = n.to_string();
582 args.push(s.clone());
583 entries.push(ArgEntry {
584 key: k.clone(),
585 value: s,
586 });
587 }
588 serde_json::Value::Bool(b) => {
589 let s = b.to_string();
590 args.push(s.clone());
591 entries.push(ArgEntry {
592 key: k.clone(),
593 value: s,
594 });
595 }
596 _ => flatten_json_recursive(v, args, entries, depth + 1, element_count),
597 }
598 }
599 }
600 serde_json::Value::Array(arr) => {
601 for v in arr {
602 *element_count += 1;
603 if *element_count >= MAX_JSON_ELEMENTS {
604 return;
605 }
606 flatten_json_recursive(v, args, entries, depth + 1, element_count);
607 }
608 }
609 _ => {}
610 }
611}
612
613fn parse_query_args(path: &str, query: Option<&str>) -> (Vec<String>, Vec<ArgEntry>) {
614 let mut args = Vec::new();
615 let mut arg_entries = Vec::new();
616
617 let query_str = if let Some(q) = query {
619 q
620 } else if let Some(idx) = path.find('?') {
621 &path[idx + 1..]
622 } else {
623 return (args, arg_entries);
624 };
625
626 for pair in query_str.split('&') {
627 if pair.is_empty() {
628 continue;
629 }
630
631 args.push(pair.to_string());
633
634 if let Some((key, value)) = pair.split_once('=') {
636 let key_fixed = key.replace('+', " ");
637 let value_fixed = value.replace('+', " ");
638 let decoded_key = percent_decode_str(&key_fixed)
639 .decode_utf8_lossy()
640 .to_string();
641 let decoded_value = percent_decode_str(&value_fixed)
642 .decode_utf8_lossy()
643 .to_string();
644 arg_entries.push(ArgEntry {
645 key: decoded_key,
646 value: decoded_value,
647 });
648 } else {
649 let pair_fixed = pair.replace('+', " ");
650 let decoded_key = percent_decode_str(&pair_fixed)
651 .decode_utf8_lossy()
652 .to_string();
653 arg_entries.push(ArgEntry {
654 key: decoded_key,
655 value: String::new(),
656 });
657 }
658 }
659 (args, arg_entries)
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665
666 #[test]
667 fn test_parse_query_args() {
668 let (args, entries) = parse_query_args("/api/users?id=1&name=test", None);
669 assert_eq!(args.len(), 2);
670 assert_eq!(entries.len(), 2);
671 assert_eq!(entries[0].key, "id");
672 assert_eq!(entries[0].value, "1");
673 assert_eq!(entries[1].key, "name");
674 assert_eq!(entries[1].value, "test");
675 }
676
677 #[test]
678 fn test_eval_context_from_request() {
679 let req = Request {
680 method: "POST",
681 path: "/api/login?username=admin",
682 headers: vec![Header::new("Content-Type", "application/json")],
683 body: Some(b"{\"password\": \"test\"}"),
684 client_ip: "192.168.1.1",
685 ..Default::default()
686 };
687
688 let ctx = EvalContext::from_request(&req);
689 assert_eq!(ctx.method, "POST");
690 assert_eq!(ctx.ip, "192.168.1.1");
691 assert_eq!(ctx.arg_entries.len(), 2);
693 assert!(ctx.json_text.is_some());
694 }
695
696 #[test]
697 fn test_anomaly_type_default_risk() {
698 assert_eq!(AnomalyType::SessionSharing.default_risk(), 50.0);
699 assert_eq!(AnomalyType::ImpossibleTravel.default_risk(), 25.0);
700 assert_eq!(AnomalyType::Custom.default_risk(), 0.0);
701 }
702}