1use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11
12use parking_lot::Mutex;
13use tokio::sync::mpsc;
14use tracing::{debug, trace, warn};
15use xds_core::{NodeHash, XdsError, XdsResult};
16
17use crate::Snapshot;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
21pub struct WatchId(u64);
22
23impl WatchId {
24 fn next() -> Self {
26 static COUNTER: AtomicU64 = AtomicU64::new(1);
27 Self(COUNTER.fetch_add(1, Ordering::Relaxed))
28 }
29
30 #[inline]
32 pub fn as_u64(&self) -> u64 {
33 self.0
34 }
35}
36
37impl std::fmt::Display for WatchId {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 write!(f, "watch-{}", self.0)
40 }
41}
42
43#[derive(Debug)]
48pub struct Watch {
49 id: WatchId,
51 node_hash: NodeHash,
53 receiver: mpsc::Receiver<Arc<Snapshot>>,
55}
56
57impl Watch {
58 #[inline]
60 pub fn id(&self) -> WatchId {
61 self.id
62 }
63
64 #[inline]
66 pub fn node_hash(&self) -> NodeHash {
67 self.node_hash
68 }
69
70 pub async fn recv(&mut self) -> Option<Arc<Snapshot>> {
74 self.receiver.recv().await
75 }
76
77 pub fn try_recv(&mut self) -> Result<Arc<Snapshot>, mpsc::error::TryRecvError> {
84 self.receiver.try_recv()
85 }
86}
87
88#[derive(Debug, Clone)]
90#[allow(dead_code)] pub(crate) struct WatchSender {
92 id: WatchId,
93 node_hash: NodeHash,
94 sender: mpsc::Sender<Arc<Snapshot>>,
95}
96
97#[allow(dead_code)] impl WatchSender {
99 pub fn try_send(&self, snapshot: Arc<Snapshot>) -> XdsResult<()> {
104 match self.sender.try_send(snapshot) {
105 Ok(()) => Ok(()),
106 Err(mpsc::error::TrySendError::Full(_)) => {
107 trace!(watch_id = %self.id, "watch channel full, skipping update");
109 Ok(())
110 }
111 Err(mpsc::error::TrySendError::Closed(_)) => Err(XdsError::WatchClosed {
112 watch_id: self.id.0,
113 }),
114 }
115 }
116
117 #[inline]
119 pub fn id(&self) -> WatchId {
120 self.id
121 }
122}
123
124#[derive(Debug)]
129pub struct WatchManager {
130 watches: Mutex<HashMap<NodeHash, Vec<WatchSender>>>,
132 channel_buffer: usize,
134}
135
136impl Default for WatchManager {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142impl WatchManager {
143 pub fn new() -> Self {
145 Self::with_buffer_size(16)
146 }
147
148 pub fn with_buffer_size(buffer_size: usize) -> Self {
150 Self {
151 watches: Mutex::new(HashMap::new()),
152 channel_buffer: buffer_size,
153 }
154 }
155
156 pub fn create_watch(&self, node_hash: NodeHash) -> Watch {
160 let id = WatchId::next();
161 let (sender, receiver) = mpsc::channel(self.channel_buffer);
162
163 let watch_sender = WatchSender {
164 id,
165 node_hash,
166 sender,
167 };
168
169 {
171 let mut watches = self.watches.lock();
172 watches.entry(node_hash).or_default().push(watch_sender);
173 }
174
175 debug!(watch_id = %id, node = %node_hash, "created watch");
176
177 Watch {
178 id,
179 node_hash,
180 receiver,
181 }
182 }
183
184 pub fn cancel_watch(&self, watch_id: WatchId) {
188 let mut watches = self.watches.lock();
189
190 for senders in watches.values_mut() {
192 if let Some(pos) = senders.iter().position(|s| s.id == watch_id) {
193 senders.swap_remove(pos);
194 debug!(watch_id = %watch_id, "cancelled watch");
195 return;
196 }
197 }
198
199 warn!(watch_id = %watch_id, "attempted to cancel unknown watch");
200 }
201
202 pub fn notify(&self, node_hash: NodeHash, snapshot: Arc<Snapshot>) {
206 let senders: Vec<WatchSender> = {
208 let watches = self.watches.lock();
209 watches.get(&node_hash).cloned().unwrap_or_default()
210 };
211
212 if senders.is_empty() {
213 return;
214 }
215
216 let mut closed_ids = Vec::new();
218
219 for sender in &senders {
220 if let Err(XdsError::WatchClosed { watch_id }) = sender.try_send(Arc::clone(&snapshot))
221 {
222 closed_ids.push(WatchId(watch_id));
223 }
224 }
225
226 if !closed_ids.is_empty() {
228 let mut watches = self.watches.lock();
229 if let Some(senders) = watches.get_mut(&node_hash) {
230 senders.retain(|s| !closed_ids.contains(&s.id));
231 }
232 debug!(count = closed_ids.len(), "removed closed watches");
233 }
234
235 trace!(
236 node = %node_hash,
237 watch_count = senders.len() - closed_ids.len(),
238 "notified watches of snapshot update"
239 );
240 }
241
242 pub fn watch_count(&self, node_hash: NodeHash) -> usize {
244 let watches = self.watches.lock();
245 watches.get(&node_hash).map(|v| v.len()).unwrap_or(0)
246 }
247
248 pub fn total_watch_count(&self) -> usize {
250 let watches = self.watches.lock();
251 watches.values().map(|v| v.len()).sum()
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use std::sync::Arc as StdArc;
259 use std::thread;
260
261 #[test]
262 fn watch_id_unique() {
263 let id1 = WatchId::next();
264 let id2 = WatchId::next();
265 assert_ne!(id1, id2);
266 }
267
268 #[test]
269 fn watch_id_display() {
270 let id = WatchId::next();
271 let display = format!("{}", id);
272 assert!(display.starts_with("watch-"));
273 }
274
275 #[test]
276 fn watch_id_concurrent_uniqueness() {
277 use std::collections::HashSet;
278 use std::sync::Mutex;
279
280 let ids = StdArc::new(Mutex::new(HashSet::new()));
281 let mut handles = vec![];
282
283 for _ in 0..10 {
285 let ids = StdArc::clone(&ids);
286 handles.push(thread::spawn(move || {
287 for _ in 0..100 {
288 let id = WatchId::next();
289 ids.lock().unwrap().insert(id.0);
290 }
291 }));
292 }
293
294 for handle in handles {
295 handle.join().unwrap();
296 }
297
298 assert_eq!(ids.lock().unwrap().len(), 1000);
300 }
301
302 #[tokio::test]
303 async fn watch_manager_create_and_notify() {
304 let manager = WatchManager::new();
305 let node = NodeHash::from_id("test-node");
306
307 let mut watch = manager.create_watch(node);
308 assert_eq!(manager.watch_count(node), 1);
309
310 let snapshot = Arc::new(Snapshot::builder().version("v1").build());
311 manager.notify(node, snapshot.clone());
312
313 let received = watch.recv().await.unwrap();
314 assert_eq!(received.version(), "v1");
315 }
316
317 #[test]
318 fn watch_manager_cancel() {
319 let manager = WatchManager::new();
320 let node = NodeHash::from_id("test-node");
321
322 let watch = manager.create_watch(node);
323 assert_eq!(manager.watch_count(node), 1);
324
325 manager.cancel_watch(watch.id());
326 assert_eq!(manager.watch_count(node), 0);
327 }
328
329 #[test]
330 fn watch_manager_cancel_nonexistent() {
331 let manager = WatchManager::new();
332 manager.cancel_watch(WatchId::next());
334 }
335
336 #[tokio::test]
337 async fn watch_manager_multiple_watches_same_node() {
338 let manager = WatchManager::new();
339 let node = NodeHash::from_id("test-node");
340
341 let mut watch1 = manager.create_watch(node);
342 let mut watch2 = manager.create_watch(node);
343 let mut watch3 = manager.create_watch(node);
344
345 assert_eq!(manager.watch_count(node), 3);
346 assert_eq!(manager.total_watch_count(), 3);
347
348 let snapshot = Arc::new(Snapshot::builder().version("v1").build());
349 manager.notify(node, snapshot);
350
351 let r1 = watch1.recv().await.unwrap();
353 let r2 = watch2.recv().await.unwrap();
354 let r3 = watch3.recv().await.unwrap();
355
356 assert_eq!(r1.version(), "v1");
357 assert_eq!(r2.version(), "v1");
358 assert_eq!(r3.version(), "v1");
359 }
360
361 #[tokio::test]
362 async fn watch_manager_multiple_nodes() {
363 let manager = WatchManager::new();
364 let node1 = NodeHash::from_id("node-1");
365 let node2 = NodeHash::from_id("node-2");
366
367 let mut watch1 = manager.create_watch(node1);
368 let mut watch2 = manager.create_watch(node2);
369
370 assert_eq!(manager.total_watch_count(), 2);
371
372 let snapshot1 = Arc::new(Snapshot::builder().version("v1").build());
374 manager.notify(node1, snapshot1);
375
376 let r1 = watch1.recv().await.unwrap();
378 assert_eq!(r1.version(), "v1");
379
380 let snapshot2 = Arc::new(Snapshot::builder().version("v2").build());
382 manager.notify(node2, snapshot2);
383
384 let r2 = watch2.recv().await.unwrap();
385 assert_eq!(r2.version(), "v2");
386 }
387
388 #[tokio::test]
389 async fn watch_manager_notify_nonexistent_node() {
390 let manager = WatchManager::new();
391 let node = NodeHash::from_id("nonexistent");
392
393 let snapshot = Arc::new(Snapshot::builder().version("v1").build());
395 manager.notify(node, snapshot);
396 }
397
398 #[test]
399 fn watch_manager_cleanup_cancelled_watches() {
400 let manager = WatchManager::new();
401 let node = NodeHash::from_id("test-node");
402
403 let watch1 = manager.create_watch(node);
404 let watch2 = manager.create_watch(node);
405 let watch3 = manager.create_watch(node);
406
407 assert_eq!(manager.watch_count(node), 3);
408
409 manager.cancel_watch(watch2.id());
410 assert_eq!(manager.watch_count(node), 2);
411
412 manager.cancel_watch(watch1.id());
413 assert_eq!(manager.watch_count(node), 1);
414
415 manager.cancel_watch(watch3.id());
416 assert_eq!(manager.watch_count(node), 0);
417 }
418
419 #[tokio::test]
420 async fn watch_receive_timeout() {
421 use tokio::time::{timeout, Duration};
422
423 let manager = WatchManager::new();
424 let node = NodeHash::from_id("test-node");
425
426 let mut watch = manager.create_watch(node);
427
428 let result = timeout(Duration::from_millis(10), watch.recv()).await;
430 assert!(result.is_err(), "Should timeout without notification");
431 }
432
433 #[tokio::test]
434 async fn watch_dropped_sender_closes_watch() {
435 let node = NodeHash::from_id("test-node");
436 let mut watch;
437
438 {
439 let manager = WatchManager::new();
440 watch = manager.create_watch(node);
441 }
443
444 let result = watch.recv().await;
446 assert!(
447 result.is_none(),
448 "Watch should close when manager is dropped"
449 );
450 }
451
452 #[test]
453 fn watch_with_custom_buffer_size() {
454 let manager = WatchManager::with_buffer_size(1);
455 let node = NodeHash::from_id("test-node");
456
457 let _watch = manager.create_watch(node);
458 assert_eq!(manager.channel_buffer, 1);
459 }
460}