1use futures_util::{SinkExt, StreamExt};
7use serde::{Deserialize, Serialize};
8use serde_json::json;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
11use std::sync::Arc;
12use std::time::Duration;
13use thiserror::Error;
14use tokio::sync::mpsc;
15use tokio::sync::{broadcast, RwLock};
16use tokio::time::sleep;
17use tokio_tungstenite::connect_async;
18use tokio_tungstenite::tungstenite::Message;
19use url::Url;
20
21#[derive(Error, Debug)]
23pub enum RealtimeError {
24 #[error("WebSocket error: {0}")]
25 WebSocketError(#[from] tokio_tungstenite::tungstenite::Error),
26
27 #[error("URL parse error: {0}")]
28 UrlParseError(#[from] url::ParseError),
29
30 #[error("JSON serialization error: {0}")]
31 SerializationError(#[from] serde_json::Error),
32
33 #[error("Subscription error: {0}")]
34 SubscriptionError(String),
35
36 #[error("Channel error: {0}")]
37 ChannelError(String),
38
39 #[error("Connection error: {0}")]
40 ConnectionError(String),
41}
42
43impl RealtimeError {
44 pub fn new(message: String) -> Self {
45 Self::ChannelError(message)
46 }
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
51#[serde(rename_all = "lowercase")]
52pub enum ChannelEvent {
53 Insert,
54 Update,
55 Delete,
56 All,
57}
58
59impl std::fmt::Display for ChannelEvent {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 Self::Insert => write!(f, "INSERT"),
63 Self::Update => write!(f, "UPDATE"),
64 Self::Delete => write!(f, "DELETE"),
65 Self::All => write!(f, "ALL"),
66 }
67 }
68}
69
70#[derive(Debug, Clone, Serialize)]
72pub struct DatabaseFilter {
73 pub column: String,
75 pub operator: FilterOperator,
77 pub value: serde_json::Value,
79}
80
81#[derive(Debug, Clone, PartialEq, Serialize)]
83pub enum FilterOperator {
84 Eq,
86 Neq,
88 Gt,
90 Gte,
92 Lt,
94 Lte,
96 In,
98 NotIn,
100 ContainedBy,
102 Contains,
104 ContainedByArray,
106 Like,
108 ILike,
110}
111
112impl std::fmt::Display for FilterOperator {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 let s = match self {
115 FilterOperator::Eq => "eq",
116 FilterOperator::Neq => "neq",
117 FilterOperator::Gt => "gt",
118 FilterOperator::Gte => "gte",
119 FilterOperator::Lt => "lt",
120 FilterOperator::Lte => "lte",
121 FilterOperator::In => "in",
122 FilterOperator::NotIn => "not.in",
123 FilterOperator::ContainedBy => "contained_by",
124 FilterOperator::Contains => "contains",
125 FilterOperator::ContainedByArray => "contained_by_array",
126 FilterOperator::Like => "like",
127 FilterOperator::ILike => "ilike",
128 };
129 write!(f, "{}", s)
130 }
131}
132
133#[derive(Debug, Clone, Serialize)]
135pub struct DatabaseChanges {
136 schema: String,
137 table: String,
138 events: Vec<ChannelEvent>,
139 filter: Option<Vec<DatabaseFilter>>,
140}
141
142impl DatabaseChanges {
143 pub fn new(table: &str) -> Self {
145 Self {
146 schema: "public".to_string(),
147 table: table.to_string(),
148 events: Vec::new(),
149 filter: None,
150 }
151 }
152
153 pub fn schema(mut self, schema: &str) -> Self {
155 self.schema = schema.to_string();
156 self
157 }
158
159 pub fn event(mut self, event: ChannelEvent) -> Self {
161 if !self.events.contains(&event) {
162 self.events.push(event);
163 }
164 self
165 }
166
167 pub fn filter(mut self, filter: DatabaseFilter) -> Self {
169 if self.filter.is_none() {
170 self.filter = Some(vec![filter]);
171 } else {
172 self.filter.as_mut().unwrap().push(filter);
173 }
174 self
175 }
176
177 pub fn eq<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
179 self.filter(DatabaseFilter {
180 column: column.to_string(),
181 operator: FilterOperator::Eq,
182 value: value.into(),
183 })
184 }
185
186 pub fn neq<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
188 self.filter(DatabaseFilter {
189 column: column.to_string(),
190 operator: FilterOperator::Neq,
191 value: value.into(),
192 })
193 }
194
195 pub fn gt<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
197 self.filter(DatabaseFilter {
198 column: column.to_string(),
199 operator: FilterOperator::Gt,
200 value: value.into(),
201 })
202 }
203
204 pub fn gte<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
206 self.filter(DatabaseFilter {
207 column: column.to_string(),
208 operator: FilterOperator::Gte,
209 value: value.into(),
210 })
211 }
212
213 pub fn lt<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
215 self.filter(DatabaseFilter {
216 column: column.to_string(),
217 operator: FilterOperator::Lt,
218 value: value.into(),
219 })
220 }
221
222 pub fn lte<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
224 self.filter(DatabaseFilter {
225 column: column.to_string(),
226 operator: FilterOperator::Lte,
227 value: value.into(),
228 })
229 }
230
231 pub fn in_values<T: Into<serde_json::Value>>(self, column: &str, values: Vec<T>) -> Self {
233 let json_values: Vec<serde_json::Value> = values.into_iter().map(|v| v.into()).collect();
234 self.filter(DatabaseFilter {
235 column: column.to_string(),
236 operator: FilterOperator::In,
237 value: serde_json::Value::Array(json_values),
238 })
239 }
240
241 pub fn contains<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
243 self.filter(DatabaseFilter {
244 column: column.to_string(),
245 operator: FilterOperator::Contains,
246 value: value.into(),
247 })
248 }
249
250 pub fn like(self, column: &str, pattern: &str) -> Self {
252 self.filter(DatabaseFilter {
253 column: column.to_string(),
254 operator: FilterOperator::Like,
255 value: serde_json::Value::String(pattern.to_string()),
256 })
257 }
258
259 pub fn ilike(self, column: &str, pattern: &str) -> Self {
261 self.filter(DatabaseFilter {
262 column: column.to_string(),
263 operator: FilterOperator::ILike,
264 value: serde_json::Value::String(pattern.to_string()),
265 })
266 }
267
268 #[allow(dead_code)]
269 fn to_channel_config(&self) -> serde_json::Value {
270 let mut events_str = String::new();
271
272 for (i, event) in self.events.iter().enumerate() {
274 if i > 0 {
275 events_str.push(',');
276 }
277 events_str.push_str(&event.to_string());
278 }
279
280 if events_str.is_empty() {
282 events_str = "*".to_string();
283 }
284
285 let mut config = serde_json::json!({
286 "schema": self.schema,
287 "table": self.table,
288 "event": events_str,
289 });
290
291 if let Some(filters) = &self.filter {
293 let mut filter_obj = serde_json::Map::new();
294
295 for filter in filters {
296 let filter_key = format!("{}:{}", filter.column, filter.operator);
297 filter_obj.insert(filter_key, filter.value.clone());
298 }
299
300 if !filter_obj.is_empty() {
301 config["filter"] = serde_json::Value::Object(filter_obj);
302 }
303 }
304
305 config
306 }
307}
308
309#[derive(Debug, Clone, Serialize)]
311pub struct BroadcastChanges {
312 event: String,
313}
314
315impl BroadcastChanges {
316 pub fn new(event: &str) -> Self {
318 Self {
319 event: event.to_string(),
320 }
321 }
322}
323
324#[derive(Debug, Clone, Serialize, Deserialize)]
326pub struct Payload {
327 pub data: serde_json::Value,
328 pub event_type: Option<String>,
329 pub timestamp: Option<i64>,
330}
331
332#[derive(Debug, Clone, Serialize, Deserialize)]
334pub struct PresenceChange {
335 pub joins: HashMap<String, serde_json::Value>,
336 pub leaves: HashMap<String, serde_json::Value>,
337}
338
339#[derive(Debug, Clone, Default)]
341pub struct PresenceState {
342 pub state: HashMap<String, serde_json::Value>,
343}
344
345impl PresenceState {
346 pub fn new() -> Self {
348 Self {
349 state: HashMap::new(),
350 }
351 }
352
353 pub fn sync(&mut self, presence_diff: &PresenceChange) {
355 for key in presence_diff.leaves.keys() {
357 self.state.remove(key);
358 }
359
360 for (key, value) in &presence_diff.joins {
362 self.state.insert(key.clone(), value.clone());
363 }
364 }
365
366 pub fn list(&self) -> Vec<(String, serde_json::Value)> {
368 self.state
369 .iter()
370 .map(|(k, v)| (k.clone(), v.clone()))
371 .collect()
372 }
373
374 pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
376 self.state.get(key)
377 }
378}
379
380#[derive(Debug, Clone, Serialize)]
382pub struct PresenceChanges {
383 event: String,
384}
385
386impl Default for PresenceChanges {
387 fn default() -> Self {
388 Self::new()
389 }
390}
391
392impl PresenceChanges {
393 pub fn new() -> Self {
395 Self {
396 event: "presence_state".to_string(),
397 }
398 }
399}
400
401pub struct Subscription {
403 id: String,
404 channel: Arc<Channel>,
405}
406
407impl Drop for Subscription {
408 fn drop(&mut self) {
409 let channel = self.channel.clone();
411 let id = self.id.clone();
412 tokio::spawn(async move {
413 let _ = channel.unsubscribe(&id).await;
414 });
415 }
416}
417
418type CallbackFn = Box<dyn Fn(Payload) + Send + Sync>;
420
421struct Channel {
422 topic: String,
423 socket: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
424 callbacks: RwLock<HashMap<String, CallbackFn>>,
425}
426
427#[derive(Debug, Clone, Copy, PartialEq, Eq)]
429pub enum ConnectionState {
430 Disconnected,
431 Connecting,
432 Connected,
433 Reconnecting,
434}
435
436#[derive(Debug, Clone)]
438pub struct RealtimeClientOptions {
439 pub auto_reconnect: bool,
441 pub max_reconnect_attempts: Option<u32>,
443 pub reconnect_interval: u64,
445 pub reconnect_backoff_factor: f64,
447 pub max_reconnect_interval: u64,
449 pub heartbeat_interval: u64,
451}
452
453impl Default for RealtimeClientOptions {
454 fn default() -> Self {
455 Self {
456 auto_reconnect: true,
457 max_reconnect_attempts: Some(20),
458 reconnect_interval: 1000,
459 reconnect_backoff_factor: 1.5,
460 max_reconnect_interval: 60000,
461 heartbeat_interval: 30000,
462 }
463 }
464}
465
466pub struct RealtimeClient {
468 url: String,
469 key: String,
470 next_ref: AtomicU32,
471 channels: Arc<RwLock<HashMap<String, Arc<Channel>>>>,
472 socket: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
473 options: RealtimeClientOptions,
474 state: Arc<RwLock<ConnectionState>>,
475 reconnect_attempts: AtomicU32,
476 is_manually_closed: AtomicBool,
477 state_change: broadcast::Sender<ConnectionState>,
478}
479
480impl RealtimeClient {
481 pub fn new(url: &str, key: &str) -> Self {
483 let (state_sender, _) = broadcast::channel(100);
484
485 Self {
486 url: url.to_string(),
487 key: key.to_string(),
488 next_ref: AtomicU32::new(0),
489 channels: Arc::new(RwLock::new(HashMap::new())),
490 socket: Arc::new(RwLock::new(None)),
491 options: RealtimeClientOptions::default(),
492 state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
493 reconnect_attempts: AtomicU32::new(0),
494 is_manually_closed: AtomicBool::new(false),
495 state_change: state_sender,
496 }
497 }
498
499 pub fn new_with_options(url: &str, key: &str, options: RealtimeClientOptions) -> Self {
501 let (state_sender, _) = broadcast::channel(100);
502
503 Self {
504 url: url.to_string(),
505 key: key.to_string(),
506 next_ref: AtomicU32::new(0),
507 channels: Arc::new(RwLock::new(HashMap::new())),
508 socket: Arc::new(RwLock::new(None)),
509 options,
510 state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
511 reconnect_attempts: AtomicU32::new(0),
512 is_manually_closed: AtomicBool::new(false),
513 state_change: state_sender,
514 }
515 }
516
517 pub fn on_state_change(&self) -> broadcast::Receiver<ConnectionState> {
519 self.state_change.subscribe()
520 }
521
522 pub async fn get_connection_state(&self) -> ConnectionState {
524 *self.state.read().await
525 }
526
527 pub fn channel(&self, topic: &str) -> ChannelBuilder {
529 ChannelBuilder {
530 client: self,
531 topic: topic.to_string(),
532 callbacks: HashMap::new(),
533 }
534 }
535
536 fn next_ref(&self) -> String {
538 self.next_ref.fetch_add(1, Ordering::SeqCst).to_string()
539 }
540
541 async fn set_connection_state(&self, state: ConnectionState) {
543 let mut state_guard = self.state.write().await;
544 let old_state = *state_guard;
545 *state_guard = state;
546
547 if old_state != state {
548 let _ = self.state_change.send(state);
549 }
550 }
551
552 fn connect(
554 &self,
555 ) -> impl std::future::Future<Output = Result<(), RealtimeError>> + Send + 'static {
556 let client_clone = self.clone();
557 async move {
558 if client_clone.get_connection_state().await == ConnectionState::Connected {
559 return Ok(());
560 }
561
562 client_clone
563 .set_connection_state(ConnectionState::Connecting)
564 .await;
565
566 let mut url = Url::parse(&client_clone.url)?;
568 url.query_pairs_mut()
569 .append_pair("apikey", &client_clone.key);
570 url.query_pairs_mut().append_pair("vsn", "1.0.0");
571
572 let (ws_stream, _) = connect_async(url).await?;
574 let (mut write, read) = ws_stream.split();
575
576 let (tx, mut rx) = mpsc::channel::<Message>(32);
578
579 let client_state = client_clone.state.clone();
581 let auto_reconnect = client_clone.options.auto_reconnect;
582 let manual_close = client_clone.is_manually_closed.load(Ordering::SeqCst);
583 let reconnect_fn = client_clone.clone();
584
585 tokio::task::spawn(async move {
587 read.for_each(|message| async {
588 match message {
589 Ok(msg) => {
590 match msg {
591 Message::Text(text) => {
592 if let Ok(json) =
594 serde_json::from_str::<serde_json::Value>(&text)
595 {
596 let topic = json
598 .get("topic")
599 .and_then(|v| v.as_str())
600 .unwrap_or_default();
601 let event = json
602 .get("event")
603 .and_then(|v| v.as_str())
604 .unwrap_or_default();
605 let payload = json
606 .get("payload")
607 .cloned()
608 .unwrap_or(serde_json::json!({}));
609
610 if topic == "phoenix" && event == "phx_reply" {
612 let status = payload
613 .get("status")
614 .and_then(|v| v.as_str())
615 .unwrap_or_default();
616
617 if status == "ok" {
618 let mut state_guard = client_state.write().await;
619 *state_guard = ConnectionState::Connected;
620
621 reconnect_fn
623 .reconnect_attempts
624 .store(0, Ordering::SeqCst);
625
626 let _ = reconnect_fn
628 .state_change
629 .send(ConnectionState::Connected);
630 }
631 }
632 else if let Some(payload_data) = payload.get("data") {
634 let decoded_payload = Payload {
635 data: payload_data.clone(),
636 event_type: payload
637 .get("type")
638 .and_then(|v| v.as_str())
639 .map(|s| s.to_string()),
640 timestamp: payload
641 .get("timestamp")
642 .and_then(|v| v.as_i64()),
643 };
644
645 if let Ok(channels_guard) =
646 reconnect_fn.channels.try_read()
647 {
648 if let Some(channel) = channels_guard.get(topic) {
649 if let Ok(callbacks_guard) =
650 channel.callbacks.try_read()
651 {
652 for callback in callbacks_guard.values() {
653 callback(decoded_payload.clone());
654 }
655 }
656 }
657 }
658 }
659 }
660 }
661 Message::Close(_) => {
662 if !manual_close && auto_reconnect {
664 let mut state_guard = client_state.write().await;
665 *state_guard = ConnectionState::Reconnecting;
666
667 let reconnect_client = reconnect_fn.clone();
669 reconnect_client.reconnect().await;
670 } else {
671 let mut state_guard = client_state.write().await;
672 *state_guard = ConnectionState::Disconnected;
673 }
674 }
675 _ => {}
676 }
677 }
678 Err(e) => {
679 eprintln!("WebSocket error: {}", e);
680
681 if !manual_close && auto_reconnect {
683 let mut state_guard = client_state.write().await;
684 *state_guard = ConnectionState::Reconnecting;
685
686 let reconnect_client = reconnect_fn.clone();
688 reconnect_client.reconnect().await;
689 } else {
690 let mut state_guard = client_state.write().await;
691 *state_guard = ConnectionState::Disconnected;
692 }
693 }
694 }
695 })
696 .await;
697 });
698
699 tokio::task::spawn(async move {
701 while let Some(msg) = rx.recv().await {
702 if let Err(e) = write.send(msg).await {
703 eprintln!("Error sending message: {}", e);
704
705 if auto_reconnect && !manual_close {
707 break;
709 }
710 }
711 }
712 });
713
714 let mut socket_guard = client_clone.socket.write().await;
716 *socket_guard = Some(tx.clone());
717
718 let socket_clone = tx.clone();
720 let heartbeat_interval = client_clone.options.heartbeat_interval;
721 let is_manually_closed = Arc::new(AtomicBool::new(
722 client_clone.is_manually_closed.load(Ordering::SeqCst),
723 ));
724
725 tokio::task::spawn(async move {
726 loop {
727 sleep(Duration::from_millis(heartbeat_interval)).await;
728
729 if is_manually_closed.load(Ordering::SeqCst) {
730 break;
731 }
732
733 let heartbeat_msg = serde_json::json!({
735 "topic": "phoenix",
736 "event": "heartbeat",
737 "payload": {},
738 "ref": null
739 });
740
741 if socket_clone
742 .send(Message::Text(heartbeat_msg.to_string()))
743 .await
744 .is_err()
745 {
746 break;
747 }
748 }
749 });
750
751 Ok(())
752 }
753 }
754
755 pub async fn disconnect(&self) -> Result<(), RealtimeError> {
757 self.is_manually_closed.store(true, Ordering::SeqCst);
758
759 let mut socket_guard = self.socket.write().await;
760 if let Some(tx) = socket_guard.take() {
761 let close_msg = Message::Close(None);
763 let _ = tx.send(close_msg).await;
764 }
765
766 self.set_connection_state(ConnectionState::Disconnected)
767 .await;
768
769 Ok(())
770 }
771
772 fn reconnect(&self) -> impl std::future::Future<Output = ()> + Send + 'static {
774 let client_clone = self.clone();
775 async move {
776 if !client_clone.options.auto_reconnect
777 || client_clone.is_manually_closed.load(Ordering::SeqCst)
778 {
779 return;
780 }
781
782 client_clone
783 .set_connection_state(ConnectionState::Reconnecting)
784 .await;
785
786 let current_attempt = client_clone
788 .reconnect_attempts
789 .fetch_add(1, Ordering::SeqCst)
790 + 1;
791 if let Some(max) = client_clone.options.max_reconnect_attempts {
792 if current_attempt > max {
793 client_clone
794 .set_connection_state(ConnectionState::Disconnected)
795 .await;
796 return;
797 }
798 }
799
800 let base_interval = client_clone.options.reconnect_interval as f64;
802 let factor = client_clone
803 .options
804 .reconnect_backoff_factor
805 .powi(current_attempt as i32 - 1);
806 let interval = (base_interval * factor)
807 .min(client_clone.options.max_reconnect_interval as f64)
808 as u64;
809
810 sleep(Duration::from_millis(interval)).await;
812
813 let _ = client_clone.connect().await;
815
816 if client_clone.get_connection_state().await == ConnectionState::Connected {
818 let channels_guard = client_clone.channels.read().await;
819
820 for (topic, _channel) in channels_guard.iter() {
821 let join_msg = serde_json::json!({
822 "topic": topic,
823 "event": "phx_join",
824 "payload": {},
825 "ref": client_clone.next_ref()
826 });
827
828 let socket_guard = client_clone.socket.read().await;
829 if let Some(tx) = &*socket_guard {
830 let _ = tx.send(Message::Text(join_msg.to_string())).await;
831 }
832 }
833 }
834 }
835 }
836}
837
838impl Clone for RealtimeClient {
840 fn clone(&self) -> Self {
841 Self {
842 url: self.url.clone(),
843 key: self.key.clone(),
844 next_ref: AtomicU32::new(self.next_ref.load(Ordering::SeqCst)),
845 channels: self.channels.clone(),
846 socket: self.socket.clone(),
847 options: self.options.clone(),
848 state: self.state.clone(),
849 reconnect_attempts: AtomicU32::new(self.reconnect_attempts.load(Ordering::SeqCst)),
850 is_manually_closed: AtomicBool::new(self.is_manually_closed.load(Ordering::SeqCst)),
851 state_change: self.state_change.clone(),
852 }
853 }
854}
855
856pub struct ChannelBuilder<'a> {
858 client: &'a RealtimeClient,
859 topic: String,
860 callbacks: HashMap<String, Box<dyn Fn(Payload) + Send + Sync>>,
861}
862
863impl<'a> ChannelBuilder<'a> {
864 pub fn on<F>(mut self, changes: DatabaseChanges, callback: F) -> Self
866 where
867 F: Fn(Payload) + Send + Sync + 'static,
868 {
869 let topic_key = serde_json::to_string(&changes).unwrap_or_default();
870 self.callbacks.insert(topic_key, Box::new(callback));
871 self
872 }
873
874 pub fn on_broadcast<F>(mut self, changes: BroadcastChanges, callback: F) -> Self
876 where
877 F: Fn(Payload) + Send + Sync + 'static,
878 {
879 let topic_key = format!("broadcast:{}", changes.event);
880 self.callbacks.insert(topic_key, Box::new(callback));
881 self
882 }
883
884 pub fn on_presence<F>(mut self, callback: F) -> Self
886 where
887 F: Fn(PresenceChange) + Send + Sync + 'static,
888 {
889 let presence_callback = move |payload: Payload| {
891 if let Ok(presence_diff) =
892 serde_json::from_value::<PresenceChange>(payload.data.clone())
893 {
894 callback(presence_diff);
895 }
896 };
897
898 self.callbacks
899 .insert("presence".to_string(), Box::new(presence_callback));
900 self
901 }
902
903 pub async fn subscribe(self) -> Result<Subscription, RealtimeError> {
905 let state = self.client.get_connection_state().await;
907 match state {
908 ConnectionState::Disconnected | ConnectionState::Reconnecting => {
909 if self.client.options.auto_reconnect {
911 let connect_future = self.client.connect();
912 tokio::spawn(connect_future);
913
914 for _ in 0..10 {
916 tokio::time::sleep(Duration::from_millis(100)).await;
917 let new_state = self.client.get_connection_state().await;
918 if matches!(new_state, ConnectionState::Connected) {
919 break;
920 }
921 }
922 } else {
923 return Err(RealtimeError::ConnectionError(
924 "Client is disconnected and auto-reconnect is disabled".to_string(),
925 ));
926 }
927 }
928 ConnectionState::Connecting => {
929 for _ in 0..20 {
931 tokio::time::sleep(Duration::from_millis(100)).await;
932 let new_state = self.client.get_connection_state().await;
933 if matches!(new_state, ConnectionState::Connected) {
934 break;
935 }
936 }
937
938 let final_state = self.client.get_connection_state().await;
939 if !matches!(final_state, ConnectionState::Connected) {
940 return Err(RealtimeError::ConnectionError(
941 "Failed to connect to realtime server within timeout".to_string(),
942 ));
943 }
944 }
945 ConnectionState::Connected => {
946 }
948 }
949
950 let channels = self.client.channels.read().await;
952 if let Some(channel) = channels.get(&self.topic) {
953 let mut callbacks = channel.callbacks.write().await;
955 for (key, callback) in self.callbacks {
956 callbacks.insert(key, callback);
957 }
958
959 return Ok(Subscription {
960 id: self.client.next_ref(),
961 channel: channel.clone(),
962 });
963 }
964
965 let channel = Arc::new(Channel {
967 topic: self.topic.clone(),
968 socket: self.client.socket.clone(),
969 callbacks: RwLock::new(self.callbacks),
970 });
971
972 let socket_guard = self.client.socket.read().await;
974 if let Some(socket) = &*socket_guard {
975 let ref_id = self.client.next_ref();
976 let join_payload = json!({
977 "event": "phx_join",
978 "topic": self.topic,
979 "payload": {},
980 "ref": ref_id
981 });
982
983 let message = Message::Text(
984 serde_json::to_string(&join_payload).map_err(RealtimeError::SerializationError)?,
985 );
986
987 socket.send(message).await.map_err(|e| {
988 RealtimeError::SubscriptionError(format!("Failed to send join message: {}", e))
989 })?;
990
991 drop(socket_guard);
993 let mut channels = self.client.channels.write().await;
994 channels.insert(self.topic.clone(), channel.clone());
995 } else {
996 return Err(RealtimeError::ConnectionError(
997 "WebSocket connection not available".to_string(),
998 ));
999 }
1000
1001 Ok(Subscription {
1002 id: self.client.next_ref(),
1003 channel,
1004 })
1005 }
1006
1007 pub async fn track_presence(
1009 &self,
1010 user_id: &str,
1011 user_data: serde_json::Value,
1012 ) -> Result<(), RealtimeError> {
1013 let socket_guard = self.client.socket.read().await;
1014 if let Some(tx) = &*socket_guard {
1015 let presence_msg = serde_json::json!({
1016 "topic": self.topic,
1017 "event": "presence",
1018 "payload": {
1019 "user_id": user_id,
1020 "user_data": user_data
1021 },
1022 "ref": self.client.next_ref()
1023 });
1024
1025 tx.send(Message::Text(presence_msg.to_string()))
1026 .await
1027 .map_err(|_| {
1028 RealtimeError::ChannelError("Failed to send presence message".to_string())
1029 })?;
1030
1031 Ok(())
1032 } else {
1033 Err(RealtimeError::ConnectionError(
1034 "Socket not connected".to_string(),
1035 ))
1036 }
1037 }
1038}
1039
1040impl Channel {
1041 async fn unsubscribe(&self, id: &str) -> Result<(), RealtimeError> {
1043 let mut callbacks_guard = self.callbacks.write().await;
1045 callbacks_guard.remove(id);
1046
1047 if callbacks_guard.is_empty() {
1049 drop(callbacks_guard);
1050
1051 let unsubscribe_message = serde_json::json!({
1053 "topic": self.topic,
1054 "event": "phx_leave",
1055 "payload": {},
1056 "ref": id,
1057 });
1058
1059 let socket_guard = self.socket.read().await;
1060 if let Some(tx) = &*socket_guard {
1061 tx.send(Message::Text(unsubscribe_message.to_string()))
1062 .await
1063 .map_err(|_| {
1064 RealtimeError::SubscriptionError(
1065 "Failed to send unsubscription message".to_string(),
1066 )
1067 })?;
1068 } else {
1069 return Err(RealtimeError::ConnectionError(
1070 "WebSocket not connected".to_string(),
1071 ));
1072 }
1073 }
1074
1075 Ok(())
1076 }
1077}
1078
1079impl From<tokio::sync::mpsc::error::SendError<Message>> for RealtimeError {
1080 fn from(err: tokio::sync::mpsc::error::SendError<Message>) -> Self {
1081 RealtimeError::ChannelError(format!("Failed to send message: {}", err))
1082 }
1083}
1084
1085#[cfg(test)]
1086mod tests {
1087 use super::*;
1088
1089 #[tokio::test]
1090 async fn test_reconnection() {
1091 let client = super::RealtimeClient::new("https://example.supabase.co", "test-key");
1096
1097 let mut status_receiver = client.on_state_change();
1099
1100 tokio::spawn(async move {
1102 tokio::time::sleep(Duration::from_millis(100)).await;
1103
1104 while let Ok(state) = status_receiver.recv().await {
1105 println!("Connection state changed: {:?}", state);
1106
1107 if state == super::ConnectionState::Connected {
1108 break;
1110 }
1111 }
1112 });
1113 }
1114}