1use axum::{
9 extract::{Query, State},
10 response::sse::{Event, KeepAlive, Sse},
11 routing::get,
12 Router,
13};
14use futures::stream::Stream;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::convert::Infallible;
18use std::sync::{Arc, RwLock};
19use std::time::Duration;
20use tokio::sync::broadcast;
21
22#[derive(Clone)]
24pub struct EventSourceManager {
25 tx: broadcast::Sender<StateChange>,
27 states: Arc<RwLock<HashMap<String, String>>>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33#[serde(rename_all = "camelCase")]
34pub struct StateChange {
35 pub changed: HashMap<String, String>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41#[serde(rename_all = "camelCase")]
42pub struct PushSubscription {
43 pub url: String,
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub types: Option<Vec<String>>,
48}
49
50#[derive(Debug, Deserialize)]
52pub struct EventSourceQuery {
53 #[serde(default)]
55 pub types: Option<String>,
56 #[serde(default)]
58 pub closeafter: Option<u64>,
59 #[serde(default)]
61 pub ping: Option<u64>,
62}
63
64impl EventSourceManager {
65 pub fn new() -> Self {
67 let (tx, _rx) = broadcast::channel(100);
68 Self {
69 tx,
70 states: Arc::new(RwLock::new(HashMap::new())),
71 }
72 }
73
74 pub fn notify_change(&self, data_type: String, new_state: String) {
76 if let Ok(mut states) = self.states.write() {
78 states.insert(data_type.clone(), new_state.clone());
79 }
80
81 let mut changed = HashMap::new();
83 changed.insert(data_type, new_state);
84
85 let state_change = StateChange { changed };
86
87 let _ = self.tx.send(state_change);
89 }
90
91 pub fn get_state(&self, data_type: &str) -> Option<String> {
93 self.states
94 .read()
95 .ok()
96 .and_then(|states| states.get(data_type).cloned())
97 }
98
99 fn subscribe(&self) -> broadcast::Receiver<StateChange> {
101 self.tx.subscribe()
102 }
103}
104
105impl Default for EventSourceManager {
106 fn default() -> Self {
107 Self::new()
108 }
109}
110
111pub fn eventsource_routes() -> Router<EventSourceManager> {
113 Router::new().route("/eventsource", get(eventsource_handler))
114}
115
116async fn eventsource_handler(
118 Query(params): Query<EventSourceQuery>,
119 State(manager): State<EventSourceManager>,
120) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
121 let types_filter: Option<Vec<String>> = params
123 .types
124 .map(|t| t.split(',').map(|s| s.trim().to_string()).collect());
125
126 let mut rx = manager.subscribe();
128
129 let close_after = params.closeafter.map(Duration::from_secs);
131
132 let ping_interval = params
134 .ping
135 .map(Duration::from_secs)
136 .unwrap_or(Duration::from_secs(30));
137
138 let stream = async_stream::stream! {
140 let start_time = tokio::time::Instant::now();
141
142 loop {
143 if let Some(timeout) = close_after {
145 if start_time.elapsed() >= timeout {
146 break;
147 }
148 }
149
150 tokio::select! {
152 result = rx.recv() => {
153 match result {
154 Ok(state_change) => {
155 let filtered_changes: HashMap<String, String> = if let Some(ref filter) = types_filter {
157 state_change.changed.into_iter()
158 .filter(|(k, _)| filter.contains(k))
159 .collect()
160 } else {
161 state_change.changed
162 };
163
164 if !filtered_changes.is_empty() {
166 let event_data = StateChange { changed: filtered_changes };
167 if let Ok(json) = serde_json::to_string(&event_data) {
168 yield Ok(Event::default()
169 .event("state")
170 .data(json));
171 }
172 }
173 }
174 Err(broadcast::error::RecvError::Lagged(_)) => {
175 yield Ok(Event::default()
177 .event("error")
178 .data("Client lagged behind"));
179 }
180 Err(broadcast::error::RecvError::Closed) => {
181 break;
183 }
184 }
185 }
186 _ = tokio::time::sleep(ping_interval) => {
187 yield Ok(Event::default().comment("ping"));
189 }
190 }
191 }
192 };
193
194 Sse::new(stream).keep_alive(KeepAlive::default())
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[test]
202 fn test_event_source_manager_new() {
203 let manager = EventSourceManager::new();
204 assert!(manager.get_state("Email").is_none());
205 }
206
207 #[test]
208 fn test_notify_change() {
209 let manager = EventSourceManager::new();
210
211 manager.notify_change("Email".to_string(), "state1".to_string());
212
213 assert_eq!(manager.get_state("Email"), Some("state1".to_string()));
214 }
215
216 #[test]
217 fn test_notify_multiple_changes() {
218 let manager = EventSourceManager::new();
219
220 manager.notify_change("Email".to_string(), "state1".to_string());
221 manager.notify_change("Mailbox".to_string(), "state2".to_string());
222 manager.notify_change("Thread".to_string(), "state3".to_string());
223
224 assert_eq!(manager.get_state("Email"), Some("state1".to_string()));
225 assert_eq!(manager.get_state("Mailbox"), Some("state2".to_string()));
226 assert_eq!(manager.get_state("Thread"), Some("state3".to_string()));
227 }
228
229 #[test]
230 fn test_state_update() {
231 let manager = EventSourceManager::new();
232
233 manager.notify_change("Email".to_string(), "state1".to_string());
234 assert_eq!(manager.get_state("Email"), Some("state1".to_string()));
235
236 manager.notify_change("Email".to_string(), "state2".to_string());
237 assert_eq!(manager.get_state("Email"), Some("state2".to_string()));
238 }
239
240 #[test]
241 fn test_subscribe() {
242 let manager = EventSourceManager::new();
243 let mut rx = manager.subscribe();
244
245 manager.notify_change("Email".to_string(), "state1".to_string());
246
247 let change = rx.try_recv().unwrap();
248 assert_eq!(change.changed.get("Email"), Some(&"state1".to_string()));
249 }
250
251 #[test]
252 fn test_multiple_subscribers() {
253 let manager = EventSourceManager::new();
254 let mut rx1 = manager.subscribe();
255 let mut rx2 = manager.subscribe();
256
257 manager.notify_change("Email".to_string(), "state1".to_string());
258
259 let change1 = rx1.try_recv().unwrap();
261 let change2 = rx2.try_recv().unwrap();
262
263 assert_eq!(change1.changed.get("Email"), Some(&"state1".to_string()));
264 assert_eq!(change2.changed.get("Email"), Some(&"state1".to_string()));
265 }
266
267 #[test]
268 fn test_state_change_serialization() {
269 let mut changed = HashMap::new();
270 changed.insert("Email".to_string(), "state123".to_string());
271 changed.insert("Mailbox".to_string(), "state456".to_string());
272
273 let state_change = StateChange { changed };
274
275 let json = serde_json::to_string(&state_change).unwrap();
276 assert!(json.contains("Email"));
277 assert!(json.contains("state123"));
278 }
279
280 #[test]
281 fn test_push_subscription_serialization() {
282 let subscription = PushSubscription {
283 url: "https://push.example.com/abc123".to_string(),
284 types: Some(vec!["Email".to_string(), "Mailbox".to_string()]),
285 };
286
287 let json = serde_json::to_string(&subscription).unwrap();
288 assert!(json.contains("push.example.com"));
289 }
290
291 #[test]
292 fn test_event_source_manager_default() {
293 let manager = EventSourceManager::default();
294 assert!(manager.get_state("any").is_none());
295 }
296
297 #[test]
298 fn test_event_source_manager_clone() {
299 let manager1 = EventSourceManager::new();
300 manager1.notify_change("Email".to_string(), "state1".to_string());
301
302 let manager2 = manager1.clone();
303 assert_eq!(manager2.get_state("Email"), Some("state1".to_string()));
304 }
305
306 #[test]
307 fn test_get_nonexistent_state() {
308 let manager = EventSourceManager::new();
309 assert_eq!(manager.get_state("NonExistent"), None);
310 }
311
312 #[test]
313 fn test_notify_empty_state() {
314 let manager = EventSourceManager::new();
315 manager.notify_change("Email".to_string(), "".to_string());
316
317 assert_eq!(manager.get_state("Email"), Some("".to_string()));
318 }
319
320 #[test]
321 fn test_subscribe_before_notify() {
322 let manager = EventSourceManager::new();
323 let mut rx = manager.subscribe();
324
325 assert!(rx.try_recv().is_err());
327
328 manager.notify_change("Email".to_string(), "state1".to_string());
329
330 assert!(rx.try_recv().is_ok());
332 }
333
334 #[test]
335 fn test_subscribe_after_notify() {
336 let manager = EventSourceManager::new();
337
338 manager.notify_change("Email".to_string(), "state1".to_string());
339
340 let mut rx = manager.subscribe();
342
343 assert!(rx.try_recv().is_err());
345
346 assert_eq!(manager.get_state("Email"), Some("state1".to_string()));
348 }
349
350 #[test]
351 fn test_multiple_data_types() {
352 let manager = EventSourceManager::new();
353
354 manager.notify_change("Email".to_string(), "email_state".to_string());
355 manager.notify_change("Mailbox".to_string(), "mailbox_state".to_string());
356 manager.notify_change("Thread".to_string(), "thread_state".to_string());
357 manager.notify_change("Identity".to_string(), "identity_state".to_string());
358
359 assert_eq!(manager.get_state("Email"), Some("email_state".to_string()));
360 assert_eq!(
361 manager.get_state("Mailbox"),
362 Some("mailbox_state".to_string())
363 );
364 assert_eq!(
365 manager.get_state("Thread"),
366 Some("thread_state".to_string())
367 );
368 assert_eq!(
369 manager.get_state("Identity"),
370 Some("identity_state".to_string())
371 );
372 }
373
374 #[test]
375 fn test_state_change_empty_changed() {
376 let state_change = StateChange {
377 changed: HashMap::new(),
378 };
379
380 let json = serde_json::to_string(&state_change).unwrap();
381 assert!(json.contains("changed"));
382 }
383
384 #[test]
385 fn test_push_subscription_without_types() {
386 let subscription = PushSubscription {
387 url: "https://push.example.com/def456".to_string(),
388 types: None,
389 };
390
391 let json = serde_json::to_string(&subscription).unwrap();
392 assert!(!json.contains("types"));
393 }
394
395 #[test]
396 fn test_concurrent_notifications() {
397 let manager = EventSourceManager::new();
398 let mut rx = manager.subscribe();
399
400 for i in 0..10 {
402 manager.notify_change(format!("Type{}", i), format!("state{}", i));
403 }
404
405 let mut received = 0;
407 while rx.try_recv().is_ok() {
408 received += 1;
409 }
410
411 assert!(received > 0);
412 }
413
414 #[test]
415 fn test_state_persistence_across_notifications() {
416 let manager = EventSourceManager::new();
417
418 manager.notify_change("Email".to_string(), "state1".to_string());
419 manager.notify_change("Mailbox".to_string(), "state2".to_string());
420
421 manager.notify_change("Email".to_string(), "state3".to_string());
423
424 assert_eq!(manager.get_state("Email"), Some("state3".to_string()));
426 assert_eq!(manager.get_state("Mailbox"), Some("state2".to_string()));
427 }
428
429 #[test]
430 fn test_subscriber_receives_only_new_changes() {
431 let manager = EventSourceManager::new();
432
433 manager.notify_change("Email".to_string(), "old_state".to_string());
434
435 let mut rx = manager.subscribe();
436
437 manager.notify_change("Email".to_string(), "new_state".to_string());
438
439 let change = rx.try_recv().unwrap();
440 assert_eq!(change.changed.get("Email"), Some(&"new_state".to_string()));
441 }
442
443 #[test]
444 fn test_broadcast_channel_capacity() {
445 let manager = EventSourceManager::new();
446 let mut rx = manager.subscribe();
447
448 for i in 0..200 {
450 manager.notify_change(format!("Type{}", i), format!("state{}", i));
451 }
452
453 let mut received = 0;
455 let mut lagged = false;
456 loop {
457 match rx.try_recv() {
458 Ok(_) => received += 1,
459 Err(broadcast::error::TryRecvError::Lagged(_)) => {
460 lagged = true;
462 break;
463 }
464 Err(_) => break,
465 }
466 }
467
468 assert!(received > 0 || lagged);
470 }
471
472 #[test]
473 fn test_state_change_deserialization() {
474 let json = r#"{"changed":{"Email":"state1","Mailbox":"state2"}}"#;
475 let state_change: StateChange = serde_json::from_str(json).unwrap();
476
477 assert_eq!(
478 state_change.changed.get("Email"),
479 Some(&"state1".to_string())
480 );
481 assert_eq!(
482 state_change.changed.get("Mailbox"),
483 Some(&"state2".to_string())
484 );
485 }
486
487 #[test]
488 fn test_push_subscription_deserialization() {
489 let json = r#"{"url":"https://example.com","types":["Email"]}"#;
490 let subscription: PushSubscription = serde_json::from_str(json).unwrap();
491
492 assert_eq!(subscription.url, "https://example.com");
493 assert_eq!(subscription.types, Some(vec!["Email".to_string()]));
494 }
495}