Skip to main content

synapse_pingora/correlation/detectors/
attack_sequence.rs

1//! Attack Sequence Detector
2//!
3//! Identifies coordinated attacks where multiple IPs send identical
4//! or highly similar attack payloads. Weight: 50 (highest signal).
5
6use std::net::IpAddr;
7use std::time::{Duration, Instant};
8
9use dashmap::DashSet;
10
11use super::common::TimeWindowedIndex;
12use super::{Detector, DetectorResult};
13use crate::correlation::{CampaignUpdate, CorrelationReason, CorrelationType, FingerprintIndex};
14
15/// Configuration for attack sequence detection
16#[derive(Debug, Clone)]
17pub struct AttackSequenceConfig {
18    /// Minimum IPs sharing same payload to trigger detection
19    pub min_ips: usize,
20    /// Time window for attack correlation
21    pub window: Duration,
22    /// Minimum payload similarity threshold (0.0 to 1.0)
23    pub similarity_threshold: f64,
24    /// Base confidence multiplier for confidence calculation (0.0 to 1.0)
25    pub base_confidence: f64,
26    /// Divisor for scaling confidence by IP count
27    pub confidence_scale_divisor: f64,
28    /// Maximum entries per payload hash (0 = unlimited)
29    pub max_entries_per_hash: usize,
30}
31
32impl Default for AttackSequenceConfig {
33    fn default() -> Self {
34        Self {
35            min_ips: 2,
36            window: Duration::from_secs(300), // 5 minutes
37            similarity_threshold: 0.95,
38            base_confidence: 0.9,
39            confidence_scale_divisor: 10.0,
40            max_entries_per_hash: 1000,
41        }
42    }
43}
44
45/// Represents an observed attack payload
46#[derive(Debug, Clone)]
47pub struct AttackPayload {
48    /// Hash of the normalized payload
49    pub payload_hash: String,
50    /// Attack classification (sqli, xss, path_traversal, etc.)
51    pub attack_type: String,
52    /// Target path
53    pub target_path: String,
54    /// When this was observed
55    pub timestamp: Instant,
56}
57
58/// Detects campaigns based on shared attack payloads
59pub struct AttackSequenceDetector {
60    config: AttackSequenceConfig,
61    /// Payload hash -> IPs (using common TimeWindowedIndex)
62    payload_index: TimeWindowedIndex<String, IpAddr>,
63    /// Already detected payload groups
64    detected: DashSet<String>,
65}
66
67impl AttackSequenceDetector {
68    pub fn new(config: AttackSequenceConfig) -> Self {
69        let payload_index = TimeWindowedIndex::new(config.window, config.max_entries_per_hash);
70        Self {
71            config,
72            payload_index,
73            detected: DashSet::new(),
74        }
75    }
76
77    /// Record an attack payload observation
78    pub fn record_attack(&self, ip: IpAddr, payload: AttackPayload) {
79        self.payload_index
80            .insert_with_timestamp(payload.payload_hash, ip, payload.timestamp);
81    }
82
83    /// Get IPs sharing a specific payload
84    pub fn get_ips_for_payload(&self, payload_hash: &str) -> Vec<IpAddr> {
85        self.payload_index.get_unique(&payload_hash.to_string())
86    }
87
88    /// Get groups of IPs sharing payloads above threshold
89    fn get_correlated_groups(&self) -> Vec<(String, Vec<IpAddr>)> {
90        self.payload_index
91            .get_groups_with_min_unique_count(self.config.min_ips)
92            .into_iter()
93            .filter(|(hash, _)| !self.detected.contains(hash))
94            .collect()
95    }
96}
97
98impl Detector for AttackSequenceDetector {
99    fn name(&self) -> &'static str {
100        "attack_sequence"
101    }
102
103    fn analyze(&self, _index: &FingerprintIndex) -> DetectorResult<Vec<CampaignUpdate>> {
104        let groups = self.get_correlated_groups();
105        let mut updates = Vec::new();
106
107        for (payload_hash, ips) in groups {
108            let confidence = (ips.len() as f64 / self.config.confidence_scale_divisor).min(1.0)
109                * self.config.base_confidence;
110
111            updates.push(CampaignUpdate {
112                campaign_id: Some(format!(
113                    "attack-seq-{}",
114                    &payload_hash[..8.min(payload_hash.len())]
115                )),
116                status: None,
117                confidence: Some(confidence),
118                attack_types: Some(vec!["attack_sequence".to_string()]),
119                add_member_ips: Some(ips.iter().map(|ip| ip.to_string()).collect()),
120                add_correlation_reason: Some(CorrelationReason::new(
121                    CorrelationType::AttackSequence,
122                    confidence,
123                    format!("{} IPs sharing identical attack payload", ips.len()),
124                    ips.iter().map(|ip| ip.to_string()).collect(),
125                )),
126                ..Default::default()
127            });
128
129            // Mark as detected
130            self.detected.insert(payload_hash);
131        }
132
133        Ok(updates)
134    }
135
136    fn should_trigger(&self, ip: &IpAddr, _index: &FingerprintIndex) -> bool {
137        // Check if this IP is part of any payload group that's close to threshold
138        self.payload_index.any_key_has_value_with_min_count(
139            |entry_ip| entry_ip == ip,
140            self.config.min_ips.saturating_sub(1).max(1),
141        )
142    }
143
144    fn scan_interval_ms(&self) -> u64 {
145        3000
146    } // 3 seconds
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn test_config_default() {
155        let config = AttackSequenceConfig::default();
156        assert_eq!(config.min_ips, 2);
157        assert_eq!(config.window, Duration::from_secs(300));
158    }
159
160    #[test]
161    fn test_record_attack() {
162        let detector = AttackSequenceDetector::new(AttackSequenceConfig::default());
163        let ip: IpAddr = "192.168.1.1".parse().unwrap();
164
165        detector.record_attack(
166            ip,
167            AttackPayload {
168                payload_hash: "hash123".to_string(),
169                attack_type: "sqli".to_string(),
170                target_path: "/api/login".to_string(),
171                timestamp: Instant::now(),
172            },
173        );
174
175        let ips = detector.get_ips_for_payload("hash123");
176        assert_eq!(ips.len(), 1);
177        assert_eq!(ips[0], ip);
178    }
179
180    #[test]
181    fn test_detection_with_multiple_ips() {
182        let detector = AttackSequenceDetector::new(AttackSequenceConfig::default());
183
184        for i in 1..=3 {
185            let ip: IpAddr = format!("192.168.1.{}", i).parse().unwrap();
186            detector.record_attack(
187                ip,
188                AttackPayload {
189                    payload_hash: "shared_payload".to_string(),
190                    attack_type: "sqli".to_string(),
191                    target_path: "/api".to_string(),
192                    timestamp: Instant::now(),
193                },
194            );
195        }
196
197        let index = FingerprintIndex::new();
198        let updates = detector.analyze(&index).unwrap();
199
200        assert_eq!(updates.len(), 1);
201        assert!(updates[0].add_member_ips.as_ref().unwrap().len() == 3);
202    }
203
204    #[test]
205    fn test_no_detection_below_threshold() {
206        let detector = AttackSequenceDetector::new(AttackSequenceConfig {
207            min_ips: 3,
208            ..Default::default()
209        });
210
211        let ip: IpAddr = "192.168.1.1".parse().unwrap();
212        detector.record_attack(
213            ip,
214            AttackPayload {
215                payload_hash: "hash".to_string(),
216                attack_type: "xss".to_string(),
217                target_path: "/".to_string(),
218                timestamp: Instant::now(),
219            },
220        );
221
222        let index = FingerprintIndex::new();
223        let updates = detector.analyze(&index).unwrap();
224        assert!(updates.is_empty());
225    }
226
227    #[test]
228    fn test_should_trigger() {
229        let detector = AttackSequenceDetector::new(AttackSequenceConfig::default());
230        let ip1: IpAddr = "10.0.0.1".parse().unwrap();
231        let ip2: IpAddr = "10.0.0.2".parse().unwrap();
232
233        detector.record_attack(
234            ip1,
235            AttackPayload {
236                payload_hash: "test".to_string(),
237                attack_type: "sqli".to_string(),
238                target_path: "/".to_string(),
239                timestamp: Instant::now(),
240            },
241        );
242
243        // Should trigger because one more IP would reach threshold
244        let index = FingerprintIndex::new();
245        assert!(detector.should_trigger(&ip1, &index));
246    }
247}