1use axum::{
19 extract::{Query, State},
20 response::sse::{Event, KeepAlive, Sse},
21 routing::get,
22 Router,
23};
24use futures::stream::Stream;
25use serde::{Deserialize, Serialize};
26use std::collections::HashMap;
27use std::convert::Infallible;
28use std::sync::{Arc, RwLock};
29use std::time::Duration;
30use tokio::sync::broadcast;
31
32#[derive(Clone)]
34pub struct EventSourceManager {
35 tx: broadcast::Sender<StateChange>,
37 states: Arc<RwLock<HashMap<String, String>>>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct StateChange {
45 pub changed: HashMap<String, String>,
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
55#[serde(rename_all = "camelCase")]
56pub struct EventSourcePushHint {
57 pub url: String,
59 #[serde(skip_serializing_if = "Option::is_none")]
61 pub types: Option<Vec<String>>,
62}
63
64#[derive(Debug, Deserialize)]
66pub struct EventSourceQuery {
67 #[serde(default)]
69 pub types: Option<String>,
70 #[serde(default)]
72 pub closeafter: Option<u64>,
73 #[serde(default)]
75 pub ping: Option<u64>,
76}
77
78impl EventSourceManager {
79 pub fn new() -> Self {
81 let (tx, _rx) = broadcast::channel(100);
82 Self {
83 tx,
84 states: Arc::new(RwLock::new(HashMap::new())),
85 }
86 }
87
88 pub fn notify_change(&self, data_type: String, new_state: String) {
90 if let Ok(mut states) = self.states.write() {
92 states.insert(data_type.clone(), new_state.clone());
93 }
94
95 let mut changed = HashMap::new();
97 changed.insert(data_type, new_state);
98
99 let state_change = StateChange { changed };
100
101 let _ = self.tx.send(state_change);
103 }
104
105 pub fn get_state(&self, data_type: &str) -> Option<String> {
107 self.states
108 .read()
109 .ok()
110 .and_then(|states| states.get(data_type).cloned())
111 }
112
113 fn subscribe(&self) -> broadcast::Receiver<StateChange> {
115 self.tx.subscribe()
116 }
117}
118
119impl Default for EventSourceManager {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125pub fn eventsource_routes() -> Router<EventSourceManager> {
130 Router::new().route("/eventsource", get(eventsource_handler))
131}
132
133async fn eventsource_handler(
135 Query(params): Query<EventSourceQuery>,
136 State(manager): State<EventSourceManager>,
137) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
138 let types_filter: Option<Vec<String>> = params
140 .types
141 .map(|t| t.split(',').map(|s| s.trim().to_string()).collect());
142
143 let mut rx = manager.subscribe();
145
146 let close_after = params.closeafter.map(Duration::from_secs);
148
149 let ping_interval = params
151 .ping
152 .map(Duration::from_secs)
153 .unwrap_or(Duration::from_secs(30));
154
155 let stream = async_stream::stream! {
157 let start_time = tokio::time::Instant::now();
158
159 loop {
160 if let Some(timeout) = close_after {
162 if start_time.elapsed() >= timeout {
163 break;
164 }
165 }
166
167 tokio::select! {
169 result = rx.recv() => {
170 match result {
171 Ok(state_change) => {
172 let filtered_changes: HashMap<String, String> = if let Some(ref filter) = types_filter {
174 state_change.changed.into_iter()
175 .filter(|(k, _)| filter.contains(k))
176 .collect()
177 } else {
178 state_change.changed
179 };
180
181 if !filtered_changes.is_empty() {
183 let event_data = StateChange { changed: filtered_changes };
184 if let Ok(json) = serde_json::to_string(&event_data) {
185 yield Ok(Event::default()
186 .event("state")
187 .data(json));
188 }
189 }
190 }
191 Err(broadcast::error::RecvError::Lagged(_)) => {
192 yield Ok(Event::default()
194 .event("error")
195 .data("Client lagged behind"));
196 }
197 Err(broadcast::error::RecvError::Closed) => {
198 break;
200 }
201 }
202 }
203 _ = tokio::time::sleep(ping_interval) => {
204 yield Ok(Event::default().comment("ping"));
206 }
207 }
208 }
209 };
210
211 Sse::new(stream).keep_alive(KeepAlive::default())
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn test_event_source_manager_new() {
220 let manager = EventSourceManager::new();
221 assert!(manager.get_state("Email").is_none());
222 }
223
224 #[test]
225 fn test_notify_change() {
226 let manager = EventSourceManager::new();
227
228 manager.notify_change("Email".to_string(), "state1".to_string());
229
230 assert_eq!(manager.get_state("Email"), Some("state1".to_string()));
231 }
232
233 #[test]
234 fn test_notify_multiple_changes() {
235 let manager = EventSourceManager::new();
236
237 manager.notify_change("Email".to_string(), "state1".to_string());
238 manager.notify_change("Mailbox".to_string(), "state2".to_string());
239 manager.notify_change("Thread".to_string(), "state3".to_string());
240
241 assert_eq!(manager.get_state("Email"), Some("state1".to_string()));
242 assert_eq!(manager.get_state("Mailbox"), Some("state2".to_string()));
243 assert_eq!(manager.get_state("Thread"), Some("state3".to_string()));
244 }
245
246 #[test]
247 fn test_state_update() {
248 let manager = EventSourceManager::new();
249
250 manager.notify_change("Email".to_string(), "state1".to_string());
251 assert_eq!(manager.get_state("Email"), Some("state1".to_string()));
252
253 manager.notify_change("Email".to_string(), "state2".to_string());
254 assert_eq!(manager.get_state("Email"), Some("state2".to_string()));
255 }
256
257 #[test]
258 fn test_subscribe() {
259 let manager = EventSourceManager::new();
260 let mut rx = manager.subscribe();
261
262 manager.notify_change("Email".to_string(), "state1".to_string());
263
264 let change = rx.try_recv().unwrap();
265 assert_eq!(change.changed.get("Email"), Some(&"state1".to_string()));
266 }
267
268 #[test]
269 fn test_multiple_subscribers() {
270 let manager = EventSourceManager::new();
271 let mut rx1 = manager.subscribe();
272 let mut rx2 = manager.subscribe();
273
274 manager.notify_change("Email".to_string(), "state1".to_string());
275
276 let change1 = rx1.try_recv().unwrap();
278 let change2 = rx2.try_recv().unwrap();
279
280 assert_eq!(change1.changed.get("Email"), Some(&"state1".to_string()));
281 assert_eq!(change2.changed.get("Email"), Some(&"state1".to_string()));
282 }
283
284 #[test]
285 fn test_state_change_serialization() {
286 let mut changed = HashMap::new();
287 changed.insert("Email".to_string(), "state123".to_string());
288 changed.insert("Mailbox".to_string(), "state456".to_string());
289
290 let state_change = StateChange { changed };
291
292 let json = serde_json::to_string(&state_change).unwrap();
293 assert!(json.contains("Email"));
294 assert!(json.contains("state123"));
295 }
296
297 #[test]
298 fn test_push_subscription_serialization() {
299 let subscription = EventSourcePushHint {
300 url: "https://push.example.com/abc123".to_string(),
301 types: Some(vec!["Email".to_string(), "Mailbox".to_string()]),
302 };
303
304 let json = serde_json::to_string(&subscription).unwrap();
305 assert!(json.contains("push.example.com"));
306 }
307
308 #[test]
309 fn test_event_source_manager_default() {
310 let manager = EventSourceManager::default();
311 assert!(manager.get_state("any").is_none());
312 }
313
314 #[test]
315 fn test_event_source_manager_clone() {
316 let manager1 = EventSourceManager::new();
317 manager1.notify_change("Email".to_string(), "state1".to_string());
318
319 let manager2 = manager1.clone();
320 assert_eq!(manager2.get_state("Email"), Some("state1".to_string()));
321 }
322
323 #[test]
324 fn test_get_nonexistent_state() {
325 let manager = EventSourceManager::new();
326 assert_eq!(manager.get_state("NonExistent"), None);
327 }
328
329 #[test]
330 fn test_notify_empty_state() {
331 let manager = EventSourceManager::new();
332 manager.notify_change("Email".to_string(), "".to_string());
333
334 assert_eq!(manager.get_state("Email"), Some("".to_string()));
335 }
336
337 #[test]
338 fn test_subscribe_before_notify() {
339 let manager = EventSourceManager::new();
340 let mut rx = manager.subscribe();
341
342 assert!(rx.try_recv().is_err());
344
345 manager.notify_change("Email".to_string(), "state1".to_string());
346
347 assert!(rx.try_recv().is_ok());
349 }
350
351 #[test]
352 fn test_subscribe_after_notify() {
353 let manager = EventSourceManager::new();
354
355 manager.notify_change("Email".to_string(), "state1".to_string());
356
357 let mut rx = manager.subscribe();
359
360 assert!(rx.try_recv().is_err());
362
363 assert_eq!(manager.get_state("Email"), Some("state1".to_string()));
365 }
366
367 #[test]
368 fn test_multiple_data_types() {
369 let manager = EventSourceManager::new();
370
371 manager.notify_change("Email".to_string(), "email_state".to_string());
372 manager.notify_change("Mailbox".to_string(), "mailbox_state".to_string());
373 manager.notify_change("Thread".to_string(), "thread_state".to_string());
374 manager.notify_change("Identity".to_string(), "identity_state".to_string());
375
376 assert_eq!(manager.get_state("Email"), Some("email_state".to_string()));
377 assert_eq!(
378 manager.get_state("Mailbox"),
379 Some("mailbox_state".to_string())
380 );
381 assert_eq!(
382 manager.get_state("Thread"),
383 Some("thread_state".to_string())
384 );
385 assert_eq!(
386 manager.get_state("Identity"),
387 Some("identity_state".to_string())
388 );
389 }
390
391 #[test]
392 fn test_state_change_empty_changed() {
393 let state_change = StateChange {
394 changed: HashMap::new(),
395 };
396
397 let json = serde_json::to_string(&state_change).unwrap();
398 assert!(json.contains("changed"));
399 }
400
401 #[test]
402 fn test_push_subscription_without_types() {
403 let subscription = EventSourcePushHint {
404 url: "https://push.example.com/def456".to_string(),
405 types: None,
406 };
407
408 let json = serde_json::to_string(&subscription).unwrap();
409 assert!(!json.contains("types"));
410 }
411
412 #[test]
413 fn test_concurrent_notifications() {
414 let manager = EventSourceManager::new();
415 let mut rx = manager.subscribe();
416
417 for i in 0..10 {
419 manager.notify_change(format!("Type{}", i), format!("state{}", i));
420 }
421
422 let mut received = 0;
424 while rx.try_recv().is_ok() {
425 received += 1;
426 }
427
428 assert!(received > 0);
429 }
430
431 #[test]
432 fn test_state_persistence_across_notifications() {
433 let manager = EventSourceManager::new();
434
435 manager.notify_change("Email".to_string(), "state1".to_string());
436 manager.notify_change("Mailbox".to_string(), "state2".to_string());
437
438 manager.notify_change("Email".to_string(), "state3".to_string());
440
441 assert_eq!(manager.get_state("Email"), Some("state3".to_string()));
443 assert_eq!(manager.get_state("Mailbox"), Some("state2".to_string()));
444 }
445
446 #[test]
447 fn test_subscriber_receives_only_new_changes() {
448 let manager = EventSourceManager::new();
449
450 manager.notify_change("Email".to_string(), "old_state".to_string());
451
452 let mut rx = manager.subscribe();
453
454 manager.notify_change("Email".to_string(), "new_state".to_string());
455
456 let change = rx.try_recv().unwrap();
457 assert_eq!(change.changed.get("Email"), Some(&"new_state".to_string()));
458 }
459
460 #[test]
461 fn test_broadcast_channel_capacity() {
462 let manager = EventSourceManager::new();
463 let mut rx = manager.subscribe();
464
465 for i in 0..200 {
467 manager.notify_change(format!("Type{}", i), format!("state{}", i));
468 }
469
470 let mut received = 0;
472 let mut lagged = false;
473 loop {
474 match rx.try_recv() {
475 Ok(_) => received += 1,
476 Err(broadcast::error::TryRecvError::Lagged(_)) => {
477 lagged = true;
479 break;
480 }
481 Err(_) => break,
482 }
483 }
484
485 assert!(received > 0 || lagged);
487 }
488
489 #[test]
490 fn test_state_change_deserialization() {
491 let json = r#"{"changed":{"Email":"state1","Mailbox":"state2"}}"#;
492 let state_change: StateChange = serde_json::from_str(json).unwrap();
493
494 assert_eq!(
495 state_change.changed.get("Email"),
496 Some(&"state1".to_string())
497 );
498 assert_eq!(
499 state_change.changed.get("Mailbox"),
500 Some(&"state2".to_string())
501 );
502 }
503
504 #[test]
505 fn test_push_subscription_deserialization() {
506 let json = r#"{"url":"https://example.com","types":["Email"]}"#;
507 let subscription: EventSourcePushHint = serde_json::from_str(json).unwrap();
508
509 assert_eq!(subscription.url, "https://example.com");
510 assert_eq!(subscription.types, Some(vec!["Email".to_string()]));
511 }
512}