Skip to main content

sphereql_graphql/
subscription.rs

1use async_graphql::futures_util::Stream;
2use async_graphql::{Context, Result, Subscription};
3use tokio::sync::broadcast;
4
5use sphereql_core::{Contains, SphericalPoint};
6
7use crate::types::{RegionInput, SphericalPointOutput};
8
9#[derive(async_graphql::Enum, Copy, Clone, Eq, PartialEq, Debug)]
10pub enum SpatialEventType {
11    Entered,
12    Left,
13    Moved,
14}
15
16#[derive(async_graphql::SimpleObject, Debug, Clone)]
17pub struct SpatialEvent {
18    pub event_type: SpatialEventType,
19    pub point: SphericalPointOutput,
20    pub item_id: String,
21    /// Cached core [`SphericalPoint`]. Populated once when the event
22    /// is constructed so subscription filters don't reparse the point
23    /// per-event-per-subscriber. `#[graphql(skip)]` keeps it out of
24    /// the schema — callers interact with `point` (the GraphQL view).
25    #[graphql(skip)]
26    pub(crate) core_point: SphericalPoint,
27}
28
29pub struct SpatialEventBus {
30    sender: broadcast::Sender<SpatialEvent>,
31}
32
33impl SpatialEvent {
34    /// Construct an event with the GraphQL-visible `point` and the
35    /// cached `core_point` for subscription filters. Prefer this
36    /// constructor over struct-literal syntax so the core point
37    /// stays in sync with the output point.
38    pub fn new(event_type: SpatialEventType, point: SphericalPointOutput, item_id: String) -> Self {
39        let core_point = SphericalPoint::new_unchecked(point.r, point.theta, point.phi);
40        Self {
41            event_type,
42            point,
43            item_id,
44            core_point,
45        }
46    }
47}
48
49impl SpatialEventBus {
50    pub fn new(capacity: usize) -> Self {
51        let (sender, _) = broadcast::channel(capacity);
52        Self { sender }
53    }
54
55    pub fn publish(&self, event: SpatialEvent) {
56        if let Err(e) = self.sender.send(event) {
57            // No subscribers is the common case and not an error —
58            // `trace!` keeps the log pristine but makes the drop
59            // visible under tracing filters tuned for spatial events.
60            tracing::trace!(error = %e, "SpatialEventBus::publish: no subscribers");
61        }
62    }
63
64    pub fn subscribe(&self) -> broadcast::Receiver<SpatialEvent> {
65        self.sender.subscribe()
66    }
67}
68
69pub struct SphericalSubscriptionRoot;
70
71#[Subscription]
72impl SphericalSubscriptionRoot {
73    async fn item_entered_region(
74        &self,
75        ctx: &Context<'_>,
76        region: RegionInput,
77    ) -> Result<impl Stream<Item = SpatialEvent>> {
78        let bus = ctx.data::<SpatialEventBus>()?;
79        let mut rx = bus.subscribe();
80        let region = region.to_core()?;
81
82        let stream = async_graphql::async_stream::stream! {
83            loop {
84                match rx.recv().await {
85                    Ok(event) => {
86                        if event.event_type == SpatialEventType::Entered
87                            && region.contains(&event.core_point)
88                        {
89                            yield event;
90                        }
91                    }
92                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
93                    Err(broadcast::error::RecvError::Closed) => break,
94                }
95            }
96        };
97
98        Ok(stream)
99    }
100
101    async fn item_left_region(
102        &self,
103        ctx: &Context<'_>,
104        region: RegionInput,
105    ) -> Result<impl Stream<Item = SpatialEvent>> {
106        let bus = ctx.data::<SpatialEventBus>()?;
107        let mut rx = bus.subscribe();
108        let region = region.to_core()?;
109
110        let stream = async_graphql::async_stream::stream! {
111            loop {
112                match rx.recv().await {
113                    Ok(event) => {
114                        if event.event_type == SpatialEventType::Left
115                            && region.contains(&event.core_point)
116                        {
117                            yield event;
118                        }
119                    }
120                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
121                    Err(broadcast::error::RecvError::Closed) => break,
122                }
123            }
124        };
125
126        Ok(stream)
127    }
128
129    async fn spatial_events(&self, ctx: &Context<'_>) -> Result<impl Stream<Item = SpatialEvent>> {
130        let bus = ctx.data::<SpatialEventBus>()?;
131        let mut rx = bus.subscribe();
132
133        let stream = async_graphql::async_stream::stream! {
134            loop {
135                match rx.recv().await {
136                    Ok(event) => { yield event; }
137                    Err(broadcast::error::RecvError::Lagged(_)) => continue,
138                    Err(broadcast::error::RecvError::Closed) => break,
139                }
140            }
141        };
142
143        Ok(stream)
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use std::f64::consts::FRAC_PI_4;
151
152    fn make_event(event_type: SpatialEventType, r: f64, theta: f64, phi: f64) -> SpatialEvent {
153        SpatialEvent::new(
154            event_type,
155            SphericalPointOutput {
156                r,
157                theta,
158                phi,
159                theta_degrees: theta.to_degrees(),
160                phi_degrees: phi.to_degrees(),
161            },
162            format!("item-{r}-{theta}-{phi}"),
163        )
164    }
165
166    #[tokio::test]
167    async fn event_bus_publish_subscribe() {
168        let bus = SpatialEventBus::new(16);
169        let mut rx = bus.subscribe();
170
171        let event = make_event(SpatialEventType::Entered, 1.0, 0.5, FRAC_PI_4);
172        bus.publish(event.clone());
173
174        let received = rx.recv().await.unwrap();
175        assert_eq!(received.item_id, "item-1-0.5-0.7853981633974483");
176        assert_eq!(received.event_type, SpatialEventType::Entered);
177        assert!((received.point.r - 1.0).abs() < 1e-12);
178    }
179
180    #[tokio::test]
181    async fn multiple_subscribers_receive_events() {
182        let bus = SpatialEventBus::new(16);
183        let mut rx1 = bus.subscribe();
184        let mut rx2 = bus.subscribe();
185
186        let event = make_event(SpatialEventType::Moved, 2.0, 1.0, 0.5);
187        bus.publish(event.clone());
188
189        let r1 = rx1.recv().await.unwrap();
190        let r2 = rx2.recv().await.unwrap();
191
192        assert_eq!(r1.item_id, r2.item_id);
193        assert_eq!(r1.event_type, SpatialEventType::Moved);
194        assert_eq!(r2.event_type, SpatialEventType::Moved);
195    }
196
197    #[tokio::test]
198    async fn event_type_filtering() {
199        let bus = SpatialEventBus::new(16);
200        let mut rx = bus.subscribe();
201
202        bus.publish(make_event(SpatialEventType::Entered, 1.0, 0.5, 0.5));
203        bus.publish(make_event(SpatialEventType::Left, 1.0, 0.6, 0.6));
204        bus.publish(make_event(SpatialEventType::Moved, 1.0, 0.7, 0.7));
205        bus.publish(make_event(SpatialEventType::Entered, 2.0, 0.8, 0.8));
206
207        let mut entered = Vec::new();
208        for _ in 0..4 {
209            let event = rx.recv().await.unwrap();
210            if event.event_type == SpatialEventType::Entered {
211                entered.push(event);
212            }
213        }
214
215        assert_eq!(entered.len(), 2);
216        assert!((entered[0].point.r - 1.0).abs() < 1e-12);
217        assert!((entered[1].point.r - 2.0).abs() < 1e-12);
218    }
219}