Skip to main content

synapse_pingora/telemetry/
auth_coverage_aggregator.rs

1use parking_lot::RwLock;
2use std::collections::HashMap;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::time::interval;
7use tracing::warn;
8
9use crate::signals::auth_coverage::{
10    AuthCoverageSummary, EndpointCounts, EndpointSummary, ResponseClass,
11};
12use crate::telemetry::SignalEmitter;
13
14/// Edge aggregator - maintains local counts, flushes to Hub periodically
15pub struct AuthCoverageAggregator {
16    sensor_id: String,
17    tenant_id: Option<String>,
18    counts: Arc<RwLock<HashMap<String, EndpointCounts>>>,
19    dropped_endpoints: AtomicU64,
20    emitter: Arc<dyn SignalEmitter>,
21    flush_interval: Duration,
22    max_endpoints: usize,
23}
24
25impl AuthCoverageAggregator {
26    pub fn new(
27        sensor_id: String,
28        tenant_id: Option<String>,
29        emitter: Arc<dyn SignalEmitter>,
30        flush_interval_secs: u64,
31    ) -> Self {
32        Self {
33            sensor_id,
34            tenant_id,
35            counts: Arc::new(RwLock::new(HashMap::new())),
36            dropped_endpoints: AtomicU64::new(0),
37            emitter,
38            flush_interval: Duration::from_secs(flush_interval_secs),
39            max_endpoints: 1000, // Default limit
40        }
41    }
42
43    /// Set the maximum number of endpoints to track
44    pub fn with_max_endpoints(mut self, max_endpoints: usize) -> Self {
45        self.max_endpoints = max_endpoints;
46        self
47    }
48
49    /// Record a request (called from response filter, must be fast)
50    pub fn record(&self, endpoint: &str, response_class: ResponseClass, has_auth_header: bool) {
51        let mut counts = self.counts.write();
52
53        // If at limit and endpoint is new, merge into "OTHER"
54        // Account for "OTHER" entry by using saturating_sub(1)
55        let target_endpoint = if counts.contains_key(endpoint)
56            || counts.len() < self.max_endpoints.saturating_sub(1)
57        {
58            endpoint
59        } else {
60            self.dropped_endpoints.fetch_add(1, Ordering::Relaxed);
61            "OTHER"
62        };
63
64        let entry = counts.entry(target_endpoint.to_string()).or_default();
65
66        entry.total += 1;
67
68        match response_class {
69            ResponseClass::Success => entry.success += 1,
70            ResponseClass::Unauthorized => entry.unauthorized += 1,
71            ResponseClass::Forbidden => entry.forbidden += 1,
72            _ => entry.other_error += 1,
73        }
74
75        if has_auth_header {
76            entry.with_auth += 1;
77        } else {
78            entry.without_auth += 1;
79        }
80    }
81
82    /// Start background flush task
83    pub fn start_flush_task(self: Arc<Self>) {
84        let Ok(handle) = tokio::runtime::Handle::try_current() else {
85            warn!("Auth coverage flush task skipped (no Tokio runtime)");
86            return;
87        };
88        let aggregator = self.clone();
89
90        handle.spawn(async move {
91            let mut ticker = interval(aggregator.flush_interval);
92
93            loop {
94                ticker.tick().await;
95                aggregator.flush().await;
96            }
97        });
98    }
99
100    /// Flush current counts to Hub and reset
101    async fn flush(&self) {
102        // Swap out current counts atomically
103        let counts = {
104            let mut guard = self.counts.write();
105            std::mem::take(&mut *guard)
106        };
107
108        let dropped_endpoints = self.dropped_endpoints.load(Ordering::Relaxed);
109
110        if counts.is_empty() && dropped_endpoints == 0 {
111            return; // Nothing to send
112        }
113
114        let summary = AuthCoverageSummary {
115            timestamp: std::time::SystemTime::now()
116                .duration_since(std::time::UNIX_EPOCH)
117                .unwrap()
118                .as_millis() as u64,
119            sensor_id: self.sensor_id.clone(),
120            tenant_id: self.tenant_id.clone(),
121            endpoints: counts
122                .into_iter()
123                .map(|(endpoint, counts)| EndpointSummary { endpoint, counts })
124                .collect(),
125            dropped_endpoints,
126        };
127
128        if let Ok(payload) = serde_json::to_value(&summary) {
129            self.emitter.emit("auth_coverage_summary", payload).await;
130            self.dropped_endpoints
131                .fetch_sub(dropped_endpoints, Ordering::SeqCst);
132        }
133    }
134
135    /// Get current endpoint count (for testing/debugging)
136    #[cfg(test)]
137    pub fn endpoint_count(&self) -> usize {
138        self.counts.read().len()
139    }
140
141    /// Force flush (for testing)
142    #[cfg(test)]
143    pub async fn force_flush(&self) {
144        self.flush().await;
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use async_trait::async_trait;
152    use std::sync::atomic::{AtomicUsize, Ordering};
153
154    // Mock emitter for testing
155    struct MockEmitter {
156        emit_count: AtomicUsize,
157    }
158
159    impl MockEmitter {
160        fn new() -> Arc<Self> {
161            Arc::new(Self {
162                emit_count: AtomicUsize::new(0),
163            })
164        }
165
166        fn count(&self) -> usize {
167            self.emit_count.load(Ordering::SeqCst)
168        }
169    }
170
171    #[async_trait]
172    impl SignalEmitter for MockEmitter {
173        async fn emit(&self, _signal_type: &str, _payload: serde_json::Value) {
174            self.emit_count.fetch_add(1, Ordering::SeqCst);
175        }
176    }
177
178    #[test]
179    fn test_record_increments_counts() {
180        let emitter = MockEmitter::new();
181        let aggregator = AuthCoverageAggregator::new(
182            "test-sensor".to_string(),
183            None,
184            emitter.clone() as Arc<dyn SignalEmitter>,
185            60,
186        );
187
188        aggregator.record("GET /api/users/{id}", ResponseClass::Success, true);
189        aggregator.record("GET /api/users/{id}", ResponseClass::Success, true);
190        aggregator.record("GET /api/users/{id}", ResponseClass::Forbidden, true);
191
192        assert_eq!(aggregator.endpoint_count(), 1);
193    }
194
195    #[tokio::test]
196    async fn test_flush_clears_counts() {
197        let emitter = MockEmitter::new();
198        let aggregator = AuthCoverageAggregator::new(
199            "test-sensor".to_string(),
200            None,
201            emitter.clone() as Arc<dyn SignalEmitter>,
202            60,
203        );
204
205        aggregator.record("GET /api/users/{id}", ResponseClass::Success, true);
206        assert_eq!(aggregator.endpoint_count(), 1);
207
208        aggregator.flush().await;
209        assert_eq!(aggregator.endpoint_count(), 0);
210        assert_eq!(emitter.count(), 1);
211    }
212
213    #[tokio::test]
214    async fn test_empty_flush_no_emit() {
215        let emitter = MockEmitter::new();
216        let aggregator = AuthCoverageAggregator::new(
217            "test-sensor".to_string(),
218            None,
219            emitter.clone() as Arc<dyn SignalEmitter>,
220            60,
221        );
222
223        aggregator.flush().await;
224        assert_eq!(emitter.count(), 0);
225    }
226
227    #[test]
228    fn test_max_endpoints_limit() {
229        let emitter = MockEmitter::new();
230        let aggregator = AuthCoverageAggregator::new(
231            "test-sensor".to_string(),
232            None,
233            emitter.clone() as Arc<dyn SignalEmitter>,
234            60,
235        )
236        .with_max_endpoints(2);
237
238        aggregator.record("EP1", ResponseClass::Success, true);
239        aggregator.record("EP2", ResponseClass::Success, true);
240        aggregator.record("EP3", ResponseClass::Success, true);
241
242        assert_eq!(aggregator.endpoint_count(), 2);
243
244        let counts = aggregator.counts.read();
245        assert!(counts.contains_key("EP1"));
246        assert!(counts.contains_key("OTHER"));
247        assert!(!counts.contains_key("EP3"));
248    }
249}