supabase_rust_realtime/
client.rs

1use crate::channel::Channel; // Assuming Channel is in channel.rs
2use crate::error::RealtimeError;
3use futures_util::{SinkExt, StreamExt};
4use serde_json::json;
5use std::collections::HashMap;
6use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::mpsc;
10use tokio::sync::{broadcast, RwLock};
11use tokio::time::sleep;
12use tokio_tungstenite::connect_async;
13use tokio_tungstenite::tungstenite::Message;
14use url::Url;
15
16/// 接続状態
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum ConnectionState {
19    Disconnected,
20    Connecting,
21    Connected,
22    Reconnecting,
23}
24
25/// RealtimeClient設定オプション
26#[derive(Debug, Clone)]
27pub struct RealtimeClientOptions {
28    pub auto_reconnect: bool,
29    pub max_reconnect_attempts: Option<u32>,
30    pub reconnect_interval: u64,
31    pub reconnect_backoff_factor: f64,
32    pub max_reconnect_interval: u64,
33    pub heartbeat_interval: u64,
34}
35
36impl Default for RealtimeClientOptions {
37    fn default() -> Self {
38        Self {
39            auto_reconnect: true,
40            max_reconnect_attempts: None, // Infinite attempts
41            reconnect_interval: 1000,     // 1 second
42            reconnect_backoff_factor: 1.5,
43            max_reconnect_interval: 30000, // 30 seconds
44            heartbeat_interval: 30000,    // 30 seconds
45        }
46    }
47}
48
49/// Realtimeクライアント本体
50pub struct RealtimeClient {
51    pub(crate) url: String,
52    pub(crate) key: String,
53    pub(crate) next_ref: AtomicU32,
54    // Shared map of active channels (topic -> Channel)
55    pub(crate) channels: Arc<RwLock<HashMap<String, Arc<Channel>>>>,
56    // Shared sender for the WebSocket task
57    pub(crate) socket: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
58    pub(crate) options: RealtimeClientOptions,
59    state: Arc<RwLock<ConnectionState>>,
60    reconnect_attempts: AtomicU32,
61    // Wrap AtomicBool in Arc for sharing across tasks
62    is_manually_closed: Arc<AtomicBool>,
63    state_change: broadcast::Sender<ConnectionState>,
64}
65
66impl RealtimeClient {
67    /// デフォルトオプションで新しいクライアントを作成
68    pub fn new(url: &str, key: &str) -> Self {
69        Self::new_with_options(url, key, RealtimeClientOptions::default())
70    }
71
72    /// カスタムオプションで新しいクライアントを作成
73    pub fn new_with_options(url: &str, key: &str, options: RealtimeClientOptions) -> Self {
74        let (state_change_tx, _) = broadcast::channel(16); // Channel for state changes
75        Self {
76            url: url.to_string(),
77            key: key.to_string(),
78            next_ref: AtomicU32::new(1),
79            channels: Arc::new(RwLock::new(HashMap::new())),
80            socket: Arc::new(RwLock::new(None)),
81            options,
82            state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
83            reconnect_attempts: AtomicU32::new(0),
84            // Initialize the Arc<AtomicBool>
85            is_manually_closed: Arc::new(AtomicBool::new(false)),
86            state_change: state_change_tx,
87        }
88    }
89
90    /// 接続状態変更の通知を受け取るためのレシーバーを取得
91    pub fn on_state_change(&self) -> broadcast::Receiver<ConnectionState> {
92        self.state_change.subscribe()
93    }
94
95    /// 現在の接続状態を取得
96    pub async fn get_connection_state(&self) -> ConnectionState {
97        *self.state.read().await
98    }
99
100    /// 特定のトピックに対するチャンネルビルダーを作成
101    pub fn channel(&self, topic: &str) -> crate::channel::ChannelBuilder {
102        crate::channel::ChannelBuilder::new(self, topic)
103    }
104
105    /// 次のメッセージ参照番号を生成
106    pub(crate) fn next_ref(&self) -> String {
107        self.next_ref.fetch_add(1, Ordering::SeqCst).to_string()
108    }
109
110    /// 内部接続状態を設定し、変更を通知
111    async fn set_connection_state(&self, state: ConnectionState) {
112        let mut current_state = self.state.write().await;
113        if *current_state != state {
114            *current_state = state;
115            // Ignore send error if no receivers are listening
116            let _ = self.state_change.send(state);
117        }
118    }
119
120    /// WebSocket接続を開始および管理するタスク
121    pub fn connect(&self) -> impl std::future::Future<Output = Result<(), RealtimeError>> + Send + 'static {
122        // Clone necessary Arcs and fields for the async task
123        let url = self.url.clone();
124        let key = self.key.clone();
125        let socket_arc = self.socket.clone();
126        let state_arc = self.state.clone();
127        let state_change_tx = self.state_change.clone();
128        let _channels_arc = self.channels.clone();
129        let options = self.options.clone();
130        let is_manually_closed_arc = self.is_manually_closed.clone();
131
132        async move {
133            // Reset manual close flag using the cloned Arc
134            is_manually_closed_arc.store(false, Ordering::SeqCst);
135
136            let ws_url = Url::parse(&format!("{}/websocket?apikey={}&vsn=1.0.0", url, key))?;
137
138            Self::set_connection_state_internal(state_arc.clone(), state_change_tx.clone(), ConnectionState::Connecting).await;
139
140            let (ws_stream, _) = connect_async(ws_url).await.map_err(|e| {
141                RealtimeError::ConnectionError(format!("WebSocket connection failed: {}", e))
142            })?;
143
144            Self::set_connection_state_internal(state_arc.clone(), state_change_tx.clone(), ConnectionState::Connected).await;
145
146            let (mut write, mut read) = ws_stream.split();
147
148            // Create an MPSC channel for sending messages to the WebSocket writer task
149            let (socket_tx, mut socket_rx) = mpsc::channel::<Message>(100);
150
151            // Store the sender half in the shared state
152            *socket_arc.write().await = Some(socket_tx);
153
154            // --- WebSocket Writer Task ---
155            let writer_socket_arc = socket_arc.clone();
156            let writer_state_arc = state_arc.clone();
157            let writer_state_change_tx = state_change_tx.clone();
158            tokio::spawn(async move {
159                while let Some(message) = socket_rx.recv().await {
160                    if let Err(e) = write.send(message).await {
161                        eprintln!("WebSocket send error: {}. Closing connection.", e);
162                        *writer_socket_arc.write().await = None; // Clear sender on error
163                        Self::set_connection_state_internal(writer_state_arc, writer_state_change_tx, ConnectionState::Disconnected).await;
164                        socket_rx.close();
165                        break;
166                    }
167                }
168                // Writer task normally ends when socket_rx channel is closed
169                println!("WebSocket writer task finished.");
170            });
171
172            // --- WebSocket Reader Task (and heartbeat/rejoin logic) ---
173            let reader_socket_arc = socket_arc.clone();
174            let reader_state_arc = state_arc.clone();
175            let reader_state_change_tx = state_change_tx.clone();
176            let heartbeat_interval = Duration::from_millis(options.heartbeat_interval);
177
178            loop {
179                let socket_tx_ref = reader_socket_arc.read().await;
180                let current_socket_tx = if let Some(tx) = socket_tx_ref.as_ref() {
181                    tx.clone()
182                } else {
183                    // Socket was closed (likely by writer task error or disconnect)
184                    println!("Socket sender gone, exiting reader task.");
185                    break;
186                };
187                drop(socket_tx_ref); // Release read lock
188
189                tokio::select! {
190                    // Read messages from WebSocket
191                    msg_result = read.next() => {
192                        match msg_result {
193                            Some(Ok(msg)) => {
194                                // TODO: Process incoming message (phx_reply, events, presence)
195                                println!("Received WS message: {:?}", msg);
196                                // Example: Handle heartbeat replies
197                                if let Message::Text(text) = &msg {
198                                    if let Ok(json_msg) = serde_json::from_str::<serde_json::Value>(text) {
199                                        if json_msg["event"].as_str() == Some("phx_reply") && json_msg["payload"]["status"].as_str() == Some("ok") {
200                                            // Likely heartbeat response, do nothing specific for now
201                                        } else {
202                                             // TODO: Route other messages to relevant channel callbacks
203                                        }
204                                    }
205                                }
206                            }
207                            Some(Err(e)) => {
208                                eprintln!("WebSocket read error: {}", e);
209                                Self::set_connection_state_internal(reader_state_arc.clone(), reader_state_change_tx.clone(), ConnectionState::Disconnected).await;
210                                *reader_socket_arc.write().await = None;
211                                break; // Exit loop on read error
212                            }
213                            None => {
214                                println!("WebSocket stream closed by remote.");
215                                Self::set_connection_state_internal(reader_state_arc.clone(), reader_state_change_tx.clone(), ConnectionState::Disconnected).await;
216                                *reader_socket_arc.write().await = None;
217                                break; // Exit loop on stream close
218                            }
219                        }
220                    }
221                    // Send heartbeat periodically
222                    _ = sleep(heartbeat_interval) => {
223                         let heartbeat_ref = AtomicU32::new(0).fetch_add(1, Ordering::SeqCst).to_string(); // Simple ref for heartbeat
224                         let heartbeat_msg = json!({
225                             "topic": "phoenix",
226                             "event": "heartbeat",
227                             "payload": {},
228                             "ref": heartbeat_ref
229                         });
230                         if let Err(e) = current_socket_tx.send(Message::Text(heartbeat_msg.to_string())).await {
231                             eprintln!("Failed to send heartbeat: {}. Assuming connection lost.", e);
232                             Self::set_connection_state_internal(reader_state_arc.clone(), reader_state_change_tx.clone(), ConnectionState::Disconnected).await;
233                             *reader_socket_arc.write().await = None;
234                             break; // Exit loop if heartbeat send fails
235                         }
236                    }
237                }
238            }
239
240            // Connection closed, attempt reconnect if enabled and not manually closed
241            if options.auto_reconnect && !is_manually_closed_arc.load(Ordering::SeqCst) {
242                println!("Connection lost. Auto-reconnect is enabled but reconnect logic needs implementation.");
243                 // self.reconnect(); // This needs careful handling
244            }
245
246            Ok(())
247        }
248    }
249
250    /// Helper for setting state (avoids async recursion issues)
251    async fn set_connection_state_internal(
252        state_arc: Arc<RwLock<ConnectionState>>,
253        state_change_tx: broadcast::Sender<ConnectionState>,
254        state: ConnectionState,
255    ) {
256        let mut current_state = state_arc.write().await;
257        if *current_state != state {
258            *current_state = state;
259            let _ = state_change_tx.send(state);
260        }
261    }
262
263    /// 切断処理
264    pub async fn disconnect(&self) -> Result<(), RealtimeError> {
265        // Use the Arc<AtomicBool>
266        self.is_manually_closed.store(true, Ordering::SeqCst);
267        self.set_connection_state(ConnectionState::Disconnected).await;
268
269        let mut socket_guard = self.socket.write().await;
270        if let Some(socket_tx) = socket_guard.take() {
271            // Close the sender channel, which will cause the writer task to exit
272            // The reader task should exit due to stream closure or heartbeat failure.
273            drop(socket_tx);
274            println!("WebSocket connection closed manually.");
275        }
276        // Clear channels? Maybe not, allow re-connecting later?
277        // self.channels.write().await.clear();
278
279        Ok(())
280    }
281
282    /// 再接続処理 (TODO: Implement backoff logic)
283    #[allow(dead_code)]
284    fn reconnect(&self) -> impl std::future::Future<Output = ()> + Send + 'static {
285        let self_clone = self.clone(); // Clones the Arcs including is_manually_closed
286        async move {
287            let mut attempts = 0;
288            let mut interval = self_clone.options.reconnect_interval;
289
290            loop {
291                // Use the cloned Arc<AtomicBool>
292                if self_clone.is_manually_closed.load(Ordering::SeqCst) {
293                    println!("Manual disconnect requested, stopping reconnect attempts.");
294                    break;
295                }
296
297                if let Some(max_attempts) = self_clone.options.max_reconnect_attempts {
298                    if attempts >= max_attempts {
299                        println!("Max reconnect attempts ({}) reached.", max_attempts);
300                        self_clone.set_connection_state(ConnectionState::Disconnected).await;
301                        break;
302                    }
303                }
304
305                attempts += 1;
306                self_clone.reconnect_attempts.store(attempts, Ordering::SeqCst);
307                self_clone.set_connection_state(ConnectionState::Reconnecting).await;
308                println!("Attempting to reconnect... (Attempt #{})", attempts);
309
310                sleep(Duration::from_millis(interval)).await;
311
312                // Try connecting again
313                // Need to call the connection logic, maybe refactor connect?
314                match self_clone.connect().await {
315                    Ok(_) => {
316                        println!("Reconnection successful!");
317                        self_clone.reconnect_attempts.store(0, Ordering::SeqCst); // Reset attempts
318                        // TODO: Rejoin channels?
319                        break; // Exit reconnect loop
320                    }
321                    Err(e) => {
322                        eprintln!("Reconnect attempt #{} failed: {}", attempts, e);
323                        // Increase interval with backoff
324                        interval = (interval as f64 * self_clone.options.reconnect_backoff_factor) as u64;
325                        interval = interval.min(self_clone.options.max_reconnect_interval);
326                    }
327                }
328            }
329        }
330    }
331}
332
333// Implement Clone manually to handle Arc fields correctly
334impl Clone for RealtimeClient {
335    fn clone(&self) -> Self {
336        Self {
337            url: self.url.clone(),
338            key: self.key.clone(),
339            next_ref: AtomicU32::new(self.next_ref.load(Ordering::SeqCst)), // Clone value
340            channels: self.channels.clone(),
341            socket: self.socket.clone(),
342            options: self.options.clone(),
343            state: self.state.clone(),
344            reconnect_attempts: AtomicU32::new(self.reconnect_attempts.load(Ordering::SeqCst)), // Clone value
345            // Clone the Arc<AtomicBool>
346            is_manually_closed: self.is_manually_closed.clone(),
347            state_change: self.state_change.clone(),
348        }
349    }
350}
351
352// WebSocketメッセージ送信エラーからの変換
353impl From<tokio::sync::mpsc::error::SendError<Message>> for RealtimeError {
354    fn from(err: tokio::sync::mpsc::error::SendError<Message>) -> Self {
355        RealtimeError::ConnectionError(format!(
356            "Failed to send message to socket task: {}",
357            err
358        ))
359    }
360}