Skip to main content

synapse_pingora/payload/
manager.rs

1//! Payload profiling manager - coordinates endpoint and entity tracking.
2
3use dashmap::DashMap;
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use std::sync::atomic::{AtomicU64, Ordering};
7
8use super::anomaly::{
9    PayloadAnomaly, PayloadAnomalyMetadata, PayloadAnomalySeverity, PayloadAnomalyType,
10};
11use super::config::PayloadConfig;
12use super::endpoint_stats::{EndpointPayloadStats, EndpointPayloadStatsSnapshot};
13use super::entity_bandwidth::{EntityBandwidth, EntityBandwidthSnapshot};
14
15/// Sort order for endpoint listings.
16#[derive(Debug, Clone, Copy)]
17pub enum EndpointSortBy {
18    RequestBytes,
19    ResponseBytes,
20    RequestCount,
21    LastSeen,
22}
23
24/// Summary statistics for the payload profiler.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct PayloadSummary {
27    pub total_endpoints: usize,
28    pub total_entities: usize,
29    pub total_requests: u64,
30    pub total_request_bytes: u64,
31    pub total_response_bytes: u64,
32    pub avg_request_size: f64,
33    pub avg_response_size: f64,
34    pub active_anomalies: usize,
35}
36
37/// Main payload profiling manager.
38pub struct PayloadManager {
39    config: PayloadConfig,
40    /// Per-endpoint statistics
41    endpoints: DashMap<String, RwLock<EndpointPayloadStats>>,
42    /// Per-entity bandwidth tracking
43    entities: DashMap<String, RwLock<EntityBandwidth>>,
44    /// Recent anomalies
45    anomalies: RwLock<Vec<PayloadAnomaly>>,
46    /// Global counters
47    total_requests: AtomicU64,
48    total_request_bytes: AtomicU64,
49    total_response_bytes: AtomicU64,
50}
51
52impl PayloadManager {
53    /// Create a new payload manager.
54    pub fn new(config: PayloadConfig) -> Self {
55        Self {
56            config,
57            endpoints: DashMap::new(),
58            entities: DashMap::new(),
59            anomalies: RwLock::new(Vec::new()),
60            total_requests: AtomicU64::new(0),
61            total_request_bytes: AtomicU64::new(0),
62            total_response_bytes: AtomicU64::new(0),
63        }
64    }
65
66    /// Record a request/response pair.
67    pub fn record_request(
68        &self,
69        template: &str,
70        entity_id: &str,
71        request_bytes: u64,
72        response_bytes: u64,
73    ) {
74        // Update global counters
75        self.total_requests.fetch_add(1, Ordering::Relaxed);
76        self.total_request_bytes
77            .fetch_add(request_bytes, Ordering::Relaxed);
78        self.total_response_bytes
79            .fetch_add(response_bytes, Ordering::Relaxed);
80
81        // Update endpoint stats
82        self.record_endpoint(template, request_bytes, response_bytes);
83
84        // Update entity bandwidth
85        self.record_entity(entity_id, request_bytes, response_bytes);
86
87        // Check for LRU eviction
88        self.maybe_evict();
89    }
90
91    fn record_endpoint(&self, template: &str, request_bytes: u64, response_bytes: u64) {
92        let entry = self
93            .endpoints
94            .entry(template.to_string())
95            .or_insert_with(|| {
96                RwLock::new(EndpointPayloadStats::new(
97                    template.to_string(),
98                    self.config.window_duration_ms,
99                    self.config.max_windows,
100                ))
101            });
102        entry.write().record(request_bytes, response_bytes);
103    }
104
105    fn record_entity(&self, entity_id: &str, request_bytes: u64, response_bytes: u64) {
106        let entry = self
107            .entities
108            .entry(entity_id.to_string())
109            .or_insert_with(|| {
110                RwLock::new(EntityBandwidth::new(
111                    entity_id.to_string(),
112                    self.config.window_duration_ms,
113                    self.config.max_windows,
114                ))
115            });
116        entry.write().record(request_bytes, response_bytes);
117    }
118
119    fn maybe_evict(&self) {
120        // Simple eviction: remove oldest if over capacity
121        if self.endpoints.len() > self.config.max_endpoints {
122            // Find entry with lowest access count
123            let mut min_access = u64::MAX;
124            let mut min_key = None;
125            for entry in self.endpoints.iter() {
126                let access = entry.value().read().access_count;
127                if access < min_access {
128                    min_access = access;
129                    min_key = Some(entry.key().clone());
130                }
131            }
132            if let Some(key) = min_key {
133                self.endpoints.remove(&key);
134            }
135        }
136
137        if self.entities.len() > self.config.max_entities {
138            let mut min_access = u64::MAX;
139            let mut min_key = None;
140            for entry in self.entities.iter() {
141                let access = entry.value().read().access_count;
142                if access < min_access {
143                    min_access = access;
144                    min_key = Some(entry.key().clone());
145                }
146            }
147            if let Some(key) = min_key {
148                self.entities.remove(&key);
149            }
150        }
151    }
152
153    /// Check for anomalies across all endpoints and entities.
154    pub fn check_anomalies(&self) -> Vec<PayloadAnomaly> {
155        let mut detected = Vec::new();
156
157        // Check for oversized payloads
158        for entry in self.endpoints.iter() {
159            let stats = entry.read();
160            if stats.request_count() < self.config.warmup_requests as u64 {
161                continue;
162            }
163
164            let req_stats = stats.request_stats();
165            let resp_stats = stats.response_stats();
166
167            // Check current requests against p99
168            let req_threshold = req_stats.p99_bytes * self.config.oversize_threshold;
169            let resp_threshold = resp_stats.p99_bytes * self.config.oversize_threshold;
170
171            // We'd need to track individual requests to detect oversized ones
172            // For now, detect if max >> p99 (indicating outliers exist)
173            if req_stats.max_bytes as f64 > req_threshold
174                && req_stats.max_bytes > self.config.min_large_payload_bytes
175            {
176                detected.push(PayloadAnomaly::new(
177                    PayloadAnomalyType::OversizedRequest,
178                    PayloadAnomalySeverity::Medium,
179                    stats.template.clone(),
180                    "unknown".to_string(),
181                    format!(
182                        "Oversized request detected: {} bytes (p99: {} bytes)",
183                        req_stats.max_bytes, req_stats.p99_bytes as u64
184                    ),
185                    PayloadAnomalyMetadata::Oversize {
186                        actual_bytes: req_stats.max_bytes,
187                        expected_bytes: req_stats.p99_bytes as u64,
188                        threshold: self.config.oversize_threshold,
189                        percentile: 99.0,
190                    },
191                ));
192            }
193
194            if resp_stats.max_bytes as f64 > resp_threshold
195                && resp_stats.max_bytes > self.config.min_large_payload_bytes
196            {
197                detected.push(PayloadAnomaly::new(
198                    PayloadAnomalyType::OversizedResponse,
199                    PayloadAnomalySeverity::Low,
200                    stats.template.clone(),
201                    "unknown".to_string(),
202                    format!(
203                        "Oversized response detected: {} bytes (p99: {} bytes)",
204                        resp_stats.max_bytes, resp_stats.p99_bytes as u64
205                    ),
206                    PayloadAnomalyMetadata::Oversize {
207                        actual_bytes: resp_stats.max_bytes,
208                        expected_bytes: resp_stats.p99_bytes as u64,
209                        threshold: self.config.oversize_threshold,
210                        percentile: 99.0,
211                    },
212                ));
213            }
214        }
215
216        // Check for bandwidth spikes per entity
217        for entry in self.entities.iter() {
218            let entity = entry.read();
219            let current = entity.current_bytes_per_minute();
220            let avg = entity.avg_bytes_per_minute();
221
222            if avg > 0 && current as f64 > avg as f64 * self.config.bandwidth_spike_threshold {
223                detected.push(PayloadAnomaly::new(
224                    PayloadAnomalyType::BandwidthSpike,
225                    PayloadAnomalySeverity::High,
226                    "".to_string(),
227                    entity.entity_id.clone(),
228                    format!(
229                        "Bandwidth spike: {} bytes/min (avg: {} bytes/min)",
230                        current, avg
231                    ),
232                    PayloadAnomalyMetadata::BandwidthSpike {
233                        current_bytes_per_min: current,
234                        avg_bytes_per_min: avg,
235                        threshold: self.config.bandwidth_spike_threshold,
236                    },
237                ));
238            }
239
240            // Check for exfiltration pattern
241            if entity.total_request_count > self.config.warmup_requests as u64 {
242                let avg_req = entity.total_request_bytes / entity.total_request_count;
243                let avg_resp = entity.total_response_bytes / entity.total_request_count;
244
245                if avg_req > 0 && avg_resp > self.config.min_large_payload_bytes {
246                    let ratio = avg_resp as f64 / avg_req as f64;
247                    if ratio > self.config.exfiltration_ratio_threshold {
248                        detected.push(PayloadAnomaly::new(
249                            PayloadAnomalyType::ExfiltrationPattern,
250                            PayloadAnomalySeverity::Critical,
251                            "".to_string(),
252                            entity.entity_id.clone(),
253                            format!("Exfiltration pattern: response/request ratio {:.1}x", ratio),
254                            PayloadAnomalyMetadata::DataPattern {
255                                request_bytes: avg_req,
256                                response_bytes: avg_resp,
257                                ratio,
258                                threshold: self.config.exfiltration_ratio_threshold,
259                            },
260                        ));
261                    }
262                }
263
264                // Check for upload pattern
265                if avg_resp > 0 && avg_req > self.config.min_large_payload_bytes {
266                    let ratio = avg_req as f64 / avg_resp as f64;
267                    if ratio > self.config.upload_ratio_threshold {
268                        detected.push(PayloadAnomaly::new(
269                            PayloadAnomalyType::UploadPattern,
270                            PayloadAnomalySeverity::High,
271                            "".to_string(),
272                            entity.entity_id.clone(),
273                            format!("Upload pattern: request/response ratio {:.1}x", ratio),
274                            PayloadAnomalyMetadata::DataPattern {
275                                request_bytes: avg_req,
276                                response_bytes: avg_resp,
277                                ratio,
278                                threshold: self.config.upload_ratio_threshold,
279                            },
280                        ));
281                    }
282                }
283            }
284        }
285
286        // Store detected anomalies
287        {
288            let mut anomalies = self.anomalies.write();
289            anomalies.extend(detected.clone());
290            // Keep only recent anomalies (last 1000)
291            let len = anomalies.len();
292            if len > 1000 {
293                anomalies.drain(0..len - 1000);
294            }
295        }
296
297        detected
298    }
299
300    /// Get summary statistics.
301    pub fn get_summary(&self) -> PayloadSummary {
302        let total_requests = self.total_requests.load(Ordering::Relaxed);
303        let total_request_bytes = self.total_request_bytes.load(Ordering::Relaxed);
304        let total_response_bytes = self.total_response_bytes.load(Ordering::Relaxed);
305
306        PayloadSummary {
307            total_endpoints: self.endpoints.len(),
308            total_entities: self.entities.len(),
309            total_requests,
310            total_request_bytes,
311            total_response_bytes,
312            avg_request_size: if total_requests > 0 {
313                total_request_bytes as f64 / total_requests as f64
314            } else {
315                0.0
316            },
317            avg_response_size: if total_requests > 0 {
318                total_response_bytes as f64 / total_requests as f64
319            } else {
320                0.0
321            },
322            active_anomalies: self.anomalies.read().len(),
323        }
324    }
325
326    /// Get statistics for a specific endpoint.
327    pub fn get_endpoint_stats(&self, template: &str) -> Option<EndpointPayloadStatsSnapshot> {
328        self.endpoints
329            .get(template)
330            .map(|e| EndpointPayloadStatsSnapshot::from(&*e.read()))
331    }
332
333    /// Get bandwidth for a specific entity.
334    pub fn get_entity_bandwidth(&self, entity_id: &str) -> Option<EntityBandwidthSnapshot> {
335        self.entities
336            .get(entity_id)
337            .map(|e| EntityBandwidthSnapshot::from(&*e.read()))
338    }
339
340    /// List top endpoints by specified metric.
341    pub fn list_top_endpoints(
342        &self,
343        limit: usize,
344        sort_by: EndpointSortBy,
345    ) -> Vec<EndpointPayloadStatsSnapshot> {
346        let mut endpoints: Vec<_> = self
347            .endpoints
348            .iter()
349            .map(|e| EndpointPayloadStatsSnapshot::from(&*e.read()))
350            .collect();
351
352        match sort_by {
353            EndpointSortBy::RequestBytes => {
354                endpoints.sort_by(|a, b| b.request.total_bytes.cmp(&a.request.total_bytes));
355            }
356            EndpointSortBy::ResponseBytes => {
357                endpoints.sort_by(|a, b| b.response.total_bytes.cmp(&a.response.total_bytes));
358            }
359            EndpointSortBy::RequestCount => {
360                endpoints.sort_by(|a, b| b.request_count.cmp(&a.request_count));
361            }
362            EndpointSortBy::LastSeen => {
363                endpoints.sort_by(|a, b| b.last_seen_ms.cmp(&a.last_seen_ms));
364            }
365        }
366
367        endpoints.truncate(limit);
368        endpoints
369    }
370
371    /// List top entities by bandwidth.
372    pub fn list_top_entities(&self, limit: usize) -> Vec<EntityBandwidthSnapshot> {
373        let mut entities: Vec<_> = self
374            .entities
375            .iter()
376            .map(|e| EntityBandwidthSnapshot::from(&*e.read()))
377            .collect();
378
379        entities.sort_by(|a, b| {
380            let a_total = a.total_request_bytes + a.total_response_bytes;
381            let b_total = b.total_request_bytes + b.total_response_bytes;
382            b_total.cmp(&a_total)
383        });
384
385        entities.truncate(limit);
386        entities
387    }
388
389    /// Get recent anomalies.
390    pub fn get_anomalies(&self, limit: usize) -> Vec<PayloadAnomaly> {
391        let anomalies = self.anomalies.read();
392        anomalies.iter().rev().take(limit).cloned().collect()
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    #[test]
401    fn test_record_and_stats() {
402        let config = PayloadConfig::default();
403        let manager = PayloadManager::new(config);
404
405        manager.record_request("/api/users", "192.168.1.1", 100, 500);
406        manager.record_request("/api/users", "192.168.1.1", 150, 600);
407        manager.record_request("/api/users", "192.168.1.2", 200, 400);
408
409        let summary = manager.get_summary();
410        assert_eq!(summary.total_requests, 3);
411        assert_eq!(summary.total_request_bytes, 450);
412        assert_eq!(summary.total_response_bytes, 1500);
413        assert_eq!(summary.total_endpoints, 1);
414        assert_eq!(summary.total_entities, 2);
415    }
416
417    #[test]
418    fn test_endpoint_stats() {
419        let config = PayloadConfig::default();
420        let manager = PayloadManager::new(config);
421
422        for i in 0..10 {
423            manager.record_request("/api/test", "10.0.0.1", 100 * i, 200 * i);
424        }
425
426        let stats = manager.get_endpoint_stats("/api/test").unwrap();
427        assert_eq!(stats.template, "/api/test");
428        assert_eq!(stats.request_count, 10);
429    }
430
431    #[test]
432    fn test_entity_bandwidth() {
433        let config = PayloadConfig::default();
434        let manager = PayloadManager::new(config);
435
436        manager.record_request("/api/a", "1.1.1.1", 1000, 2000);
437        manager.record_request("/api/b", "1.1.1.1", 500, 1000);
438
439        let bandwidth = manager.get_entity_bandwidth("1.1.1.1").unwrap();
440        assert_eq!(bandwidth.entity_id, "1.1.1.1");
441        assert_eq!(bandwidth.total_request_bytes, 1500);
442        assert_eq!(bandwidth.total_response_bytes, 3000);
443    }
444}