1use std::collections::HashMap;
22use std::sync::Arc;
23use std::time::Duration;
24
25use async_trait::async_trait;
26use serde::{Deserialize, Serialize};
27use tokio::sync::{broadcast, RwLock};
28
29use crate::error::PodError;
30use crate::storage::StorageEvent;
31
32#[cfg(feature = "legacy-notifications")]
35pub mod legacy;
36
37#[cfg(feature = "webhook-signing")]
42pub mod signing;
43
44pub mod as_ns {
46 pub const CONTEXT: &str = "https://www.w3.org/ns/activitystreams";
47 pub const CREATE: &str = "Create";
48 pub const UPDATE: &str = "Update";
49 pub const DELETE: &str = "Delete";
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
54#[serde(rename_all = "PascalCase")]
55pub enum ChannelType {
56 WebSocketChannel2023,
57 WebhookChannel2023,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct Subscription {
63 pub id: String,
65 pub topic: String,
67 pub channel_type: ChannelType,
69 pub receive_from: String,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct ChangeNotification {
78 #[serde(rename = "@context")]
79 pub context: String,
80 pub id: String,
81 #[serde(rename = "type")]
82 pub kind: String,
83 pub object: String,
84 pub published: String,
85}
86
87impl ChangeNotification {
88 pub fn from_storage_event(event: &StorageEvent, pod_base: &str) -> Self {
90 let (kind, path) = match event {
91 StorageEvent::Created(p) => (as_ns::CREATE, p),
92 StorageEvent::Updated(p) => (as_ns::UPDATE, p),
93 StorageEvent::Deleted(p) => (as_ns::DELETE, p),
94 };
95 let object = format!("{}{}", pod_base.trim_end_matches('/'), path);
96 Self {
97 context: as_ns::CONTEXT.to_string(),
98 id: format!("urn:uuid:{}", uuid::Uuid::new_v4()),
99 kind: kind.to_string(),
100 object,
101 published: chrono::Utc::now().to_rfc3339(),
102 }
103 }
104}
105
106#[async_trait]
108pub trait Notifications: Send + Sync {
109 async fn subscribe(&self, subscription: Subscription) -> Result<(), PodError>;
111
112 async fn unsubscribe(&self, id: &str) -> Result<(), PodError>;
114
115 async fn publish(
117 &self,
118 topic: &str,
119 notification: ChangeNotification,
120 ) -> Result<(), PodError>;
121}
122
123#[derive(Default, Clone)]
128pub struct InMemoryNotifications {
129 inner: Arc<RwLock<HashMap<String, Vec<Subscription>>>>,
130}
131
132impl InMemoryNotifications {
133 pub fn new() -> Self {
134 Self::default()
135 }
136}
137
138#[async_trait]
139impl Notifications for InMemoryNotifications {
140 async fn subscribe(&self, subscription: Subscription) -> Result<(), PodError> {
141 let mut guard = self.inner.write().await;
142 guard
143 .entry(subscription.topic.clone())
144 .or_default()
145 .push(subscription);
146 Ok(())
147 }
148
149 async fn unsubscribe(&self, id: &str) -> Result<(), PodError> {
150 let mut guard = self.inner.write().await;
151 for subs in guard.values_mut() {
152 subs.retain(|s| s.id != id);
153 }
154 Ok(())
155 }
156
157 async fn publish(
158 &self,
159 topic: &str,
160 _notification: ChangeNotification,
161 ) -> Result<(), PodError> {
162 let guard = self.inner.read().await;
163 let _ = guard.get(topic);
164 Ok(())
165 }
166}
167
168#[derive(Clone)]
177pub struct WebSocketChannelManager {
178 subscriptions: Arc<RwLock<HashMap<String, Subscription>>>,
179 sender: broadcast::Sender<ChangeNotification>,
180 heartbeat_interval: Duration,
181}
182
183impl Default for WebSocketChannelManager {
184 fn default() -> Self {
185 Self::new()
186 }
187}
188
189impl WebSocketChannelManager {
190 pub fn new() -> Self {
191 let (tx, _) = broadcast::channel(1024);
192 Self {
193 subscriptions: Arc::new(RwLock::new(HashMap::new())),
194 sender: tx,
195 heartbeat_interval: Duration::from_secs(30),
196 }
197 }
198
199 pub fn with_heartbeat(mut self, interval: Duration) -> Self {
201 self.heartbeat_interval = interval;
202 self
203 }
204
205 pub fn heartbeat_interval(&self) -> Duration {
207 self.heartbeat_interval
208 }
209
210 pub async fn subscribe(&self, topic: &str, base_url: &str) -> Subscription {
213 let id = uuid::Uuid::new_v4().to_string();
214 let receive_from = format!(
215 "{}/subscription/{}",
216 base_url.trim_end_matches('/'),
217 urlencoding(topic)
218 );
219 let sub = Subscription {
220 id: id.clone(),
221 topic: topic.to_string(),
222 channel_type: ChannelType::WebSocketChannel2023,
223 receive_from,
224 };
225 self.subscriptions.write().await.insert(id, sub.clone());
226 sub
227 }
228
229 pub async fn unsubscribe(&self, id: &str) {
231 self.subscriptions.write().await.remove(id);
232 }
233
234 pub fn stream(&self) -> broadcast::Receiver<ChangeNotification> {
238 self.sender.subscribe()
239 }
240
241 pub async fn active_subscriptions(&self) -> usize {
243 self.subscriptions.read().await.len()
244 }
245
246 pub async fn pump_from_storage(
251 self,
252 mut rx: tokio::sync::mpsc::Receiver<StorageEvent>,
253 pod_base: String,
254 ) {
255 while let Some(event) = rx.recv().await {
256 let note = ChangeNotification::from_storage_event(&event, &pod_base);
257 let _ = self.sender.send(note);
258 }
259 }
260}
261
262#[async_trait]
263impl Notifications for WebSocketChannelManager {
264 async fn subscribe(&self, subscription: Subscription) -> Result<(), PodError> {
265 self.subscriptions
266 .write()
267 .await
268 .insert(subscription.id.clone(), subscription);
269 Ok(())
270 }
271
272 async fn unsubscribe(&self, id: &str) -> Result<(), PodError> {
273 self.subscriptions.write().await.remove(id);
274 Ok(())
275 }
276
277 async fn publish(
278 &self,
279 _topic: &str,
280 notification: ChangeNotification,
281 ) -> Result<(), PodError> {
282 let _ = self.sender.send(notification);
283 Ok(())
284 }
285}
286
287#[derive(Debug, Clone, PartialEq, Eq)]
293pub enum WebhookDelivery {
294 Delivered { status: u16 },
296 FatalDrop { status: u16 },
298 TransientRetry { reason: String },
300}
301
302#[derive(Clone)]
317pub struct WebhookChannelManager {
318 client: reqwest::Client,
319 subscriptions: Arc<RwLock<HashMap<String, Subscription>>>,
320 pub retry_base: Duration,
322 pub max_retries: u32,
326 pub max_backoff: Duration,
328 pub circuit_threshold: u32,
330 consecutive_failures: Arc<std::sync::atomic::AtomicU32>,
333 #[cfg(feature = "webhook-signing")]
337 signer: Option<signing::SignerConfig>,
338}
339
340impl Default for WebhookChannelManager {
341 fn default() -> Self {
342 Self::new()
343 }
344}
345
346impl WebhookChannelManager {
347 pub fn new() -> Self {
348 Self {
349 client: reqwest::Client::builder()
350 .timeout(Duration::from_secs(10))
351 .build()
352 .unwrap_or_default(),
353 subscriptions: Arc::new(RwLock::new(HashMap::new())),
354 retry_base: Duration::from_millis(500),
355 max_retries: 3,
356 max_backoff: Duration::from_secs(3600),
357 circuit_threshold: 10,
358 consecutive_failures: Arc::new(std::sync::atomic::AtomicU32::new(0)),
359 #[cfg(feature = "webhook-signing")]
360 signer: None,
361 }
362 }
363
364 pub fn with_client(client: reqwest::Client) -> Self {
367 let mut m = Self::new();
368 m.client = client;
369 m
370 }
371
372 #[cfg(feature = "webhook-signing")]
375 pub fn with_signer(mut self, signer: signing::SignerConfig) -> Self {
376 self.signer = Some(signer);
377 self
378 }
379
380 pub fn with_max_attempts(mut self, attempts: u32) -> Self {
382 self.max_retries = attempts.saturating_sub(1);
386 self
387 }
388
389 pub fn with_max_backoff(mut self, max: Duration) -> Self {
391 self.max_backoff = max;
392 self
393 }
394
395 pub fn with_circuit_threshold(mut self, threshold: u32) -> Self {
398 self.circuit_threshold = threshold;
399 self
400 }
401
402 pub fn circuit_open(&self) -> bool {
404 self.consecutive_failures
405 .load(std::sync::atomic::Ordering::Relaxed)
406 >= self.circuit_threshold
407 }
408
409 pub fn consecutive_failures(&self) -> u32 {
412 self.consecutive_failures
413 .load(std::sync::atomic::Ordering::Relaxed)
414 }
415
416 pub fn reset_circuit(&self) {
419 self.consecutive_failures
420 .store(0, std::sync::atomic::Ordering::Relaxed);
421 }
422
423 pub async fn subscribe(&self, topic: &str, target_url: &str) -> Subscription {
424 let sub = Subscription {
425 id: uuid::Uuid::new_v4().to_string(),
426 topic: topic.to_string(),
427 channel_type: ChannelType::WebhookChannel2023,
428 receive_from: target_url.to_string(),
429 };
430 self.subscriptions
431 .write()
432 .await
433 .insert(sub.id.clone(), sub.clone());
434 sub
435 }
436
437 pub async fn unsubscribe(&self, id: &str) {
438 self.subscriptions.write().await.remove(id);
439 }
440
441 pub async fn active_subscriptions(&self) -> usize {
442 self.subscriptions.read().await.len()
443 }
444
445 fn parse_retry_after(raw: &str) -> Option<Duration> {
448 if let Ok(secs) = raw.trim().parse::<u64>() {
449 return Some(Duration::from_secs(secs));
450 }
451 #[cfg(feature = "webhook-signing")]
452 {
453 if let Ok(when) = httpdate::parse_http_date(raw.trim()) {
454 if let Ok(delta) = when.duration_since(std::time::SystemTime::now()) {
455 return Some(delta);
456 }
457 }
458 }
459 None
460 }
461
462 #[doc(hidden)]
467 pub fn compute_backoff(&self, attempt: u32) -> Duration {
468 let exp = self
469 .retry_base
470 .saturating_mul(2u32.saturating_pow(attempt.min(20)));
471 let cap = std::cmp::min(exp, self.max_backoff);
472 let factor = jitter_factor();
478 let nanos = (cap.as_nanos() as f64 * factor) as u128;
479 Duration::from_nanos(nanos.min(u64::MAX as u128) as u64)
480 }
481
482 async fn send_once(
484 &self,
485 url: &str,
486 note: &ChangeNotification,
487 ) -> Result<reqwest::Response, reqwest::Error> {
488 let body = serde_json::to_vec(note).unwrap_or_default();
489 #[cfg(feature = "webhook-signing")]
490 let notification_id = note.id.clone();
491 #[cfg_attr(not(feature = "webhook-signing"), allow(unused_mut))]
492 let mut req = self
493 .client
494 .post(url)
495 .header("Content-Type", "application/ld+json");
496
497 #[cfg(feature = "webhook-signing")]
498 {
499 if let Some(cfg) = &self.signer {
500 let now = std::time::SystemTime::now()
501 .duration_since(std::time::UNIX_EPOCH)
502 .map(|d| d.as_secs())
503 .unwrap_or_default();
504 let signed = signing::sign_request(
505 cfg,
506 "POST",
507 url,
508 "application/ld+json",
509 &body,
510 ¬ification_id,
511 now,
512 );
513 for (name, value) in &signed.headers {
516 if name.eq_ignore_ascii_case("content-type") {
517 continue;
518 }
519 req = req.header(name.as_str(), value.as_str());
520 }
521 } else {
522 tracing::warn!(
523 "webhook manager delivering {} unsigned — consider configuring a SignerConfig",
524 url
525 );
526 }
527 }
528
529 req.body(body).send().await
530 }
531
532 pub async fn deliver_one(
535 &self,
536 url: &str,
537 note: &ChangeNotification,
538 ) -> WebhookDelivery {
539 if self.circuit_open() {
541 return WebhookDelivery::TransientRetry {
542 reason: "circuit open".to_string(),
543 };
544 }
545
546 let total_attempts = self.max_retries.saturating_add(1);
547 let mut attempt = 0u32;
548 loop {
549 let resp = self.send_once(url, note).await;
550 match resp {
551 Ok(r) => {
552 let status = r.status().as_u16();
553 if r.status().is_success() {
555 self.consecutive_failures
556 .store(0, std::sync::atomic::Ordering::Relaxed);
557 return WebhookDelivery::Delivered { status };
558 }
559 if status == 410 {
561 self.consecutive_failures
562 .store(0, std::sync::atomic::Ordering::Relaxed);
563 return WebhookDelivery::FatalDrop { status };
564 }
565 if status == 429 {
567 let retry_after = r
568 .headers()
569 .get("retry-after")
570 .and_then(|v| v.to_str().ok())
571 .and_then(Self::parse_retry_after)
572 .unwrap_or_else(|| self.compute_backoff(attempt));
573 attempt += 1;
574 if attempt >= total_attempts {
575 self.record_failure();
576 return WebhookDelivery::TransientRetry {
577 reason: format!("429 after {attempt} attempts"),
578 };
579 }
580 tokio::time::sleep(
581 retry_after.min(self.max_backoff),
582 )
583 .await;
584 continue;
585 }
586 if r.status().is_server_error() {
589 let wait = r
590 .headers()
591 .get("retry-after")
592 .and_then(|v| v.to_str().ok())
593 .and_then(Self::parse_retry_after)
594 .unwrap_or_else(|| self.compute_backoff(attempt));
595 attempt += 1;
596 if attempt >= total_attempts {
597 self.record_failure();
598 return WebhookDelivery::TransientRetry {
599 reason: format!("5xx after {attempt} attempts"),
600 };
601 }
602 tokio::time::sleep(wait.min(self.max_backoff)).await;
603 continue;
604 }
605 if r.status().is_client_error() {
608 let wait = self.compute_backoff(attempt);
609 attempt += 1;
610 if attempt >= total_attempts {
611 self.record_failure();
612 return WebhookDelivery::TransientRetry {
613 reason: format!("{status} after {attempt} attempts"),
614 };
615 }
616 tokio::time::sleep(wait.min(self.max_backoff)).await;
617 continue;
618 }
619 let wait = self.compute_backoff(attempt);
621 attempt += 1;
622 if attempt >= total_attempts {
623 self.record_failure();
624 return WebhookDelivery::TransientRetry {
625 reason: format!("status {status} after {attempt} attempts"),
626 };
627 }
628 tokio::time::sleep(wait.min(self.max_backoff)).await;
629 }
630 Err(e) => {
631 let wait = self.compute_backoff(attempt);
633 attempt += 1;
634 if attempt >= total_attempts {
635 self.record_failure();
636 return WebhookDelivery::TransientRetry {
637 reason: format!("network error: {e}"),
638 };
639 }
640 tokio::time::sleep(wait.min(self.max_backoff)).await;
641 }
642 }
643 }
644 }
645
646 fn record_failure(&self) {
647 self.consecutive_failures
648 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
649 }
650
651 pub async fn deliver_all(
654 &self,
655 note: &ChangeNotification,
656 topic_matches: impl Fn(&str) -> bool,
657 ) -> Vec<(String, WebhookDelivery)> {
658 let subs: Vec<Subscription> = {
659 let guard = self.subscriptions.read().await;
660 guard
661 .values()
662 .filter(|s| topic_matches(&s.topic))
663 .cloned()
664 .collect()
665 };
666 let mut out = Vec::with_capacity(subs.len());
667 let mut dropped = Vec::new();
668 for sub in subs {
669 let result = self.deliver_one(&sub.receive_from, note).await;
670 if matches!(result, WebhookDelivery::FatalDrop { .. }) {
671 dropped.push(sub.id.clone());
672 }
673 out.push((sub.id, result));
674 }
675 if !dropped.is_empty() {
676 let mut guard = self.subscriptions.write().await;
677 for id in dropped {
678 guard.remove(&id);
679 }
680 }
681 out
682 }
683
684 pub async fn pump_from_storage(
689 self,
690 mut rx: tokio::sync::mpsc::Receiver<StorageEvent>,
691 pod_base: String,
692 ) {
693 while let Some(event) = rx.recv().await {
694 let path = match &event {
695 StorageEvent::Created(p) | StorageEvent::Updated(p) | StorageEvent::Deleted(p) => {
696 p.clone()
697 }
698 };
699 let note = ChangeNotification::from_storage_event(&event, &pod_base);
700 self.deliver_all(¬e, |topic| path.starts_with(topic)).await;
701 }
702 }
703}
704
705#[async_trait]
706impl Notifications for WebhookChannelManager {
707 async fn subscribe(&self, subscription: Subscription) -> Result<(), PodError> {
708 self.subscriptions
709 .write()
710 .await
711 .insert(subscription.id.clone(), subscription);
712 Ok(())
713 }
714
715 async fn unsubscribe(&self, id: &str) -> Result<(), PodError> {
716 self.subscriptions.write().await.remove(id);
717 Ok(())
718 }
719
720 async fn publish(
721 &self,
722 topic: &str,
723 notification: ChangeNotification,
724 ) -> Result<(), PodError> {
725 let matches_topic = |t: &str| topic.starts_with(t) || t == topic;
726 self.deliver_all(¬ification, matches_topic).await;
727 Ok(())
728 }
729}
730
731pub fn discovery_document(pod_base: &str) -> serde_json::Value {
738 let base = pod_base.trim_end_matches('/');
739 serde_json::json!({
740 "@context": ["https://www.w3.org/ns/solid/notifications-context/v1"],
741 "id": format!("{base}/.notifications"),
742 "channelTypes": [
743 {
744 "id": "WebSocketChannel2023",
745 "endpoint": format!("{base}/.notifications/websocket"),
746 "features": ["as:Create", "as:Update", "as:Delete"]
747 },
748 {
749 "id": "WebhookChannel2023",
750 "endpoint": format!("{base}/.notifications/webhook"),
751 "features": ["as:Create", "as:Update", "as:Delete"]
752 }
753 ]
754 })
755}
756
757#[cfg(feature = "webhook-signing")]
767fn jitter_factor() -> f64 {
768 use rand::Rng;
769 rand::thread_rng().gen_range(0.8_f64..1.0_f64)
770}
771
772#[cfg(not(feature = "webhook-signing"))]
773fn jitter_factor() -> f64 {
774 use std::sync::atomic::{AtomicU64, Ordering};
775 static SEED: AtomicU64 = AtomicU64::new(0);
777 let seed = {
778 let n = std::time::Instant::now().elapsed().as_nanos() as u64;
779 let prev = SEED.fetch_add(n | 1, Ordering::Relaxed);
780 prev.wrapping_add(n).wrapping_add(0x9E3779B97F4A7C15)
781 };
782 let mut x = seed;
783 x = (x ^ (x >> 30)).wrapping_mul(0xBF58476D1CE4E5B9);
784 x = (x ^ (x >> 27)).wrapping_mul(0x94D049BB133111EB);
785 x ^= x >> 31;
786 let unit = (x >> 11) as f64 / (1u64 << 53) as f64;
788 0.8 + unit * 0.2
789}
790
791fn urlencoding(s: &str) -> String {
796 let mut out = String::with_capacity(s.len());
797 for b in s.bytes() {
798 match b {
799 b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' | b'/' => {
800 out.push(b as char);
801 }
802 _ => {
803 out.push_str(&format!("%{:02X}", b));
804 }
805 }
806 }
807 out
808}
809
810#[cfg(test)]
815mod tests {
816 use super::*;
817
818 #[tokio::test]
819 async fn subscribe_unsubscribe_roundtrip() {
820 let n = InMemoryNotifications::new();
821 let sub = Subscription {
822 id: "sub-1".into(),
823 topic: "/public/".into(),
824 channel_type: ChannelType::WebhookChannel2023,
825 receive_from: "https://example.com/hook".into(),
826 };
827 n.subscribe(sub.clone()).await.unwrap();
828 n.unsubscribe("sub-1").await.unwrap();
829 n.publish(
830 "/public/",
831 ChangeNotification {
832 context: as_ns::CONTEXT.into(),
833 id: "urn:uuid:test".into(),
834 kind: "Update".into(),
835 object: "/public/x".into(),
836 published: chrono::Utc::now().to_rfc3339(),
837 },
838 )
839 .await
840 .unwrap();
841 }
842
843 #[tokio::test]
844 async fn websocket_manager_broadcasts_events() {
845 let m = WebSocketChannelManager::new();
846 let mut rx = m.stream();
847 let sub = m.subscribe("/public/", "wss://pod.example").await;
848 assert_eq!(sub.channel_type, ChannelType::WebSocketChannel2023);
849 assert!(sub.receive_from.contains("/subscription/"));
850
851 let note = ChangeNotification::from_storage_event(
852 &StorageEvent::Created("/public/x".into()),
853 "https://pod.example",
854 );
855 m.publish("/public/", note.clone()).await.unwrap();
856 let received = tokio::time::timeout(Duration::from_secs(1), rx.recv())
857 .await
858 .unwrap()
859 .unwrap();
860 assert_eq!(received.kind, "Create");
861 assert_eq!(received.object, "https://pod.example/public/x");
862 }
863
864 #[tokio::test]
865 async fn change_notification_maps_event_types() {
866 let c = ChangeNotification::from_storage_event(
867 &StorageEvent::Created("/x".into()),
868 "https://p.example",
869 );
870 assert_eq!(c.kind, "Create");
871 let u = ChangeNotification::from_storage_event(
872 &StorageEvent::Updated("/x".into()),
873 "https://p.example",
874 );
875 assert_eq!(u.kind, "Update");
876 let d = ChangeNotification::from_storage_event(
877 &StorageEvent::Deleted("/x".into()),
878 "https://p.example",
879 );
880 assert_eq!(d.kind, "Delete");
881 }
882
883 #[test]
884 fn discovery_lists_both_channels() {
885 let doc = discovery_document("https://pod.example");
886 let arr = doc["channelTypes"].as_array().unwrap();
887 assert_eq!(arr.len(), 2);
888 let ids: Vec<&str> = arr.iter().map(|v| v["id"].as_str().unwrap()).collect();
889 assert!(ids.contains(&"WebSocketChannel2023"));
890 assert!(ids.contains(&"WebhookChannel2023"));
891 }
892
893 #[test]
894 fn webhook_manager_default_retries() {
895 let m = WebhookChannelManager::new();
896 assert_eq!(m.max_retries, 3);
897 }
898
899 #[tokio::test]
900 async fn websocket_active_subscriptions_count() {
901 let m = WebSocketChannelManager::new();
902 assert_eq!(m.active_subscriptions().await, 0);
903 let s = m.subscribe("/a/", "wss://p").await;
904 assert_eq!(m.active_subscriptions().await, 1);
905 m.unsubscribe(&s.id).await;
906 assert_eq!(m.active_subscriptions().await, 0);
907 }
908}