1use crate::channel::{Channel, ChannelBuilder}; use crate::error::RealtimeError;
3use crate::message::RealtimeMessage; use futures_util::{SinkExt, StreamExt};
5use rand::Rng;
6use serde_json::json;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::mpsc;
12use tokio::sync::{broadcast, RwLock};
13use tokio::time::sleep;
14use tokio_tungstenite::connect_async;
15use tokio_tungstenite::tungstenite::Message;
16use tracing::{debug, error, info, instrument, trace, warn};
17use url::Url; #[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ConnectionState {
22 Disconnected,
23 Connecting,
24 Connected,
25 Reconnecting,
26}
27
28#[derive(Debug, Clone)]
30pub struct RealtimeClientOptions {
31 pub auto_reconnect: bool,
32 pub max_reconnect_attempts: Option<u32>,
33 pub reconnect_interval: u64,
34 pub reconnect_backoff_factor: f64,
35 pub max_reconnect_interval: u64,
36 pub heartbeat_interval: u64,
37}
38
39impl Default for RealtimeClientOptions {
40 fn default() -> Self {
41 Self {
42 auto_reconnect: true,
43 max_reconnect_attempts: None, reconnect_interval: 1000, reconnect_backoff_factor: 1.5,
46 max_reconnect_interval: 30000, heartbeat_interval: 30000, }
49 }
50}
51
52pub struct RealtimeClient {
54 pub(crate) url: String,
55 pub(crate) key: String,
56 pub(crate) next_ref: AtomicU32,
57 pub(crate) channels: Arc<RwLock<HashMap<String, Arc<Channel>>>>,
59 pub(crate) socket: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
61 pub(crate) options: RealtimeClientOptions,
62 state: Arc<RwLock<ConnectionState>>,
63 reconnect_attempts: AtomicU32,
64 is_manually_closed: Arc<AtomicBool>,
66 state_change: broadcast::Sender<ConnectionState>,
67 pub(crate) access_token: Arc<RwLock<Option<String>>>,
69}
70
71impl RealtimeClient {
72 #[instrument(skip(key))]
74 pub fn new(url: &str, key: &str) -> Self {
75 info!("Creating new RealtimeClient");
76 Self::new_with_options(url, key, RealtimeClientOptions::default())
77 }
78
79 #[instrument(skip(key))]
81 pub fn new_with_options(url: &str, key: &str, options: RealtimeClientOptions) -> Self {
82 info!("Creating new RealtimeClient with options: {:?}", options);
83 let (state_change_tx, _) = broadcast::channel(16); Self {
85 url: url.to_string(),
86 key: key.to_string(),
87 next_ref: AtomicU32::new(1),
88 channels: Arc::new(RwLock::new(HashMap::new())),
89 socket: Arc::new(RwLock::new(None)),
90 options,
91 state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
92 reconnect_attempts: AtomicU32::new(0),
93 is_manually_closed: Arc::new(AtomicBool::new(false)),
95 state_change: state_change_tx,
96 access_token: Arc::new(RwLock::new(None)),
98 }
99 }
100
101 #[instrument(skip(self, token))]
103 pub async fn set_auth(&self, token: Option<String>) {
104 info!("Setting auth token (is_some: {})", token.is_some());
105 let mut current_token = self.access_token.write().await;
106 *current_token = token;
107 }
109
110 #[instrument(skip(self))]
112 pub fn on_state_change(&self) -> broadcast::Receiver<ConnectionState> {
113 debug!("Subscribing to state changes");
114 self.state_change.subscribe()
115 }
116
117 #[instrument(skip(self))]
119 pub async fn get_connection_state(&self) -> ConnectionState {
120 let state = *self.state.read().await;
121 debug!(?state, "Getting current connection state");
122 state
123 }
124
125 #[instrument(skip(self))]
127 pub fn channel(&self, topic: &str) -> ChannelBuilder {
128 info!(?topic, "Creating channel builder");
129 ChannelBuilder::new(self, topic)
130 }
131
132 pub(crate) fn next_ref(&self) -> String {
134 let next = self.next_ref.fetch_add(1, Ordering::SeqCst);
135 trace!(next_ref = next, "Generated next ref");
136 next.to_string()
137 }
138
139 async fn set_connection_state(&self, state: ConnectionState) {
141 let mut current_state = self.state.write().await;
142 if *current_state != state {
143 info!(from = ?*current_state, to = ?state, "Client state changing");
144 *current_state = state;
145 if let Err(e) = self.state_change.send(state) {
147 warn!(error = %e, ?state, "Failed to broadcast state change");
148 }
149 } else {
150 trace!(?state, "Client state already set, not changing.");
151 }
152 }
153
154 #[instrument(skip(self))]
156 pub fn connect(
157 &self,
158 ) -> impl std::future::Future<Output = Result<(), RealtimeError>> + Send + 'static {
159 info!("Connect task initiated");
160 let url = self.url.clone();
162 let key = self.key.clone();
163 let socket_arc = self.socket.clone();
164 let state_arc = self.state.clone();
165 let state_change_tx = self.state_change.clone();
166 let _channels_arc = self.channels.clone();
167 let options = self.options.clone();
168 let is_manually_closed_arc = self.is_manually_closed.clone();
169 let token_arc = self.access_token.clone(); async move {
172 info!("Connect task initiated");
173 is_manually_closed_arc.store(false, Ordering::SeqCst);
174 debug!("Reset manual close flag");
175
176 let token_guard = token_arc.read().await;
177 let token_param = token_guard
178 .as_ref()
179 .map(|t| format!("&token={}", t))
180 .unwrap_or_default();
181 debug!(token_present = token_guard.is_some(), "Read auth token");
182 drop(token_guard); let base_url = match Url::parse(&url) {
185 Ok(u) => u,
186 Err(e) => {
187 error!(url = %url, error = %e, "Failed to parse base URL");
188 Self::set_connection_state_internal(
190 state_arc.clone(),
191 state_change_tx.clone(),
192 ConnectionState::Disconnected,
193 )
194 .await;
195 return Err(RealtimeError::UrlParseError(e));
196 }
197 };
198 debug!(url = %base_url, "Parsed base URL");
199 match base_url.scheme() {
201 "http" | "ws" | "https" | "wss" => { }
202 s => {
204 error!(scheme = %s, "Unsupported URL scheme");
205 Self::set_connection_state_internal(
206 state_arc.clone(),
207 state_change_tx.clone(),
208 ConnectionState::Disconnected,
209 )
210 .await;
211 return Err(RealtimeError::ConnectionError(format!(
212 "Unsupported URL scheme: {}",
213 s
214 )));
215 }
216 };
217
218 let _host = match base_url.host_str() {
220 Some(h) => h, None => {
222 error!(url = %base_url, "Failed to get host from URL (no host)");
224 Self::set_connection_state_internal(
225 state_arc.clone(),
226 state_change_tx.clone(),
227 ConnectionState::Disconnected,
228 )
229 .await;
230 return Err(RealtimeError::UrlParseError(url::ParseError::EmptyHost));
231 }
233 };
234 let ws_url = match base_url.join("/realtime/v1/websocket?vsn=2.0.0") {
235 Ok(mut joined_url) => {
236 joined_url
237 .query_pairs_mut()
238 .append_pair("apikey", &key)
239 .append_pair("token", token_param.trim_start_matches("&token=")); info!(url = %joined_url, "Constructed WebSocket URL");
241 joined_url.to_string()
242 }
243 Err(e) => {
244 error!(error = %e, base_url = %base_url, "Failed to join path to base URL");
245 Self::set_connection_state_internal(
246 state_arc.clone(),
247 state_change_tx.clone(),
248 ConnectionState::Disconnected,
249 )
250 .await;
251 return Err(RealtimeError::UrlParseError(e));
252 }
253 };
254
255 info!(url = %ws_url, "Attempting to connect to WebSocket");
256
257 Self::set_connection_state_internal(
258 state_arc.clone(),
259 state_change_tx.clone(),
260 ConnectionState::Connecting,
261 )
262 .await;
263
264 let connect_result = connect_async(&ws_url).await; let ws_stream = match connect_result {
266 Ok((stream, response)) => {
267 info!(response = ?response, "WebSocket connection successful");
268 stream
269 }
270 Err(e) => {
271 error!(error = %e, url = %ws_url, "WebSocket connection failed");
272 Self::set_connection_state_internal(
274 state_arc.clone(),
275 state_change_tx.clone(),
276 ConnectionState::Disconnected,
277 )
278 .await;
279 return Err(RealtimeError::ConnectionError(format!(
280 "WebSocket connection failed: {}",
281 e
282 )));
283 }
284 };
285
286 Self::set_connection_state_internal(
287 state_arc.clone(),
288 state_change_tx.clone(),
289 ConnectionState::Connected,
290 )
291 .await;
292
293 let (write, read) = ws_stream.split();
294 debug!("WebSocket stream split into writer and reader");
295
296 let (socket_tx, socket_rx) = mpsc::channel::<Message>(100);
297 *socket_arc.write().await = Some(socket_tx.clone()); debug!("Internal MPSC channel created, sender stored");
299
300 let writer_socket_arc = socket_arc.clone();
302 let writer_state_arc = state_arc.clone();
303 let writer_state_change_tx = state_change_tx.clone();
304 let _writer_handle = tokio::spawn(async move {
305 #[instrument(skip_all, name = "ws_writer")]
307 async fn writer_task(
308 mut write: impl SinkExt<Message, Error = tokio_tungstenite::tungstenite::Error>
309 + Unpin,
310 mut socket_rx: mpsc::Receiver<Message>,
311 writer_socket_arc: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
312 writer_state_arc: Arc<RwLock<ConnectionState>>,
313 writer_state_change_tx: broadcast::Sender<ConnectionState>,
314 heartbeat_interval_ms: u64,
315 ) {
316 info!("Writer task started");
317 let heartbeat_interval = Duration::from_millis(heartbeat_interval_ms);
318 let mut heartbeat_timer = tokio::time::interval(heartbeat_interval);
319
320 loop {
321 tokio::select! {
322 Some(msg) = socket_rx.recv() => {
324 trace!(message = ?msg, "Sending message via WebSocket");
325 if let Err(e) = write.send(msg).await {
326 error!(error = %e, "Failed to send message via WebSocket");
327 {
329 let mut current_state = writer_state_arc.write().await;
330 if *current_state != ConnectionState::Disconnected {
331 info!(from = ?*current_state, to = ?ConnectionState::Disconnected, "Writer: Setting state Disconnected on send error");
332 *current_state = ConnectionState::Disconnected;
333 let _ = writer_state_change_tx.send(ConnectionState::Disconnected);
334 }
335 }
336 break;
337 }
338 }
339 _ = heartbeat_timer.tick() => {
341 let heartbeat_ref = format!("hb-{}", rand::thread_rng().gen::<u32>());
342 let heartbeat_msg = json!({
343 "topic": "phoenix",
344 "event": "heartbeat",
345 "payload": {},
346 "ref": heartbeat_ref
347 });
348 trace!(heartbeat_ref = %heartbeat_ref, "Sending heartbeat");
349 if let Err(e) = write.send(Message::Text(heartbeat_msg.to_string())).await {
350 error!(error = %e, "Failed to send heartbeat");
351 {
353 let mut current_state = writer_state_arc.write().await;
354 if *current_state != ConnectionState::Disconnected {
355 info!(from = ?*current_state, to = ?ConnectionState::Disconnected, "Writer: Setting state Disconnected on heartbeat error");
356 *current_state = ConnectionState::Disconnected;
357 let _ = writer_state_change_tx.send(ConnectionState::Disconnected);
358 }
359 }
360 break;
361 }
362 }
363 else => {
364 info!("Writer loop finished (select exhausted)");
365 break;
366 }
367 }
368 }
369 info!("Writer task finished");
370 *writer_socket_arc.write().await = None;
372 }
373 writer_task(
374 write,
375 socket_rx,
376 writer_socket_arc,
377 writer_state_arc,
378 writer_state_change_tx,
379 options.heartbeat_interval,
380 )
381 .await;
382 });
383
384 let reader_socket_arc = socket_arc.clone();
386 let reader_state_arc = state_arc.clone();
387 let reader_state_change_tx = state_change_tx.clone();
388 let reader_channels_arc = _channels_arc.clone(); let reader_reconnect_attempts = Arc::new(AtomicU32::new(0)); let reader_options = options.clone();
391 let reader_is_manually_closed = is_manually_closed_arc.clone();
392 let _reader_handle = tokio::spawn(async move {
393 #[allow(clippy::too_many_arguments)] async fn reader_task(
398 mut read: impl StreamExt<Item = Result<Message, tokio_tungstenite::tungstenite::Error>>
399 + Unpin,
400 reader_channels_arc: Arc<RwLock<HashMap<String, Arc<Channel>>>>,
401 reader_socket_arc: Arc<RwLock<Option<mpsc::Sender<Message>>>>, reader_state_arc: Arc<RwLock<ConnectionState>>,
403 reader_state_change_tx: broadcast::Sender<ConnectionState>,
404 _reader_reconnect_attempts: Arc<AtomicU32>, reader_options: RealtimeClientOptions, reader_is_manually_closed: Arc<AtomicBool>,
407 ) {
408 info!("Reader task started");
409 while let Some(result) = read.next().await {
410 match result {
411 Ok(msg) => {
412 trace!(message = ?msg, "Received message from WebSocket");
413 match msg {
414 Message::Text(text) => {
415 match serde_json::from_str::<RealtimeMessage>(&text) {
416 Ok(parsed_msg) => {
417 trace!(message = ?parsed_msg, "Parsed RealtimeMessage");
418 let channels = reader_channels_arc.read().await;
420 if let Some(channel) =
421 channels.get(&parsed_msg.topic)
422 {
423 channel.handle_message(parsed_msg).await;
424 }
425 }
427 Err(e) => {
428 error!(error = %e, raw_message = %text, "Failed to parse RealtimeMessage");
429 }
430 }
431 }
432 Message::Close(close_frame) => {
433 info!(frame = ?close_frame, "Received WebSocket Close frame");
434 break; }
436 Message::Ping(ping_data) => {
437 trace!(data = ?ping_data, "Received Ping, sending Pong");
438 if let Some(tx) = reader_socket_arc.read().await.as_ref() {
440 if let Err(e) = tx.send(Message::Pong(ping_data)).await
441 {
442 error!(error = %e, "Failed to queue Pong message");
443 }
444 } else {
445 warn!("Socket sender not available to send Pong");
446 }
447 }
448 Message::Pong(_) => {
449 trace!("Received Pong");
450 }
452 Message::Binary(_) => {
453 warn!("Received unexpected Binary message");
454 }
455 Message::Frame(_) => {
456 trace!("Received low-level Frame");
458 }
459 }
460 }
461 Err(e) => {
462 error!(error = %e, "WebSocket read error");
463 break; }
465 }
466 }
467 info!("Reader loop finished");
468
469 if !reader_is_manually_closed.load(Ordering::SeqCst)
471 && reader_options.auto_reconnect
472 {
473 warn!("WebSocket connection lost unexpectedly, attempting reconnect...");
474 {
476 let mut current_state = reader_state_arc.write().await;
477 if *current_state != ConnectionState::Reconnecting {
478 info!(from = ?*current_state, to = ?ConnectionState::Reconnecting, "Reader: Setting state Reconnecting");
479 *current_state = ConnectionState::Reconnecting;
480 let _ = reader_state_change_tx.send(ConnectionState::Reconnecting);
481 }
482 }
483 warn!(
487 "Reconnect logic not fully implemented, setting state to Disconnected"
488 );
489 {
491 let mut current_state = reader_state_arc.write().await;
492 if *current_state != ConnectionState::Disconnected {
493 info!(from = ?*current_state, to = ?ConnectionState::Disconnected, "Reader: Setting state Disconnected (reconnect N/A)");
494 *current_state = ConnectionState::Disconnected;
495 let _ = reader_state_change_tx.send(ConnectionState::Disconnected);
496 }
497 }
498 } else {
499 info!("WebSocket connection closed (manual or auto_reconnect=false)");
500 {
502 let mut current_state = reader_state_arc.write().await;
503 if *current_state != ConnectionState::Disconnected {
504 info!(from = ?*current_state, to = ?ConnectionState::Disconnected, "Reader: Setting state Disconnected (manual/no-reconnect)");
505 *current_state = ConnectionState::Disconnected;
506 let _ = reader_state_change_tx.send(ConnectionState::Disconnected);
507 }
508 }
509 }
510 *reader_socket_arc.write().await = None;
512 info!("Reader task finished");
513 }
514 reader_task(
515 read,
516 reader_channels_arc,
517 reader_socket_arc,
518 reader_state_arc,
519 reader_state_change_tx,
520 reader_reconnect_attempts,
521 reader_options,
522 reader_is_manually_closed,
523 )
524 .await;
525 });
526
527 info!("Connect task completed successfully (connection established, reader/writer tasks spawned)");
528 Ok(())
530 }
531 }
532
533 #[instrument(skip(state_arc, state_change_tx), fields(state = ?state))]
535 async fn set_connection_state_internal(
536 state_arc: Arc<RwLock<ConnectionState>>,
537 state_change_tx: broadcast::Sender<ConnectionState>,
538 state: ConnectionState,
539 ) {
540 let mut current_state = state_arc.write().await;
541 if *current_state != state {
542 info!(from = ?*current_state, to = ?state, "Internal state changing");
543 *current_state = state;
544 if let Err(e) = state_change_tx.send(state) {
545 warn!(error = %e, state = ?state, "Failed to broadcast internal state change");
546 }
547 } else {
548 trace!(state = ?state, "Internal state already set, not changing.");
549 }
550 }
551
552 #[instrument(skip(self))]
554 pub async fn disconnect(&self) -> Result<(), RealtimeError> {
555 info!("disconnect() called");
556 self.is_manually_closed.store(true, Ordering::SeqCst);
557 debug!("Set manual close flag");
558
559 self.set_connection_state(ConnectionState::Disconnected)
560 .await;
561
562 let mut socket_guard = self.socket.write().await;
563 if let Some(socket_tx) = socket_guard.take() {
564 info!("Closing WebSocket connection via MPSC channel");
565 drop(socket_tx); debug!("Dropped MPSC sender to signal writer task termination");
571 } else {
574 info!("disconnect() called but no active socket sender found (already disconnected?)");
575 }
576 drop(socket_guard);
577
578 info!("disconnect() finished");
584 Ok(())
585 }
586
587 #[instrument(skip(self))]
589 fn reconnect(&self) -> impl std::future::Future<Output = ()> + Send + 'static {
590 info!("reconnect() called");
591 let self_clone = self.clone(); async move {
593 let mut attempts = self_clone.reconnect_attempts.fetch_add(1, Ordering::SeqCst);
594 info!(attempts, "Reconnect attempt initiated");
595
596 if let Some(max_attempts) = self_clone.options.max_reconnect_attempts {
597 if attempts >= max_attempts {
598 error!(max_attempts, "Max reconnect attempts reached. Giving up.");
599 self_clone
600 .set_connection_state(ConnectionState::Disconnected)
601 .await;
602 return;
603 }
604 }
605
606 let interval_ms = std::cmp::min(
607 self_clone.options.max_reconnect_interval,
608 (self_clone.options.reconnect_interval as f64
609 * self_clone
610 .options
611 .reconnect_backoff_factor
612 .powi(attempts as i32)) as u64,
613 );
614 let interval = Duration::from_millis(interval_ms);
615 info!(interval = ?interval, "Waiting before next reconnect attempt");
616
617 sleep(interval).await;
618
619 info!("Attempting to reconnect...");
620 match self_clone.connect().await {
621 Ok(_) => {
622 info!("Reconnect successful!");
623 self_clone.reconnect_attempts.store(0, Ordering::SeqCst); }
626 Err(e) => {
627 error!(error = %e, attempts, "Reconnect attempt failed");
628 if self_clone.options.auto_reconnect {
631 attempts = self_clone.reconnect_attempts.load(Ordering::SeqCst); if let Some(max_attempts) = self_clone.options.max_reconnect_attempts {
634 if attempts >= max_attempts {
635 warn!("Max reconnect attempts reached after failed attempt.");
636 return; }
638 }
639 warn!("Scheduling next reconnect attempt...");
640 tokio::spawn(self_clone.reconnect());
641 }
642 }
643 }
644 info!("reconnect() task finished for this attempt");
645 }
646 }
647
648 #[instrument(skip(self, message))]
650 pub(crate) async fn send_message(
651 &self,
652 message: serde_json::Value,
653 ) -> Result<(), RealtimeError> {
654 let msg_text = message.to_string();
655 trace!(message = %msg_text, "Preparing to send message");
656 let ws_message = Message::Text(msg_text);
657
658 let socket_guard = self.socket.read().await;
659 if let Some(socket_tx) = socket_guard.as_ref() {
660 debug!("Sending message via MPSC channel to writer task");
661 socket_tx.send(ws_message).await.map_err(|e| {
662 error!(error = %e, "Failed to send message via MPSC channel");
663 RealtimeError::ConnectionError(format!(
664 "Failed to send message via MPSC channel: {}",
665 e
666 ))
667 })
668 } else {
669 error!("Cannot send message: WebSocket sender not available (not connected?)");
670 Err(RealtimeError::ConnectionError(
671 "WebSocket sender not available (not connected?)".to_string(),
672 ))
673 }
674 }
675}
676
677impl Clone for RealtimeClient {
678 fn clone(&self) -> Self {
679 Self {
681 url: self.url.clone(),
682 key: self.key.clone(),
683 next_ref: AtomicU32::new(self.next_ref.load(Ordering::SeqCst)), channels: self.channels.clone(),
685 socket: self.socket.clone(),
686 options: self.options.clone(),
687 state: self.state.clone(),
688 reconnect_attempts: AtomicU32::new(self.reconnect_attempts.load(Ordering::SeqCst)),
689 is_manually_closed: self.is_manually_closed.clone(),
690 state_change: self.state_change.clone(),
691 access_token: self.access_token.clone(),
692 }
693 }
694}
695
696impl From<tokio::sync::mpsc::error::SendError<Message>> for RealtimeError {
698 fn from(err: tokio::sync::mpsc::error::SendError<Message>) -> Self {
699 RealtimeError::ConnectionError(format!("Failed to send message to socket task: {}", err))
700 }
701}