synapse_pingora/telemetry/
auth_coverage_aggregator.rs1use parking_lot::RwLock;
2use std::collections::HashMap;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::Duration;
6use tokio::time::interval;
7use tracing::warn;
8
9use crate::signals::auth_coverage::{
10 AuthCoverageSummary, EndpointCounts, EndpointSummary, ResponseClass,
11};
12use crate::telemetry::SignalEmitter;
13
14pub struct AuthCoverageAggregator {
16 sensor_id: String,
17 tenant_id: Option<String>,
18 counts: Arc<RwLock<HashMap<String, EndpointCounts>>>,
19 dropped_endpoints: AtomicU64,
20 emitter: Arc<dyn SignalEmitter>,
21 flush_interval: Duration,
22 max_endpoints: usize,
23}
24
25impl AuthCoverageAggregator {
26 pub fn new(
27 sensor_id: String,
28 tenant_id: Option<String>,
29 emitter: Arc<dyn SignalEmitter>,
30 flush_interval_secs: u64,
31 ) -> Self {
32 Self {
33 sensor_id,
34 tenant_id,
35 counts: Arc::new(RwLock::new(HashMap::new())),
36 dropped_endpoints: AtomicU64::new(0),
37 emitter,
38 flush_interval: Duration::from_secs(flush_interval_secs),
39 max_endpoints: 1000, }
41 }
42
43 pub fn with_max_endpoints(mut self, max_endpoints: usize) -> Self {
45 self.max_endpoints = max_endpoints;
46 self
47 }
48
49 pub fn record(&self, endpoint: &str, response_class: ResponseClass, has_auth_header: bool) {
51 let mut counts = self.counts.write();
52
53 let target_endpoint = if counts.contains_key(endpoint)
56 || counts.len() < self.max_endpoints.saturating_sub(1)
57 {
58 endpoint
59 } else {
60 self.dropped_endpoints.fetch_add(1, Ordering::Relaxed);
61 "OTHER"
62 };
63
64 let entry = counts.entry(target_endpoint.to_string()).or_default();
65
66 entry.total += 1;
67
68 match response_class {
69 ResponseClass::Success => entry.success += 1,
70 ResponseClass::Unauthorized => entry.unauthorized += 1,
71 ResponseClass::Forbidden => entry.forbidden += 1,
72 _ => entry.other_error += 1,
73 }
74
75 if has_auth_header {
76 entry.with_auth += 1;
77 } else {
78 entry.without_auth += 1;
79 }
80 }
81
82 pub fn start_flush_task(self: Arc<Self>) {
84 let Ok(handle) = tokio::runtime::Handle::try_current() else {
85 warn!("Auth coverage flush task skipped (no Tokio runtime)");
86 return;
87 };
88 let aggregator = self.clone();
89
90 handle.spawn(async move {
91 let mut ticker = interval(aggregator.flush_interval);
92
93 loop {
94 ticker.tick().await;
95 aggregator.flush().await;
96 }
97 });
98 }
99
100 async fn flush(&self) {
102 let counts = {
104 let mut guard = self.counts.write();
105 std::mem::take(&mut *guard)
106 };
107
108 let dropped_endpoints = self.dropped_endpoints.load(Ordering::Relaxed);
109
110 if counts.is_empty() && dropped_endpoints == 0 {
111 return; }
113
114 let summary = AuthCoverageSummary {
115 timestamp: std::time::SystemTime::now()
116 .duration_since(std::time::UNIX_EPOCH)
117 .unwrap()
118 .as_millis() as u64,
119 sensor_id: self.sensor_id.clone(),
120 tenant_id: self.tenant_id.clone(),
121 endpoints: counts
122 .into_iter()
123 .map(|(endpoint, counts)| EndpointSummary { endpoint, counts })
124 .collect(),
125 dropped_endpoints,
126 };
127
128 if let Ok(payload) = serde_json::to_value(&summary) {
129 self.emitter.emit("auth_coverage_summary", payload).await;
130 self.dropped_endpoints
131 .fetch_sub(dropped_endpoints, Ordering::SeqCst);
132 }
133 }
134
135 #[cfg(test)]
137 pub fn endpoint_count(&self) -> usize {
138 self.counts.read().len()
139 }
140
141 #[cfg(test)]
143 pub async fn force_flush(&self) {
144 self.flush().await;
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use super::*;
151 use async_trait::async_trait;
152 use std::sync::atomic::{AtomicUsize, Ordering};
153
154 struct MockEmitter {
156 emit_count: AtomicUsize,
157 }
158
159 impl MockEmitter {
160 fn new() -> Arc<Self> {
161 Arc::new(Self {
162 emit_count: AtomicUsize::new(0),
163 })
164 }
165
166 fn count(&self) -> usize {
167 self.emit_count.load(Ordering::SeqCst)
168 }
169 }
170
171 #[async_trait]
172 impl SignalEmitter for MockEmitter {
173 async fn emit(&self, _signal_type: &str, _payload: serde_json::Value) {
174 self.emit_count.fetch_add(1, Ordering::SeqCst);
175 }
176 }
177
178 #[test]
179 fn test_record_increments_counts() {
180 let emitter = MockEmitter::new();
181 let aggregator = AuthCoverageAggregator::new(
182 "test-sensor".to_string(),
183 None,
184 emitter.clone() as Arc<dyn SignalEmitter>,
185 60,
186 );
187
188 aggregator.record("GET /api/users/{id}", ResponseClass::Success, true);
189 aggregator.record("GET /api/users/{id}", ResponseClass::Success, true);
190 aggregator.record("GET /api/users/{id}", ResponseClass::Forbidden, true);
191
192 assert_eq!(aggregator.endpoint_count(), 1);
193 }
194
195 #[tokio::test]
196 async fn test_flush_clears_counts() {
197 let emitter = MockEmitter::new();
198 let aggregator = AuthCoverageAggregator::new(
199 "test-sensor".to_string(),
200 None,
201 emitter.clone() as Arc<dyn SignalEmitter>,
202 60,
203 );
204
205 aggregator.record("GET /api/users/{id}", ResponseClass::Success, true);
206 assert_eq!(aggregator.endpoint_count(), 1);
207
208 aggregator.flush().await;
209 assert_eq!(aggregator.endpoint_count(), 0);
210 assert_eq!(emitter.count(), 1);
211 }
212
213 #[tokio::test]
214 async fn test_empty_flush_no_emit() {
215 let emitter = MockEmitter::new();
216 let aggregator = AuthCoverageAggregator::new(
217 "test-sensor".to_string(),
218 None,
219 emitter.clone() as Arc<dyn SignalEmitter>,
220 60,
221 );
222
223 aggregator.flush().await;
224 assert_eq!(emitter.count(), 0);
225 }
226
227 #[test]
228 fn test_max_endpoints_limit() {
229 let emitter = MockEmitter::new();
230 let aggregator = AuthCoverageAggregator::new(
231 "test-sensor".to_string(),
232 None,
233 emitter.clone() as Arc<dyn SignalEmitter>,
234 60,
235 )
236 .with_max_endpoints(2);
237
238 aggregator.record("EP1", ResponseClass::Success, true);
239 aggregator.record("EP2", ResponseClass::Success, true);
240 aggregator.record("EP3", ResponseClass::Success, true);
241
242 assert_eq!(aggregator.endpoint_count(), 2);
243
244 let counts = aggregator.counts.read();
245 assert!(counts.contains_key("EP1"));
246 assert!(counts.contains_key("OTHER"));
247 assert!(!counts.contains_key("EP3"));
248 }
249}