1use std::collections::HashSet;
26use std::sync::Arc;
27use std::time::Instant;
28
29use dashmap::DashMap;
30
31use crate::profiler::entropy::shannon_entropy;
32use crate::profiler::header_types::{
33 HeaderAnomaly, HeaderAnomalyResult, HeaderBaseline, ValueStats,
34};
35
36const DEFAULT_MAX_ENDPOINTS: usize = 10_000;
42
43const DEFAULT_MIN_SAMPLES: u64 = 50;
45
46const REQUIRED_HEADER_THRESHOLD: f64 = 0.95;
48
49const ENTROPY_Z_THRESHOLD: f64 = 3.0;
51
52const MAX_HEADERS_PER_ENDPOINT: usize = 100;
54
55const LENGTH_TOLERANCE_FACTOR: f64 = 1.5;
57
58#[derive(Debug)]
67pub struct HeaderProfiler {
68 baselines: Arc<DashMap<String, HeaderBaseline>>,
70
71 max_endpoints: usize,
73
74 min_samples: u64,
76}
77
78impl HeaderProfiler {
79 pub fn new() -> Self {
81 Self {
82 baselines: Arc::new(DashMap::with_capacity(1000)),
83 max_endpoints: DEFAULT_MAX_ENDPOINTS,
84 min_samples: DEFAULT_MIN_SAMPLES,
85 }
86 }
87
88 pub fn with_config(max_endpoints: usize, min_samples: u64) -> Self {
94 Self {
95 baselines: Arc::new(DashMap::with_capacity(max_endpoints.min(10000))),
96 max_endpoints,
97 min_samples,
98 }
99 }
100
101 pub fn learn(&self, endpoint: &str, headers: &[(String, String)]) {
114 if self.baselines.len() >= self.max_endpoints && !self.baselines.contains_key(endpoint) {
116 self.evict_oldest();
117 }
118
119 let mut baseline = self
121 .baselines
122 .entry(endpoint.to_string())
123 .or_insert_with(|| HeaderBaseline::new(endpoint.to_string()));
124
125 let present_headers: HashSet<String> =
127 headers.iter().map(|(k, _)| k.to_lowercase()).collect();
128
129 for (header_name, header_value) in headers {
131 let header_name = header_name.to_lowercase();
132
133 if baseline.header_value_stats.len() >= MAX_HEADERS_PER_ENDPOINT
135 && !baseline.header_value_stats.contains_key(&header_name)
136 {
137 continue;
138 }
139
140 let entropy = shannon_entropy(header_value);
141 let length = header_value.len();
142
143 baseline
144 .header_value_stats
145 .entry(header_name.clone())
146 .or_insert_with(ValueStats::new)
147 .update(length, entropy);
148 }
149
150 baseline.sample_count += 1;
152 baseline.last_updated = Instant::now();
153
154 if baseline.sample_count >= self.min_samples && baseline.sample_count % 10 == 0 {
156 self.recalculate_header_categories(&mut baseline, &present_headers);
157 }
158 }
159
160 pub fn analyze(&self, endpoint: &str, headers: &[(String, String)]) -> HeaderAnomalyResult {
175 let baseline = match self.baselines.get(endpoint) {
177 Some(b) => b,
178 None => return HeaderAnomalyResult::none(),
179 };
180
181 if !baseline.is_mature(self.min_samples) {
183 return HeaderAnomalyResult::none();
184 }
185
186 let mut result = HeaderAnomalyResult::new();
187
188 let present_headers: HashSet<String> =
190 headers.iter().map(|(k, _)| k.to_lowercase()).collect();
191
192 for required_header in &baseline.required_headers {
194 if !present_headers.contains(required_header) {
195 result.add(HeaderAnomaly::MissingRequired {
196 header: required_header.clone(),
197 });
198 }
199 }
200
201 for (header_name, _) in headers {
203 let header_name = header_name.to_lowercase();
204 if !baseline.is_known(&header_name) {
205 result.add(HeaderAnomaly::UnexpectedHeader {
206 header: header_name.clone(),
207 });
208 }
209 }
210
211 for (header_name, header_value) in headers {
213 let header_name = header_name.to_lowercase();
214 if let Some(stats) = baseline.get_stats(&header_name) {
215 if stats.is_mature(self.min_samples / 2) {
216 let length = header_value.len();
218 if !stats.is_length_in_range(length, LENGTH_TOLERANCE_FACTOR) {
219 result.add(HeaderAnomaly::LengthAnomaly {
220 header: header_name.clone(),
221 length,
222 expected_range: (stats.min_length, stats.max_length),
223 });
224 }
225
226 let entropy = shannon_entropy(header_value);
228 let z_score = stats.entropy_z_score(entropy);
229 if z_score.abs() > ENTROPY_Z_THRESHOLD {
230 result.add(HeaderAnomaly::EntropyAnomaly {
231 header: header_name.clone(),
232 entropy,
233 expected_mean: stats.entropy_mean,
234 });
235 }
236 }
237 }
238 }
239
240 result
241 }
242
243 pub fn get_baseline(&self, endpoint: &str) -> Option<HeaderBaseline> {
247 self.baselines.get(endpoint).map(|b| b.clone())
248 }
249
250 #[inline]
252 pub fn endpoint_count(&self) -> usize {
253 self.baselines.len()
254 }
255
256 #[inline]
258 pub fn max_endpoints(&self) -> usize {
259 self.max_endpoints
260 }
261
262 #[inline]
264 pub fn min_samples(&self) -> u64 {
265 self.min_samples
266 }
267
268 pub fn clear(&self) {
270 self.baselines.clear();
271 }
272
273 pub fn stats(&self) -> HeaderProfilerStats {
275 let mut total_samples = 0u64;
276 let mut total_headers = 0usize;
277 let mut mature_endpoints = 0usize;
278
279 for entry in self.baselines.iter() {
280 total_samples += entry.sample_count;
281 total_headers += entry.header_value_stats.len();
282 if entry.is_mature(self.min_samples) {
283 mature_endpoints += 1;
284 }
285 }
286
287 HeaderProfilerStats {
288 endpoint_count: self.baselines.len(),
289 mature_endpoints,
290 total_samples,
291 total_headers,
292 max_endpoints: self.max_endpoints,
293 }
294 }
295
296 fn recalculate_header_categories(
302 &self,
303 baseline: &mut HeaderBaseline,
304 current_headers: &HashSet<String>,
305 ) {
306 let sample_count = baseline.sample_count;
307
308 let mut new_required = HashSet::with_capacity(baseline.header_value_stats.len());
310 let mut new_optional = HashSet::with_capacity(baseline.header_value_stats.len());
311
312 for (header_name, stats) in &baseline.header_value_stats {
313 let frequency = stats.total_samples as f64 / sample_count as f64;
314
315 if frequency >= REQUIRED_HEADER_THRESHOLD {
316 new_required.insert(header_name.clone());
317 } else {
318 new_optional.insert(header_name.clone());
319 }
320 }
321
322 for header in current_headers {
324 if !new_required.contains(header) && !new_optional.contains(header) {
325 new_optional.insert(header.to_string());
326 }
327 }
328
329 baseline.required_headers = new_required;
330 baseline.optional_headers = new_optional;
331 }
332
333 fn evict_oldest(&self) {
335 let mut oldest_key: Option<String> = None;
337 let mut oldest_time = Instant::now();
338
339 for entry in self.baselines.iter() {
340 if entry.last_updated < oldest_time {
341 oldest_time = entry.last_updated;
342 oldest_key = Some(entry.key().clone());
343 }
344 }
345
346 if let Some(key) = oldest_key {
347 self.baselines.remove(&key);
348 }
349 }
350}
351
352impl Default for HeaderProfiler {
353 fn default() -> Self {
354 Self::new()
355 }
356}
357
358impl Clone for HeaderProfiler {
359 fn clone(&self) -> Self {
360 Self {
361 baselines: Arc::clone(&self.baselines),
362 max_endpoints: self.max_endpoints,
363 min_samples: self.min_samples,
364 }
365 }
366}
367
368#[derive(Debug, Clone)]
374pub struct HeaderProfilerStats {
375 pub endpoint_count: usize,
377
378 pub mature_endpoints: usize,
380
381 pub total_samples: u64,
383
384 pub total_headers: usize,
386
387 pub max_endpoints: usize,
389}
390
391#[cfg(test)]
396mod tests {
397 use super::*;
398
399 fn make_headers(pairs: &[(&str, &str)]) -> Vec<(String, String)> {
401 pairs
402 .iter()
403 .map(|(k, v)| (k.to_string(), v.to_string()))
404 .collect()
405 }
406
407 #[test]
412 fn test_profiler_new() {
413 let profiler = HeaderProfiler::new();
414 assert_eq!(profiler.endpoint_count(), 0);
415 assert_eq!(profiler.max_endpoints(), DEFAULT_MAX_ENDPOINTS);
416 assert_eq!(profiler.min_samples(), DEFAULT_MIN_SAMPLES);
417 }
418
419 #[test]
420 fn test_profiler_with_config() {
421 let profiler = HeaderProfiler::with_config(100, 10);
422 assert_eq!(profiler.max_endpoints(), 100);
423 assert_eq!(profiler.min_samples(), 10);
424 }
425
426 #[test]
427 fn test_profiler_learn_creates_baseline() {
428 let profiler = HeaderProfiler::new();
429 let headers = make_headers(&[
430 ("Content-Type", "application/json"),
431 ("Authorization", "Bearer token123"),
432 ]);
433
434 profiler.learn("/api/users", &headers);
435
436 assert_eq!(profiler.endpoint_count(), 1);
437 let baseline = profiler.get_baseline("/api/users").unwrap();
438 assert_eq!(baseline.sample_count, 1);
439 assert_eq!(baseline.header_value_stats.len(), 2);
440 }
441
442 #[test]
443 fn test_profiler_learn_accumulates() {
444 let profiler = HeaderProfiler::new();
445
446 for i in 0..10 {
447 let headers = make_headers(&[
448 ("Content-Type", "application/json"),
449 ("X-Request-ID", &format!("req-{}", i)),
450 ]);
451 profiler.learn("/api/test", &headers);
452 }
453
454 let baseline = profiler.get_baseline("/api/test").unwrap();
455 assert_eq!(baseline.sample_count, 10);
456
457 let ct_stats = baseline.get_stats("content-type").unwrap();
459 assert_eq!(ct_stats.total_samples, 10);
460 }
461
462 #[test]
463 fn test_profiler_analyze_no_baseline() {
464 let profiler = HeaderProfiler::new();
465 let headers = make_headers(&[("Content-Type", "application/json")]);
466
467 let result = profiler.analyze("/unknown", &headers);
468 assert!(!result.has_anomalies());
469 }
470
471 #[test]
472 fn test_profiler_analyze_immature_baseline() {
473 let profiler = HeaderProfiler::with_config(100, 10);
474
475 for _ in 0..5 {
477 let headers = make_headers(&[("Content-Type", "application/json")]);
478 profiler.learn("/api/test", &headers);
479 }
480
481 let headers = make_headers(&[("Content-Type", "application/json")]);
482 let result = profiler.analyze("/api/test", &headers);
483
484 assert!(!result.has_anomalies());
486 }
487
488 #[test]
493 fn test_detect_missing_required_header() {
494 let profiler = HeaderProfiler::with_config(100, 10);
495
496 for _ in 0..50 {
498 let headers = make_headers(&[
499 ("Content-Type", "application/json"),
500 ("Authorization", "Bearer token"),
501 ]);
502 profiler.learn("/api/secure", &headers);
503 }
504
505 let headers = make_headers(&[("Content-Type", "application/json")]);
507 let result = profiler.analyze("/api/secure", &headers);
508
509 assert!(result.has_anomalies());
510 let missing = result.anomalies.iter().find(
511 |a| matches!(a, HeaderAnomaly::MissingRequired { header } if header == "authorization"),
512 );
513 assert!(missing.is_some());
514 }
515
516 #[test]
517 fn test_detect_unexpected_header() {
518 let profiler = HeaderProfiler::with_config(100, 10);
519
520 for _ in 0..50 {
522 let headers = make_headers(&[("Content-Type", "application/json")]);
523 profiler.learn("/api/test", &headers);
524 }
525
526 let headers = make_headers(&[
528 ("Content-Type", "application/json"),
529 ("X-Evil-Header", "malicious"),
530 ]);
531 let result = profiler.analyze("/api/test", &headers);
532
533 assert!(result.has_anomalies());
534 let unexpected = result.anomalies.iter().find(|a| {
535 matches!(a, HeaderAnomaly::UnexpectedHeader { header } if header == "x-evil-header")
536 });
537 assert!(unexpected.is_some());
538 }
539
540 #[test]
541 fn test_detect_length_anomaly() {
542 let profiler = HeaderProfiler::with_config(100, 20);
543
544 for _ in 0..50 {
546 let headers = make_headers(&[("Authorization", "Bearer short_token")]);
547 profiler.learn("/api/auth", &headers);
548 }
549
550 let long_token = "a".repeat(10000);
552 let headers = make_headers(&[("Authorization", &format!("Bearer {}", long_token))]);
553 let result = profiler.analyze("/api/auth", &headers);
554
555 assert!(result.has_anomalies());
556 let length_anomaly = result.anomalies.iter().find(|a| {
557 matches!(a, HeaderAnomaly::LengthAnomaly { header, .. } if header == "authorization")
558 });
559 assert!(length_anomaly.is_some());
560 }
561
562 #[test]
563 fn test_detect_entropy_anomaly() {
564 let profiler = HeaderProfiler::with_config(100, 30);
565
566 for i in 0..60 {
568 let headers = make_headers(&[("X-Token", &format!("user-token-{:05}", i))]);
569 profiler.learn("/api/token", &headers);
570 }
571
572 let high_entropy = "xK9mNqR5vL8jYpW2eTfGhIuB7cDaZoS4";
574 let headers = make_headers(&[("X-Token", high_entropy)]);
575 let result = profiler.analyze("/api/token", &headers);
576
577 if result.has_anomalies() {
580 let has_entropy_anomaly = result.anomalies.iter().any(|a| {
581 matches!(a, HeaderAnomaly::EntropyAnomaly { header, .. } if header == "x-token")
582 });
583 if has_entropy_anomaly {
584 }
586 }
587 }
588
589 #[test]
594 fn test_risk_contribution_accumulates() {
595 let profiler = HeaderProfiler::with_config(100, 10);
596
597 for _ in 0..50 {
599 let headers = make_headers(&[
600 ("Content-Type", "application/json"),
601 ("Authorization", "Bearer token"),
602 ]);
603 profiler.learn("/api/risk", &headers);
604 }
605
606 let headers = make_headers(&[("X-Unexpected-1", "value"), ("X-Unexpected-2", "value")]);
608 let result = profiler.analyze("/api/risk", &headers);
609
610 assert!(result.has_anomalies());
611 assert!(result.risk_contribution > 0);
616 assert!(result.risk_contribution <= 50);
617 }
618
619 #[test]
624 fn test_lru_eviction() {
625 let profiler = HeaderProfiler::with_config(3, 10);
626
627 profiler.learn("/api/1", &make_headers(&[("X", "1")]));
629 std::thread::sleep(std::time::Duration::from_millis(10));
630 profiler.learn("/api/2", &make_headers(&[("X", "2")]));
631 std::thread::sleep(std::time::Duration::from_millis(10));
632 profiler.learn("/api/3", &make_headers(&[("X", "3")]));
633
634 assert_eq!(profiler.endpoint_count(), 3);
635
636 profiler.learn("/api/4", &make_headers(&[("X", "4")]));
638
639 assert_eq!(profiler.endpoint_count(), 3);
640 assert!(profiler.get_baseline("/api/1").is_none());
641 assert!(profiler.get_baseline("/api/4").is_some());
642 }
643
644 #[test]
649 fn test_concurrent_learn() {
650 use std::thread;
651
652 let profiler = Arc::new(HeaderProfiler::new());
653
654 let handles: Vec<_> = (0..4)
655 .map(|i| {
656 let p = Arc::clone(&profiler);
657 thread::spawn(move || {
658 for j in 0..100 {
659 let headers = make_headers(&[
660 ("Thread", &format!("{}", i)),
661 ("Request", &format!("{}", j)),
662 ]);
663 p.learn(&format!("/api/thread-{}", i), &headers);
664 }
665 })
666 })
667 .collect();
668
669 for h in handles {
670 h.join().unwrap();
671 }
672
673 assert_eq!(profiler.endpoint_count(), 4);
675 }
676
677 #[test]
678 fn test_concurrent_learn_and_analyze() {
679 use std::thread;
680
681 let profiler = Arc::new(HeaderProfiler::with_config(100, 10));
682
683 for _ in 0..20 {
685 profiler.learn(
686 "/api/concurrent",
687 &make_headers(&[("Content-Type", "application/json")]),
688 );
689 }
690
691 let handles: Vec<_> = (0..4)
692 .map(|i| {
693 let p = Arc::clone(&profiler);
694 thread::spawn(move || {
695 for _ in 0..50 {
696 if i % 2 == 0 {
697 p.learn(
698 "/api/concurrent",
699 &make_headers(&[("Content-Type", "application/json")]),
700 );
701 } else {
702 let _ = p.analyze(
703 "/api/concurrent",
704 &make_headers(&[("Content-Type", "application/json")]),
705 );
706 }
707 }
708 })
709 })
710 .collect();
711
712 for h in handles {
713 h.join().unwrap();
714 }
715
716 let baseline = profiler.get_baseline("/api/concurrent").unwrap();
718 assert!(baseline.sample_count > 20);
719 }
720
721 #[test]
726 fn test_profiler_stats() {
727 let profiler = HeaderProfiler::with_config(100, 10);
728
729 for _ in 0..50 {
731 profiler.learn(
732 "/api/mature",
733 &make_headers(&[("Content-Type", "application/json")]),
734 );
735 }
736 for _ in 0..5 {
737 profiler.learn("/api/immature", &make_headers(&[("X-Token", "test")]));
738 }
739
740 let stats = profiler.stats();
741 assert_eq!(stats.endpoint_count, 2);
742 assert_eq!(stats.mature_endpoints, 1); assert_eq!(stats.total_samples, 55);
744 assert_eq!(stats.total_headers, 2); }
746
747 #[test]
752 fn test_profiler_clear() {
753 let profiler = HeaderProfiler::new();
754
755 profiler.learn("/api/1", &make_headers(&[("X", "1")]));
756 profiler.learn("/api/2", &make_headers(&[("X", "2")]));
757 assert_eq!(profiler.endpoint_count(), 2);
758
759 profiler.clear();
760 assert_eq!(profiler.endpoint_count(), 0);
761 }
762
763 #[test]
768 fn test_profiler_clone_shares_state() {
769 let profiler1 = HeaderProfiler::new();
770 profiler1.learn("/api/shared", &make_headers(&[("X", "1")]));
771
772 let profiler2 = profiler1.clone();
773 profiler2.learn("/api/shared", &make_headers(&[("X", "2")]));
774
775 let baseline = profiler1.get_baseline("/api/shared").unwrap();
777 assert_eq!(baseline.sample_count, 2);
778 }
779
780 #[test]
781 fn test_header_ordering_is_ignored() {
782 let profiler = HeaderProfiler::with_config(100, 10);
783
784 for _ in 0..20 {
786 profiler.learn(
787 "/api/order",
788 &make_headers(&[("A", "1"), ("B", "2"), ("C", "3")]),
789 );
790 }
791
792 let result = profiler.analyze(
794 "/api/order",
795 &make_headers(&[("C", "3"), ("A", "1"), ("B", "2")]),
796 );
797
798 assert!(!result.has_anomalies());
799 }
800
801 #[test]
802 fn test_header_case_sensitivity() {
803 let profiler = HeaderProfiler::with_config(100, 10);
804
805 for _ in 0..20 {
807 profiler.learn("/api/case", &make_headers(&[("X-Custom", "value")]));
808 }
809
810 let result = profiler.analyze("/api/case", &make_headers(&[("x-custom", "value")]));
812
813 assert!(
815 !result.has_anomalies(),
816 "Header analysis should be case-insensitive"
817 );
818 }
819}