Skip to main content

supabase_client_realtime/
channel.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use serde_json::Value;
6use tokio::sync::RwLock;
7
8use crate::callback::{Binding, CallbackRegistry};
9use crate::error::RealtimeError;
10use crate::types::{
11    BroadcastConfig, ChannelState, JoinConfig, JoinPayload, PostgresChangesEvent,
12    PostgresChangesFilter, PresenceConfig, PresenceState, SubscriptionStatus,
13};
14
15// ── ChannelBuilder ────────────────────────────────────────────────────────────
16
17/// Builder for configuring and subscribing to a realtime channel.
18///
19/// Created via `RealtimeClient::channel("name")`. Consumed by `subscribe()`.
20pub struct ChannelBuilder {
21    pub(crate) name: String,
22    pub(crate) topic: String,
23    pub(crate) broadcast_config: BroadcastConfig,
24    pub(crate) presence_key: String,
25    pub(crate) presence_enabled: bool,
26    pub(crate) postgres_changes: Vec<PostgresChangesFilter>,
27    pub(crate) bindings: Vec<Binding>,
28    pub(crate) is_private: bool,
29    pub(crate) subscribe_timeout: Duration,
30    pub(crate) access_token: Option<String>,
31    /// Back-reference to the client for sending messages.
32    pub(crate) client_sender: crate::client::ClientSender,
33}
34
35impl ChannelBuilder {
36    /// Listen for postgres database changes.
37    pub fn on_postgres_changes<F>(
38        mut self,
39        event: PostgresChangesEvent,
40        filter: PostgresChangesFilter,
41        callback: F,
42    ) -> Self
43    where
44        F: Fn(crate::types::PostgresChangePayload) + Send + Sync + 'static,
45    {
46        let filter_index = self.postgres_changes.len();
47        // Store the filter with the correct event type
48        let filter = filter.event(event);
49        self.postgres_changes.push(filter);
50        self.bindings.push(Binding::PostgresChanges {
51            filter_index,
52            event,
53            callback: Arc::new(callback),
54        });
55        self
56    }
57
58    /// Listen for broadcast messages with the given event name.
59    pub fn on_broadcast<F>(mut self, event: &str, callback: F) -> Self
60    where
61        F: Fn(Value) + Send + Sync + 'static,
62    {
63        self.bindings.push(Binding::Broadcast {
64            event: event.to_string(),
65            callback: Arc::new(callback),
66        });
67        self
68    }
69
70    /// Listen for presence sync events (full state).
71    pub fn on_presence_sync<F>(mut self, callback: F) -> Self
72    where
73        F: Fn(&PresenceState) + Send + Sync + 'static,
74    {
75        self.presence_enabled = true;
76        self.bindings.push(Binding::PresenceSync(Arc::new(callback)));
77        self
78    }
79
80    /// Listen for presence join events.
81    pub fn on_presence_join<F>(mut self, callback: F) -> Self
82    where
83        F: Fn(String, Vec<crate::types::PresenceMeta>) + Send + Sync + 'static,
84    {
85        self.presence_enabled = true;
86        self.bindings
87            .push(Binding::PresenceJoin(Arc::new(callback)));
88        self
89    }
90
91    /// Listen for presence leave events.
92    pub fn on_presence_leave<F>(mut self, callback: F) -> Self
93    where
94        F: Fn(String, Vec<crate::types::PresenceMeta>) + Send + Sync + 'static,
95    {
96        self.presence_enabled = true;
97        self.bindings
98            .push(Binding::PresenceLeave(Arc::new(callback)));
99        self
100    }
101
102    /// Enable broadcast acknowledgement from the server.
103    pub fn broadcast_ack(mut self, ack: bool) -> Self {
104        self.broadcast_config.ack = ack;
105        self
106    }
107
108    /// Enable receiving your own broadcast messages.
109    pub fn broadcast_self(mut self, self_send: bool) -> Self {
110        self.broadcast_config.self_send = self_send;
111        self
112    }
113
114    /// Set the presence key for this channel.
115    pub fn presence_key(mut self, key: &str) -> Self {
116        self.presence_enabled = true;
117        self.presence_key = key.to_string();
118        self
119    }
120
121    /// Mark this channel as private (requires RLS).
122    pub fn private(mut self) -> Self {
123        self.is_private = true;
124        self
125    }
126
127    /// Set the subscribe timeout for this channel.
128    pub fn timeout(mut self, timeout: Duration) -> Self {
129        self.subscribe_timeout = timeout;
130        self
131    }
132
133    /// Subscribe to the channel. Sends `phx_join` and waits for acknowledgement.
134    ///
135    /// The `status_callback` is called when subscription status changes.
136    pub async fn subscribe<F>(
137        self,
138        status_callback: F,
139    ) -> Result<RealtimeChannel, RealtimeError>
140    where
141        F: Fn(SubscriptionStatus, Option<RealtimeError>) + Send + Sync + 'static,
142    {
143        let join_payload = JoinPayload {
144            config: JoinConfig {
145                broadcast: self.broadcast_config.clone(),
146                presence: PresenceConfig {
147                    key: self.presence_key.clone(),
148                },
149                postgres_changes: self.postgres_changes.clone(),
150            },
151            access_token: self.access_token.clone(),
152        };
153
154        let registry = CallbackRegistry::new();
155        {
156            let mut bindings = registry.bindings.write().await;
157            for binding in self.bindings {
158                bindings.push(binding);
159            }
160        }
161        {
162            let mut status_cb = registry.status_callback.write().await;
163            *status_cb = Some(Arc::new(status_callback));
164        }
165
166        let inner = Arc::new(ChannelInner {
167            name: self.name.clone(),
168            topic: self.topic.clone(),
169            state: RwLock::new(ChannelState::Joining),
170            join_ref: RwLock::new(None),
171            join_payload: RwLock::new(join_payload.clone()),
172            registry,
173            presence_state: RwLock::new(PresenceState::new()),
174            pg_change_id_map: RwLock::new(HashMap::new()),
175            client_sender: self.client_sender.clone(),
176        });
177
178        let channel = RealtimeChannel {
179            inner: inner.clone(),
180        };
181
182        // Register channel with the client and send phx_join
183        self.client_sender
184            .subscribe_channel(channel.clone(), join_payload, self.subscribe_timeout)
185            .await?;
186
187        Ok(channel)
188    }
189}
190
191// ── RealtimeChannel ───────────────────────────────────────────────────────────
192
193/// A handle to a subscribed realtime channel.
194///
195/// This is cheaply cloneable and `Send + Sync`.
196#[derive(Clone)]
197pub struct RealtimeChannel {
198    pub(crate) inner: Arc<ChannelInner>,
199}
200
201pub(crate) struct ChannelInner {
202    pub(crate) name: String,
203    pub(crate) topic: String,
204    pub(crate) state: RwLock<ChannelState>,
205    pub(crate) join_ref: RwLock<Option<String>>,
206    pub(crate) join_payload: RwLock<JoinPayload>,
207    pub(crate) registry: CallbackRegistry,
208    pub(crate) presence_state: RwLock<PresenceState>,
209    /// Maps server-assigned postgres_changes subscription IDs → filter_index
210    pub(crate) pg_change_id_map: RwLock<HashMap<u64, usize>>,
211    pub(crate) client_sender: crate::client::ClientSender,
212}
213
214impl RealtimeChannel {
215    /// Get the channel topic (e.g., "realtime:db-changes").
216    pub fn topic(&self) -> &str {
217        &self.inner.topic
218    }
219
220    /// Get the channel name (user-provided name without prefix).
221    pub fn name(&self) -> &str {
222        &self.inner.name
223    }
224
225    /// Get the current channel state.
226    pub async fn state(&self) -> ChannelState {
227        *self.inner.state.read().await
228    }
229
230    /// Send a broadcast message on this channel.
231    pub async fn send_broadcast(
232        &self,
233        event: &str,
234        payload: Value,
235    ) -> Result<(), RealtimeError> {
236        let state = *self.inner.state.read().await;
237        if state != ChannelState::Joined {
238            return Err(RealtimeError::InvalidChannelState {
239                expected: ChannelState::Joined,
240                actual: state,
241            });
242        }
243        let join_ref = self.inner.join_ref.read().await;
244        let join_ref = join_ref
245            .as_deref()
246            .ok_or_else(|| RealtimeError::Internal("No join_ref".to_string()))?;
247        self.inner
248            .client_sender
249            .send_broadcast(&self.inner.topic, event, payload, join_ref)
250            .await
251    }
252
253    /// Track presence state on this channel.
254    pub async fn track(&self, payload: Value) -> Result<(), RealtimeError> {
255        let state = *self.inner.state.read().await;
256        if state != ChannelState::Joined {
257            return Err(RealtimeError::InvalidChannelState {
258                expected: ChannelState::Joined,
259                actual: state,
260            });
261        }
262        let join_ref = self.inner.join_ref.read().await;
263        let join_ref = join_ref
264            .as_deref()
265            .ok_or_else(|| RealtimeError::Internal("No join_ref".to_string()))?;
266        self.inner
267            .client_sender
268            .send_presence_track(&self.inner.topic, payload, join_ref)
269            .await
270    }
271
272    /// Stop tracking presence on this channel.
273    pub async fn untrack(&self) -> Result<(), RealtimeError> {
274        let state = *self.inner.state.read().await;
275        if state != ChannelState::Joined {
276            return Err(RealtimeError::InvalidChannelState {
277                expected: ChannelState::Joined,
278                actual: state,
279            });
280        }
281        let join_ref = self.inner.join_ref.read().await;
282        let join_ref = join_ref
283            .as_deref()
284            .ok_or_else(|| RealtimeError::Internal("No join_ref".to_string()))?;
285        self.inner
286            .client_sender
287            .send_presence_untrack(&self.inner.topic, join_ref)
288            .await
289    }
290
291    /// Get the current presence state for this channel.
292    pub async fn presence_state(&self) -> PresenceState {
293        self.inner.presence_state.read().await.clone()
294    }
295
296    /// Unsubscribe from this channel. Sends `phx_leave`.
297    pub async fn unsubscribe(&self) -> Result<(), RealtimeError> {
298        let state = *self.inner.state.read().await;
299        if state == ChannelState::Closed || state == ChannelState::Leaving {
300            return Ok(());
301        }
302        let join_ref = self.inner.join_ref.read().await;
303        let join_ref = join_ref
304            .as_deref()
305            .ok_or_else(|| RealtimeError::Internal("No join_ref for leave".to_string()))?;
306        self.inner
307            .client_sender
308            .send_leave(&self.inner.topic, join_ref)
309            .await?;
310        *self.inner.state.write().await = ChannelState::Leaving;
311        Ok(())
312    }
313
314    /// Update the access token for this channel (e.g., after token refresh).
315    pub async fn update_access_token(&self, token: &str) -> Result<(), RealtimeError> {
316        let state = *self.inner.state.read().await;
317        if state != ChannelState::Joined {
318            return Err(RealtimeError::InvalidChannelState {
319                expected: ChannelState::Joined,
320                actual: state,
321            });
322        }
323        // Update stored join payload
324        {
325            let mut jp = self.inner.join_payload.write().await;
326            jp.access_token = Some(token.to_string());
327        }
328        let join_ref = self.inner.join_ref.read().await;
329        let join_ref = join_ref
330            .as_deref()
331            .ok_or_else(|| RealtimeError::Internal("No join_ref".to_string()))?;
332        self.inner
333            .client_sender
334            .send_access_token(&self.inner.topic, token, join_ref)
335            .await
336    }
337}