1use dashmap::DashMap;
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use std::sync::atomic::{AtomicU64, Ordering};
7
8use super::anomaly::{
9 PayloadAnomaly, PayloadAnomalyMetadata, PayloadAnomalySeverity, PayloadAnomalyType,
10};
11use super::config::PayloadConfig;
12use super::endpoint_stats::{EndpointPayloadStats, EndpointPayloadStatsSnapshot};
13use super::entity_bandwidth::{EntityBandwidth, EntityBandwidthSnapshot};
14
15#[derive(Debug, Clone, Copy)]
17pub enum EndpointSortBy {
18 RequestBytes,
19 ResponseBytes,
20 RequestCount,
21 LastSeen,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct PayloadSummary {
27 pub total_endpoints: usize,
28 pub total_entities: usize,
29 pub total_requests: u64,
30 pub total_request_bytes: u64,
31 pub total_response_bytes: u64,
32 pub avg_request_size: f64,
33 pub avg_response_size: f64,
34 pub active_anomalies: usize,
35}
36
37pub struct PayloadManager {
39 config: PayloadConfig,
40 endpoints: DashMap<String, RwLock<EndpointPayloadStats>>,
42 entities: DashMap<String, RwLock<EntityBandwidth>>,
44 anomalies: RwLock<Vec<PayloadAnomaly>>,
46 total_requests: AtomicU64,
48 total_request_bytes: AtomicU64,
49 total_response_bytes: AtomicU64,
50}
51
52impl PayloadManager {
53 pub fn new(config: PayloadConfig) -> Self {
55 Self {
56 config,
57 endpoints: DashMap::new(),
58 entities: DashMap::new(),
59 anomalies: RwLock::new(Vec::new()),
60 total_requests: AtomicU64::new(0),
61 total_request_bytes: AtomicU64::new(0),
62 total_response_bytes: AtomicU64::new(0),
63 }
64 }
65
66 pub fn record_request(
68 &self,
69 template: &str,
70 entity_id: &str,
71 request_bytes: u64,
72 response_bytes: u64,
73 ) {
74 self.total_requests.fetch_add(1, Ordering::Relaxed);
76 self.total_request_bytes
77 .fetch_add(request_bytes, Ordering::Relaxed);
78 self.total_response_bytes
79 .fetch_add(response_bytes, Ordering::Relaxed);
80
81 self.record_endpoint(template, request_bytes, response_bytes);
83
84 self.record_entity(entity_id, request_bytes, response_bytes);
86
87 self.maybe_evict();
89 }
90
91 fn record_endpoint(&self, template: &str, request_bytes: u64, response_bytes: u64) {
92 let entry = self
93 .endpoints
94 .entry(template.to_string())
95 .or_insert_with(|| {
96 RwLock::new(EndpointPayloadStats::new(
97 template.to_string(),
98 self.config.window_duration_ms,
99 self.config.max_windows,
100 ))
101 });
102 entry.write().record(request_bytes, response_bytes);
103 }
104
105 fn record_entity(&self, entity_id: &str, request_bytes: u64, response_bytes: u64) {
106 let entry = self
107 .entities
108 .entry(entity_id.to_string())
109 .or_insert_with(|| {
110 RwLock::new(EntityBandwidth::new(
111 entity_id.to_string(),
112 self.config.window_duration_ms,
113 self.config.max_windows,
114 ))
115 });
116 entry.write().record(request_bytes, response_bytes);
117 }
118
119 fn maybe_evict(&self) {
120 if self.endpoints.len() > self.config.max_endpoints {
122 let mut min_access = u64::MAX;
124 let mut min_key = None;
125 for entry in self.endpoints.iter() {
126 let access = entry.value().read().access_count;
127 if access < min_access {
128 min_access = access;
129 min_key = Some(entry.key().clone());
130 }
131 }
132 if let Some(key) = min_key {
133 self.endpoints.remove(&key);
134 }
135 }
136
137 if self.entities.len() > self.config.max_entities {
138 let mut min_access = u64::MAX;
139 let mut min_key = None;
140 for entry in self.entities.iter() {
141 let access = entry.value().read().access_count;
142 if access < min_access {
143 min_access = access;
144 min_key = Some(entry.key().clone());
145 }
146 }
147 if let Some(key) = min_key {
148 self.entities.remove(&key);
149 }
150 }
151 }
152
153 pub fn check_anomalies(&self) -> Vec<PayloadAnomaly> {
155 let mut detected = Vec::new();
156
157 for entry in self.endpoints.iter() {
159 let stats = entry.read();
160 if stats.request_count() < self.config.warmup_requests as u64 {
161 continue;
162 }
163
164 let req_stats = stats.request_stats();
165 let resp_stats = stats.response_stats();
166
167 let req_threshold = req_stats.p99_bytes * self.config.oversize_threshold;
169 let resp_threshold = resp_stats.p99_bytes * self.config.oversize_threshold;
170
171 if req_stats.max_bytes as f64 > req_threshold
174 && req_stats.max_bytes > self.config.min_large_payload_bytes
175 {
176 detected.push(PayloadAnomaly::new(
177 PayloadAnomalyType::OversizedRequest,
178 PayloadAnomalySeverity::Medium,
179 stats.template.clone(),
180 "unknown".to_string(),
181 format!(
182 "Oversized request detected: {} bytes (p99: {} bytes)",
183 req_stats.max_bytes, req_stats.p99_bytes as u64
184 ),
185 PayloadAnomalyMetadata::Oversize {
186 actual_bytes: req_stats.max_bytes,
187 expected_bytes: req_stats.p99_bytes as u64,
188 threshold: self.config.oversize_threshold,
189 percentile: 99.0,
190 },
191 ));
192 }
193
194 if resp_stats.max_bytes as f64 > resp_threshold
195 && resp_stats.max_bytes > self.config.min_large_payload_bytes
196 {
197 detected.push(PayloadAnomaly::new(
198 PayloadAnomalyType::OversizedResponse,
199 PayloadAnomalySeverity::Low,
200 stats.template.clone(),
201 "unknown".to_string(),
202 format!(
203 "Oversized response detected: {} bytes (p99: {} bytes)",
204 resp_stats.max_bytes, resp_stats.p99_bytes as u64
205 ),
206 PayloadAnomalyMetadata::Oversize {
207 actual_bytes: resp_stats.max_bytes,
208 expected_bytes: resp_stats.p99_bytes as u64,
209 threshold: self.config.oversize_threshold,
210 percentile: 99.0,
211 },
212 ));
213 }
214 }
215
216 for entry in self.entities.iter() {
218 let entity = entry.read();
219 let current = entity.current_bytes_per_minute();
220 let avg = entity.avg_bytes_per_minute();
221
222 if avg > 0 && current as f64 > avg as f64 * self.config.bandwidth_spike_threshold {
223 detected.push(PayloadAnomaly::new(
224 PayloadAnomalyType::BandwidthSpike,
225 PayloadAnomalySeverity::High,
226 "".to_string(),
227 entity.entity_id.clone(),
228 format!(
229 "Bandwidth spike: {} bytes/min (avg: {} bytes/min)",
230 current, avg
231 ),
232 PayloadAnomalyMetadata::BandwidthSpike {
233 current_bytes_per_min: current,
234 avg_bytes_per_min: avg,
235 threshold: self.config.bandwidth_spike_threshold,
236 },
237 ));
238 }
239
240 if entity.total_request_count > self.config.warmup_requests as u64 {
242 let avg_req = entity.total_request_bytes / entity.total_request_count;
243 let avg_resp = entity.total_response_bytes / entity.total_request_count;
244
245 if avg_req > 0 && avg_resp > self.config.min_large_payload_bytes {
246 let ratio = avg_resp as f64 / avg_req as f64;
247 if ratio > self.config.exfiltration_ratio_threshold {
248 detected.push(PayloadAnomaly::new(
249 PayloadAnomalyType::ExfiltrationPattern,
250 PayloadAnomalySeverity::Critical,
251 "".to_string(),
252 entity.entity_id.clone(),
253 format!("Exfiltration pattern: response/request ratio {:.1}x", ratio),
254 PayloadAnomalyMetadata::DataPattern {
255 request_bytes: avg_req,
256 response_bytes: avg_resp,
257 ratio,
258 threshold: self.config.exfiltration_ratio_threshold,
259 },
260 ));
261 }
262 }
263
264 if avg_resp > 0 && avg_req > self.config.min_large_payload_bytes {
266 let ratio = avg_req as f64 / avg_resp as f64;
267 if ratio > self.config.upload_ratio_threshold {
268 detected.push(PayloadAnomaly::new(
269 PayloadAnomalyType::UploadPattern,
270 PayloadAnomalySeverity::High,
271 "".to_string(),
272 entity.entity_id.clone(),
273 format!("Upload pattern: request/response ratio {:.1}x", ratio),
274 PayloadAnomalyMetadata::DataPattern {
275 request_bytes: avg_req,
276 response_bytes: avg_resp,
277 ratio,
278 threshold: self.config.upload_ratio_threshold,
279 },
280 ));
281 }
282 }
283 }
284 }
285
286 {
288 let mut anomalies = self.anomalies.write();
289 anomalies.extend(detected.clone());
290 let len = anomalies.len();
292 if len > 1000 {
293 anomalies.drain(0..len - 1000);
294 }
295 }
296
297 detected
298 }
299
300 pub fn get_summary(&self) -> PayloadSummary {
302 let total_requests = self.total_requests.load(Ordering::Relaxed);
303 let total_request_bytes = self.total_request_bytes.load(Ordering::Relaxed);
304 let total_response_bytes = self.total_response_bytes.load(Ordering::Relaxed);
305
306 PayloadSummary {
307 total_endpoints: self.endpoints.len(),
308 total_entities: self.entities.len(),
309 total_requests,
310 total_request_bytes,
311 total_response_bytes,
312 avg_request_size: if total_requests > 0 {
313 total_request_bytes as f64 / total_requests as f64
314 } else {
315 0.0
316 },
317 avg_response_size: if total_requests > 0 {
318 total_response_bytes as f64 / total_requests as f64
319 } else {
320 0.0
321 },
322 active_anomalies: self.anomalies.read().len(),
323 }
324 }
325
326 pub fn get_endpoint_stats(&self, template: &str) -> Option<EndpointPayloadStatsSnapshot> {
328 self.endpoints
329 .get(template)
330 .map(|e| EndpointPayloadStatsSnapshot::from(&*e.read()))
331 }
332
333 pub fn get_entity_bandwidth(&self, entity_id: &str) -> Option<EntityBandwidthSnapshot> {
335 self.entities
336 .get(entity_id)
337 .map(|e| EntityBandwidthSnapshot::from(&*e.read()))
338 }
339
340 pub fn list_top_endpoints(
342 &self,
343 limit: usize,
344 sort_by: EndpointSortBy,
345 ) -> Vec<EndpointPayloadStatsSnapshot> {
346 let mut endpoints: Vec<_> = self
347 .endpoints
348 .iter()
349 .map(|e| EndpointPayloadStatsSnapshot::from(&*e.read()))
350 .collect();
351
352 match sort_by {
353 EndpointSortBy::RequestBytes => {
354 endpoints.sort_by(|a, b| b.request.total_bytes.cmp(&a.request.total_bytes));
355 }
356 EndpointSortBy::ResponseBytes => {
357 endpoints.sort_by(|a, b| b.response.total_bytes.cmp(&a.response.total_bytes));
358 }
359 EndpointSortBy::RequestCount => {
360 endpoints.sort_by(|a, b| b.request_count.cmp(&a.request_count));
361 }
362 EndpointSortBy::LastSeen => {
363 endpoints.sort_by(|a, b| b.last_seen_ms.cmp(&a.last_seen_ms));
364 }
365 }
366
367 endpoints.truncate(limit);
368 endpoints
369 }
370
371 pub fn list_top_entities(&self, limit: usize) -> Vec<EntityBandwidthSnapshot> {
373 let mut entities: Vec<_> = self
374 .entities
375 .iter()
376 .map(|e| EntityBandwidthSnapshot::from(&*e.read()))
377 .collect();
378
379 entities.sort_by(|a, b| {
380 let a_total = a.total_request_bytes + a.total_response_bytes;
381 let b_total = b.total_request_bytes + b.total_response_bytes;
382 b_total.cmp(&a_total)
383 });
384
385 entities.truncate(limit);
386 entities
387 }
388
389 pub fn get_anomalies(&self, limit: usize) -> Vec<PayloadAnomaly> {
391 let anomalies = self.anomalies.read();
392 anomalies.iter().rev().take(limit).cloned().collect()
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399
400 #[test]
401 fn test_record_and_stats() {
402 let config = PayloadConfig::default();
403 let manager = PayloadManager::new(config);
404
405 manager.record_request("/api/users", "192.168.1.1", 100, 500);
406 manager.record_request("/api/users", "192.168.1.1", 150, 600);
407 manager.record_request("/api/users", "192.168.1.2", 200, 400);
408
409 let summary = manager.get_summary();
410 assert_eq!(summary.total_requests, 3);
411 assert_eq!(summary.total_request_bytes, 450);
412 assert_eq!(summary.total_response_bytes, 1500);
413 assert_eq!(summary.total_endpoints, 1);
414 assert_eq!(summary.total_entities, 2);
415 }
416
417 #[test]
418 fn test_endpoint_stats() {
419 let config = PayloadConfig::default();
420 let manager = PayloadManager::new(config);
421
422 for i in 0..10 {
423 manager.record_request("/api/test", "10.0.0.1", 100 * i, 200 * i);
424 }
425
426 let stats = manager.get_endpoint_stats("/api/test").unwrap();
427 assert_eq!(stats.template, "/api/test");
428 assert_eq!(stats.request_count, 10);
429 }
430
431 #[test]
432 fn test_entity_bandwidth() {
433 let config = PayloadConfig::default();
434 let manager = PayloadManager::new(config);
435
436 manager.record_request("/api/a", "1.1.1.1", 1000, 2000);
437 manager.record_request("/api/b", "1.1.1.1", 500, 1000);
438
439 let bandwidth = manager.get_entity_bandwidth("1.1.1.1").unwrap();
440 assert_eq!(bandwidth.entity_id, "1.1.1.1");
441 assert_eq!(bandwidth.total_request_bytes, 1500);
442 assert_eq!(bandwidth.total_response_bytes, 3000);
443 }
444}