supabase_rust_realtime/
channel.rs

1use crate::client::RealtimeClient; // Removed unused ConnectionState
2use crate::error::RealtimeError;
3use crate::filters::{DatabaseFilter, FilterOperator};
4use crate::message::{ChannelEvent, Payload, PresenceChange, RealtimeMessage};
5use log::{debug, error, info, trace}; // Removed unused warn
6use serde::Serialize;
7use serde_json::json;
8use std::collections::HashMap;
9use std::sync::Arc;
10// use tokio::sync::mpsc; // Unused import after commenting out `socket` field
11use tokio::sync::RwLock;
12use tokio::time::{timeout, Duration};
13// use tokio_tungstenite::tungstenite::Message; // Removed unused import
14
15/// データベース変更監視設定
16#[derive(Debug, Clone, Serialize)]
17pub struct DatabaseChanges {
18    schema: String,
19    table: String,
20    events: Vec<ChannelEvent>,
21    filter: Option<Vec<DatabaseFilter>>,
22}
23
24impl DatabaseChanges {
25    /// 新しいデータベース変更監視設定を作成
26    pub fn new(table: &str) -> Self {
27        Self {
28            schema: "public".to_string(),
29            table: table.to_string(),
30            events: Vec::new(),
31            filter: None,
32        }
33    }
34
35    /// スキーマを設定
36    pub fn schema(mut self, schema: &str) -> Self {
37        self.schema = schema.to_string();
38        self
39    }
40
41    /// イベントを追加
42    pub fn event(mut self, event: ChannelEvent) -> Self {
43        if !self.events.contains(&event) {
44            self.events.push(event);
45        }
46        self
47    }
48
49    /// フィルター条件を追加
50    pub fn filter(mut self, filter: DatabaseFilter) -> Self {
51        self.filter.get_or_insert_with(Vec::new).push(filter);
52        self
53    }
54
55    // --- Filter convenience methods ---
56
57    pub fn eq<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
58        self.filter(DatabaseFilter {
59            column: column.to_string(),
60            operator: FilterOperator::Eq,
61            value: value.into(),
62        })
63    }
64
65    pub fn neq<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
66        self.filter(DatabaseFilter {
67            column: column.to_string(),
68            operator: FilterOperator::Neq,
69            value: value.into(),
70        })
71    }
72
73    pub fn gt<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
74        self.filter(DatabaseFilter {
75            column: column.to_string(),
76            operator: FilterOperator::Gt,
77            value: value.into(),
78        })
79    }
80
81    pub fn gte<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
82        self.filter(DatabaseFilter {
83            column: column.to_string(),
84            operator: FilterOperator::Gte,
85            value: value.into(),
86        })
87    }
88
89    pub fn lt<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
90        self.filter(DatabaseFilter {
91            column: column.to_string(),
92            operator: FilterOperator::Lt,
93            value: value.into(),
94        })
95    }
96
97    pub fn lte<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
98        self.filter(DatabaseFilter {
99            column: column.to_string(),
100            operator: FilterOperator::Lte,
101            value: value.into(),
102        })
103    }
104
105    pub fn in_values<T: Into<serde_json::Value>>(self, column: &str, values: Vec<T>) -> Self {
106        self.filter(DatabaseFilter {
107            column: column.to_string(),
108            operator: FilterOperator::In,
109            value: values
110                .into_iter()
111                .map(|v| v.into())
112                .collect::<Vec<_>>()
113                .into(),
114        })
115    }
116
117    // Add other filter methods (like, ilike, contains) if needed
118
119    // --- Internal methods ---
120
121    // /// Convert config to JSON for the websocket message
122    // pub(crate) fn to_channel_config(&self) -> serde_json::Value {
123    //     // ... implementation ...
124    // }
125}
126
127/// ブロードキャストイベント監視設定
128#[derive(Debug, Clone, Serialize)]
129pub struct BroadcastChanges {
130    event: String, // Specific event name to listen for
131}
132
133impl BroadcastChanges {
134    pub fn new(event: &str) -> Self {
135        Self {
136            event: event.to_string(),
137        }
138    }
139
140    #[allow(dead_code)] // Mark as allowed since it might be useful later
141    pub(crate) fn get_event_name(&self) -> &str {
142        &self.event
143    }
144}
145
146/// プレゼンスイベント監視設定 (シンプルなマーカー型)
147#[derive(Debug, Clone, Default, Serialize)]
148pub struct PresenceChanges;
149
150impl PresenceChanges {
151    pub fn new() -> Self {
152        // Self::default() // Clippy: default_constructed_unit_structs
153        PresenceChanges // Create directly
154    }
155}
156
157/// アクティブなチャンネル購読を表す
158pub struct Subscription {
159    id: String, // Internal subscription identifier
160    channel: Arc<Channel>,
161}
162
163impl Drop for Subscription {
164    fn drop(&mut self) {
165        let id_clone = self.id.clone();
166        let channel_clone = self.channel.clone();
167        tokio::spawn(async move {
168            if let Err(e) = channel_clone.unsubscribe(&id_clone).await {
169                // TODO: Log unsubscribe error properly
170                eprintln!("Error unsubscribing from channel: {}", e);
171            }
172        });
173    }
174}
175
176type CallbackFn = Box<dyn Fn(Payload) + Send + Sync>;
177type PresenceCallbackFn = Box<dyn Fn(PresenceChange) + Send + Sync>;
178
179/// 内部チャンネル表現
180pub(crate) struct Channel {
181    topic: String,
182    client: Arc<RealtimeClient>, // Store Arc<RealtimeClient> for sending messages
183    callbacks: Arc<RwLock<HashMap<String, CallbackFn>>>,
184    presence_callbacks: Arc<RwLock<Vec<PresenceCallbackFn>>>,
185    // Add channel state
186    state: Arc<RwLock<ChannelState>>,
187}
188
189#[derive(Debug, Clone, Copy, PartialEq, Eq)]
190pub(crate) enum ChannelState {
191    Closed,
192    Joining,
193    Joined,
194    Leaving,
195    Errored,
196}
197
198impl Channel {
199    pub(crate) fn new(topic: String, client: Arc<RealtimeClient>) -> Self {
200        debug!("Channel::new created for topic: {}", topic);
201        Self {
202            topic,
203            client,
204            callbacks: Arc::new(RwLock::new(HashMap::new())),
205            presence_callbacks: Arc::new(RwLock::new(Vec::new())),
206            state: Arc::new(RwLock::new(ChannelState::Closed)),
207        }
208    }
209
210    async fn set_state(&self, state: ChannelState) {
211        let mut current_state = self.state.write().await;
212        if *current_state != state {
213            info!(
214                "Channel '{}' state changing from {:?} to {:?}",
215                self.topic, *current_state, state
216            );
217            *current_state = state;
218        } else {
219            trace!(
220                "Channel '{}' state already {:?}, not changing.",
221                self.topic,
222                state
223            );
224        }
225    }
226
227    // Simplified join - just sends the message
228    async fn join(&self) -> Result<(), RealtimeError> {
229        self.set_state(ChannelState::Joining).await;
230        let join_ref = self.client.next_ref();
231        info!(
232            "Channel '{}' sending join message with ref {}",
233            self.topic, join_ref
234        );
235        let join_msg = json!({
236            "topic": self.topic,
237            "event": ChannelEvent::PhoenixJoin,
238            "payload": {},
239            "ref": join_ref
240        });
241        // TODO: Add timeout for join reply
242        self.client.send_message(join_msg).await
243        // Need mechanism to wait for phx_reply with matching ref
244    }
245
246    // async fn send_message(&self, payload: serde_json::Value) -> Result<(), RealtimeError> {
247    //    // ... implementation ...
248    // }
249
250    async fn unsubscribe(&self, id: &str) -> Result<(), RealtimeError> {
251        // Remove callback
252        self.callbacks.write().await.remove(id);
253        // TODO: Unsubscribe presence if needed
254
255        // Send unsubscribe message if this was the last callback? Requires tracking.
256        // For simplicity, assume client handles full channel leave when all subscriptions drop.
257        println!(
258            "Subscription {} dropped. Channel {} might need explicit leave.",
259            id, self.topic
260        );
261        Ok(())
262    }
263
264    // Adjusted to accept RealtimeMessage
265    pub(crate) async fn handle_message(&self, message: RealtimeMessage) {
266        debug!(
267            "Channel '{}' handling message: event={:?}, ref={:?}",
268            self.topic, message.event, message.message_ref
269        );
270
271        match message.event {
272            ChannelEvent::PhoenixReply => {
273                // TODO: Check ref against pending joins/leaves
274                info!(
275                    "Channel '{}' received PhoenixReply: {:?}",
276                    self.topic, message.payload
277                );
278                if *self.state.read().await == ChannelState::Joining {
279                    // Basic assumption: any reply means join succeeded for now
280                    self.set_state(ChannelState::Joined).await;
281                } else if *self.state.read().await == ChannelState::Leaving {
282                    self.set_state(ChannelState::Closed).await;
283                }
284            }
285            ChannelEvent::PhoenixClose => {
286                info!(
287                    "Channel '{}' received PhoenixClose. Setting state to Closed.",
288                    self.topic
289                );
290                self.set_state(ChannelState::Closed).await;
291            }
292            ChannelEvent::PhoenixError => {
293                error!(
294                    "Channel '{}' received PhoenixError: {:?}",
295                    self.topic, message.payload
296                );
297                self.set_state(ChannelState::Errored).await;
298            }
299            ChannelEvent::PostgresChanges | ChannelEvent::Broadcast | ChannelEvent::Presence => {
300                // These events have nested data we need to pass to callbacks
301                let payload = Payload {
302                    data: message.payload.clone(), // Pass the whole payload as data for now
303                    event_type: Some(message.event.to_string()), // Reflect the event type
304                    timestamp: None, // Timestamp might be deeper in payload, needs parsing
305                };
306                trace!(
307                    "Channel '{}' dispatching event {:?} to callbacks",
308                    self.topic,
309                    message.event
310                );
311                let callbacks_guard = self.callbacks.read().await;
312                for callback in callbacks_guard.values() {
313                    // Execute callback - Consider spawning if long-running
314                    callback(payload.clone());
315                }
316                // TODO: Handle presence callbacks separately if event is Presence
317            }
318            // Ignore other events like Heartbeat, Insert, Update, Delete, All at the channel level
319            // (Those might be relevant *inside* a PostgresChanges payload)
320            _ => {
321                trace!(
322                    "Channel '{}' ignored event: {:?}",
323                    self.topic,
324                    message.event
325                );
326            }
327        }
328    }
329}
330
331/// チャンネル作成と購読設定のためのビルダー
332pub struct ChannelBuilder<'a> {
333    client: &'a RealtimeClient,
334    topic: String,
335    db_callbacks: HashMap<String, (DatabaseChanges, CallbackFn)>,
336    broadcast_callbacks: HashMap<String, (BroadcastChanges, CallbackFn)>,
337    presence_callbacks: Vec<PresenceCallbackFn>,
338}
339
340impl<'a> ChannelBuilder<'a> {
341    pub(crate) fn new(client: &'a RealtimeClient, topic: &str) -> Self {
342        debug!("ChannelBuilder::new for topic: {}", topic);
343        Self {
344            client,
345            topic: topic.to_string(),
346            db_callbacks: HashMap::new(),
347            broadcast_callbacks: HashMap::new(),
348            presence_callbacks: Vec::new(),
349        }
350    }
351
352    /// データベース変更イベントのコールバックを登録
353    pub fn on<F>(mut self, changes: DatabaseChanges, callback: F) -> Self
354    where
355        F: Fn(Payload) + Send + Sync + 'static,
356    {
357        // Use a unique identifier for the subscription
358        let id = uuid::Uuid::new_v4().to_string();
359        self.db_callbacks.insert(id, (changes, Box::new(callback)));
360        self
361    }
362
363    /// ブロードキャストイベントのコールバックを登録
364    pub fn on_broadcast<F>(mut self, changes: BroadcastChanges, callback: F) -> Self
365    where
366        F: Fn(Payload) + Send + Sync + 'static,
367    {
368        let id = uuid::Uuid::new_v4().to_string();
369        self.broadcast_callbacks
370            .insert(id, (changes, Box::new(callback)));
371        self
372    }
373
374    /// プレゼンス変更イベントのコールバックを登録
375    pub fn on_presence<F>(mut self, callback: F) -> Self
376    where
377        F: Fn(PresenceChange) + Send + Sync + 'static,
378    {
379        self.presence_callbacks.push(Box::new(callback));
380        self
381    }
382
383    /// チャンネルへの接続と購読を開始
384    pub async fn subscribe(self) -> Result<Vec<Subscription>, RealtimeError> {
385        info!("ChannelBuilder subscribing for topic: {}", self.topic);
386        let client_arc = Arc::new(self.client.clone()); // Clone client Arcs into a new Arc for the Channel
387
388        // Get or create the channel instance
389        let mut channels_guard = client_arc.channels.write().await;
390        let channel = channels_guard
391            .entry(self.topic.clone())
392            .or_insert_with(|| Arc::new(Channel::new(self.topic.clone(), client_arc.clone())))
393            .clone();
394        drop(channels_guard); // Release write lock
395        debug!("Got or created Channel Arc for topic: {}", self.topic);
396
397        let mut subscriptions = Vec::new();
398        let mut callbacks_guard = channel.callbacks.write().await;
399        let mut presence_callbacks_guard = channel.presence_callbacks.write().await;
400
401        // Add database change callbacks
402        for (id, (_changes, callback)) in self.db_callbacks {
403            debug!("Adding DB callback ID {} to channel {}", id, self.topic);
404            callbacks_guard.insert(id.clone(), callback);
405            subscriptions.push(Subscription {
406                id,
407                channel: channel.clone(),
408            });
409        }
410
411        // Add broadcast callbacks
412        for (id, (_changes, callback)) in self.broadcast_callbacks {
413            debug!(
414                "Adding Broadcast callback ID {} to channel {}",
415                id, self.topic
416            );
417            // Assuming broadcast uses the same callback mechanism for now
418            callbacks_guard.insert(id.clone(), callback);
419            subscriptions.push(Subscription {
420                id,
421                channel: channel.clone(),
422            });
423        }
424
425        // Add presence callbacks
426        for callback in self.presence_callbacks {
427            debug!("Adding Presence callback to channel {}", self.topic);
428            presence_callbacks_guard.push(callback);
429            // How to represent presence subscription? Use a fixed ID?
430            let id = format!("presence_{}", self.topic); // Example ID
431            subscriptions.push(Subscription {
432                id,
433                channel: channel.clone(),
434            });
435        }
436
437        drop(callbacks_guard);
438        drop(presence_callbacks_guard);
439
440        // Only send join if channel wasn't already joined/joining
441        let current_state = *channel.state.read().await;
442        if current_state == ChannelState::Closed || current_state == ChannelState::Errored {
443            info!(
444                "Channel '{}' is {:?}, attempting to join.",
445                self.topic, current_state
446            );
447            match channel.join().await {
448                Ok(_) => {
449                    // Join message sent, now wait for reply (handled by reader task)
450                    debug!(
451                        "Join message sent for channel '{}'. Waiting for reply.",
452                        self.topic
453                    );
454                    // We might want a timeout here to ensure the join completes
455                    match timeout(Duration::from_secs(10), async {
456                        while *channel.state.read().await != ChannelState::Joined {
457                            tokio::time::sleep(Duration::from_millis(50)).await;
458                            // Add a check for Errored or Closed state too
459                            let check_state = *channel.state.read().await;
460                            if check_state == ChannelState::Errored
461                                || check_state == ChannelState::Closed
462                            {
463                                return Err(RealtimeError::SubscriptionError(format!(
464                                    "Channel '{}' entered state {:?} while waiting for join reply",
465                                    self.topic, check_state
466                                )));
467                            }
468                        }
469                        Ok(())
470                    })
471                    .await
472                    {
473                        Ok(Ok(_)) => info!("Channel '{}' successfully joined.", self.topic),
474                        Ok(Err(e)) => {
475                            error!(
476                                "Error waiting for join confirmation for channel '{}': {:?}",
477                                self.topic, e
478                            );
479                            return Err(e);
480                        }
481                        Err(_) => {
482                            error!(
483                                "Timed out waiting for join confirmation for channel '{}'",
484                                self.topic
485                            );
486                            channel.set_state(ChannelState::Errored).await;
487                            return Err(RealtimeError::SubscriptionError(format!(
488                                "Timed out waiting for join confirmation for channel '{}'",
489                                self.topic
490                            )));
491                        }
492                    }
493                }
494                Err(e) => {
495                    error!(
496                        "Failed to send join message for channel '{}': {}",
497                        self.topic, e
498                    );
499                    channel.set_state(ChannelState::Errored).await;
500                    return Err(e);
501                }
502            }
503        } else {
504            info!(
505                "Channel '{}' is already {:?}, not sending join message.",
506                self.topic, current_state
507            );
508        }
509
510        info!(
511            "ChannelBuilder subscribe finished for topic '{}', returning {} subscriptions.",
512            self.topic,
513            subscriptions.len()
514        );
515        Ok(subscriptions)
516    }
517
518    // Method to track presence - might belong on RealtimeClient or Channel directly?
519    pub async fn track_presence(
520        &self,
521        _user_id: &str,
522        _user_data: serde_json::Value,
523    ) -> Result<(), RealtimeError> {
524        // TODO: Implement sending presence track message
525        Err(RealtimeError::ChannelError(
526            "track_presence not implemented".to_string(),
527        ))
528    }
529}