1use std::sync::Arc;
7use std::collections::HashMap;
8use tokio_tungstenite::connect_async;
9use tokio_tungstenite::tungstenite::Message;
10use serde::{Serialize, Deserialize};
11use thiserror::Error;
12use futures_util::{StreamExt, SinkExt};
13use tokio::sync::mpsc;
14use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
15use tokio::sync::{broadcast, RwLock};
16use std::time::Duration;
17use tokio::time::sleep;
18use url::Url;
19
20#[derive(Error, Debug)]
22pub enum RealtimeError {
23 #[error("WebSocket error: {0}")]
24 WebSocketError(#[from] tokio_tungstenite::tungstenite::Error),
25
26 #[error("URL parse error: {0}")]
27 UrlParseError(#[from] url::ParseError),
28
29 #[error("JSON serialization error: {0}")]
30 SerializationError(#[from] serde_json::Error),
31
32 #[error("Subscription error: {0}")]
33 SubscriptionError(String),
34
35 #[error("Channel error: {0}")]
36 ChannelError(String),
37
38 #[error("Connection error: {0}")]
39 ConnectionError(String),
40}
41
42impl RealtimeError {
43 pub fn new(message: String) -> Self {
44 Self::ChannelError(message)
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
50#[serde(rename_all = "lowercase")]
51pub enum ChannelEvent {
52 Insert,
53 Update,
54 Delete,
55 All,
56}
57
58impl std::fmt::Display for ChannelEvent {
59 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
60 match self {
61 Self::Insert => write!(f, "INSERT"),
62 Self::Update => write!(f, "UPDATE"),
63 Self::Delete => write!(f, "DELETE"),
64 Self::All => write!(f, "ALL"),
65 }
66 }
67}
68
69#[derive(Debug, Clone, Serialize)]
71pub struct DatabaseFilter {
72 pub column: String,
74 pub operator: FilterOperator,
76 pub value: serde_json::Value,
78}
79
80#[derive(Debug, Clone, PartialEq, Serialize)]
82pub enum FilterOperator {
83 Eq,
85 Neq,
87 Gt,
89 Gte,
91 Lt,
93 Lte,
95 In,
97 NotIn,
99 ContainedBy,
101 Contains,
103 ContainedByArray,
105 Like,
107 ILike,
109}
110
111impl ToString for FilterOperator {
112 fn to_string(&self) -> String {
113 match self {
114 FilterOperator::Eq => "eq".to_string(),
115 FilterOperator::Neq => "neq".to_string(),
116 FilterOperator::Gt => "gt".to_string(),
117 FilterOperator::Gte => "gte".to_string(),
118 FilterOperator::Lt => "lt".to_string(),
119 FilterOperator::Lte => "lte".to_string(),
120 FilterOperator::In => "in".to_string(),
121 FilterOperator::NotIn => "not.in".to_string(),
122 FilterOperator::ContainedBy => "contained_by".to_string(),
123 FilterOperator::Contains => "contains".to_string(),
124 FilterOperator::ContainedByArray => "contained_by_array".to_string(),
125 FilterOperator::Like => "like".to_string(),
126 FilterOperator::ILike => "ilike".to_string(),
127 }
128 }
129}
130
131#[derive(Debug, Clone, Serialize)]
133pub struct DatabaseChanges {
134 schema: String,
135 table: String,
136 events: Vec<ChannelEvent>,
137 filter: Option<Vec<DatabaseFilter>>,
138}
139
140impl DatabaseChanges {
141 pub fn new(table: &str) -> Self {
143 Self {
144 schema: "public".to_string(),
145 table: table.to_string(),
146 events: Vec::new(),
147 filter: None,
148 }
149 }
150
151 pub fn schema(mut self, schema: &str) -> Self {
153 self.schema = schema.to_string();
154 self
155 }
156
157 pub fn event(mut self, event: ChannelEvent) -> Self {
159 if !self.events.contains(&event) {
160 self.events.push(event);
161 }
162 self
163 }
164
165 pub fn filter(mut self, filter: DatabaseFilter) -> Self {
167 if self.filter.is_none() {
168 self.filter = Some(vec![filter]);
169 } else {
170 self.filter.as_mut().unwrap().push(filter);
171 }
172 self
173 }
174
175 pub fn eq<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
177 self.filter(DatabaseFilter {
178 column: column.to_string(),
179 operator: FilterOperator::Eq,
180 value: value.into(),
181 })
182 }
183
184 pub fn neq<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
186 self.filter(DatabaseFilter {
187 column: column.to_string(),
188 operator: FilterOperator::Neq,
189 value: value.into(),
190 })
191 }
192
193 pub fn gt<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
195 self.filter(DatabaseFilter {
196 column: column.to_string(),
197 operator: FilterOperator::Gt,
198 value: value.into(),
199 })
200 }
201
202 pub fn gte<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
204 self.filter(DatabaseFilter {
205 column: column.to_string(),
206 operator: FilterOperator::Gte,
207 value: value.into(),
208 })
209 }
210
211 pub fn lt<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
213 self.filter(DatabaseFilter {
214 column: column.to_string(),
215 operator: FilterOperator::Lt,
216 value: value.into(),
217 })
218 }
219
220 pub fn lte<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
222 self.filter(DatabaseFilter {
223 column: column.to_string(),
224 operator: FilterOperator::Lte,
225 value: value.into(),
226 })
227 }
228
229 pub fn in_values<T: Into<serde_json::Value>>(self, column: &str, values: Vec<T>) -> Self {
231 let json_values: Vec<serde_json::Value> = values.into_iter().map(|v| v.into()).collect();
232 self.filter(DatabaseFilter {
233 column: column.to_string(),
234 operator: FilterOperator::In,
235 value: serde_json::Value::Array(json_values),
236 })
237 }
238
239 pub fn contains<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
241 self.filter(DatabaseFilter {
242 column: column.to_string(),
243 operator: FilterOperator::Contains,
244 value: value.into(),
245 })
246 }
247
248 pub fn like(self, column: &str, pattern: &str) -> Self {
250 self.filter(DatabaseFilter {
251 column: column.to_string(),
252 operator: FilterOperator::Like,
253 value: serde_json::Value::String(pattern.to_string()),
254 })
255 }
256
257 pub fn ilike(self, column: &str, pattern: &str) -> Self {
259 self.filter(DatabaseFilter {
260 column: column.to_string(),
261 operator: FilterOperator::ILike,
262 value: serde_json::Value::String(pattern.to_string()),
263 })
264 }
265
266 fn to_channel_config(&self) -> serde_json::Value {
268 let mut events_str = String::new();
269
270 for (i, event) in self.events.iter().enumerate() {
272 if i > 0 {
273 events_str.push(',');
274 }
275 events_str.push_str(&event.to_string());
276 }
277
278 if events_str.is_empty() {
280 events_str = "*".to_string();
281 }
282
283 let mut config = serde_json::json!({
284 "schema": self.schema,
285 "table": self.table,
286 "event": events_str,
287 });
288
289 if let Some(filters) = &self.filter {
291 let mut filter_obj = serde_json::Map::new();
292
293 for filter in filters {
294 let filter_key = format!("{}:{}", filter.column, filter.operator.to_string());
295 filter_obj.insert(filter_key, filter.value.clone());
296 }
297
298 if !filter_obj.is_empty() {
299 config["filter"] = serde_json::Value::Object(filter_obj);
300 }
301 }
302
303 config
304 }
305}
306
307#[derive(Debug, Clone, Serialize)]
309pub struct BroadcastChanges {
310 event: String,
311}
312
313impl BroadcastChanges {
314 pub fn new(event: &str) -> Self {
316 Self {
317 event: event.to_string(),
318 }
319 }
320}
321
322#[derive(Debug, Clone, Serialize, Deserialize)]
324pub struct Payload {
325 pub data: serde_json::Value,
326 pub event_type: Option<String>,
327 pub timestamp: Option<i64>,
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize)]
332pub struct PresenceChange {
333 pub joins: HashMap<String, serde_json::Value>,
334 pub leaves: HashMap<String, serde_json::Value>,
335}
336
337#[derive(Debug, Clone, Default)]
339pub struct PresenceState {
340 pub state: HashMap<String, serde_json::Value>,
341}
342
343impl PresenceState {
344 pub fn new() -> Self {
346 Self {
347 state: HashMap::new(),
348 }
349 }
350
351 pub fn sync(&mut self, presence_diff: &PresenceChange) {
353 for key in presence_diff.leaves.keys() {
355 self.state.remove(key);
356 }
357
358 for (key, value) in &presence_diff.joins {
360 self.state.insert(key.clone(), value.clone());
361 }
362 }
363
364 pub fn list(&self) -> Vec<(String, serde_json::Value)> {
366 self.state.iter()
367 .map(|(k, v)| (k.clone(), v.clone()))
368 .collect()
369 }
370
371 pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
373 self.state.get(key)
374 }
375}
376
377#[derive(Debug, Clone, Serialize)]
379pub struct PresenceChanges {
380 event: String,
381}
382
383impl PresenceChanges {
384 pub fn new() -> Self {
386 Self {
387 event: "presence_state".to_string(),
388 }
389 }
390}
391
392pub struct Subscription {
394 id: String,
395 channel: Arc<Channel>,
396}
397
398impl Drop for Subscription {
399 fn drop(&mut self) {
400 let channel = self.channel.clone();
402 let id = self.id.clone();
403 tokio::spawn(async move {
404 let _ = channel.unsubscribe(&id).await;
405 });
406 }
407}
408
409struct Channel {
410 topic: String,
411 socket: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
412 callbacks: RwLock<HashMap<String, Box<dyn Fn(Payload) + Send + Sync>>>,
413}
414
415#[derive(Debug, Clone, Copy, PartialEq, Eq)]
417pub enum ConnectionState {
418 Disconnected,
419 Connecting,
420 Connected,
421 Reconnecting,
422}
423
424#[derive(Debug, Clone)]
426pub struct RealtimeClientOptions {
427 pub auto_reconnect: bool,
429 pub max_reconnect_attempts: Option<u32>,
431 pub reconnect_interval: u64,
433 pub reconnect_backoff_factor: f64,
435 pub max_reconnect_interval: u64,
437 pub heartbeat_interval: u64,
439}
440
441impl Default for RealtimeClientOptions {
442 fn default() -> Self {
443 Self {
444 auto_reconnect: true,
445 max_reconnect_attempts: Some(20),
446 reconnect_interval: 1000,
447 reconnect_backoff_factor: 1.5,
448 max_reconnect_interval: 60000,
449 heartbeat_interval: 30000,
450 }
451 }
452}
453
454pub struct RealtimeClient {
456 url: String,
457 key: String,
458 next_ref: AtomicU32,
459 channels: Arc<RwLock<HashMap<String, Arc<Channel>>>>,
460 socket: Arc<RwLock<Option<mpsc::Sender<Message>>>>,
461 options: RealtimeClientOptions,
462 state: Arc<RwLock<ConnectionState>>,
463 reconnect_attempts: AtomicU32,
464 is_manually_closed: AtomicBool,
465 state_change: broadcast::Sender<ConnectionState>,
466}
467
468impl RealtimeClient {
469 pub fn new(url: &str, key: &str) -> Self {
471 let (state_sender, _) = broadcast::channel(100);
472
473 Self {
474 url: url.to_string(),
475 key: key.to_string(),
476 next_ref: AtomicU32::new(0),
477 channels: Arc::new(RwLock::new(HashMap::new())),
478 socket: Arc::new(RwLock::new(None)),
479 options: RealtimeClientOptions::default(),
480 state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
481 reconnect_attempts: AtomicU32::new(0),
482 is_manually_closed: AtomicBool::new(false),
483 state_change: state_sender,
484 }
485 }
486
487 pub fn new_with_options(url: &str, key: &str, options: RealtimeClientOptions) -> Self {
489 let (state_sender, _) = broadcast::channel(100);
490
491 Self {
492 url: url.to_string(),
493 key: key.to_string(),
494 next_ref: AtomicU32::new(0),
495 channels: Arc::new(RwLock::new(HashMap::new())),
496 socket: Arc::new(RwLock::new(None)),
497 options,
498 state: Arc::new(RwLock::new(ConnectionState::Disconnected)),
499 reconnect_attempts: AtomicU32::new(0),
500 is_manually_closed: AtomicBool::new(false),
501 state_change: state_sender,
502 }
503 }
504
505 pub fn on_state_change(&self) -> broadcast::Receiver<ConnectionState> {
507 self.state_change.subscribe()
508 }
509
510 pub async fn get_connection_state(&self) -> ConnectionState {
512 *self.state.read().await
513 }
514
515 pub fn channel(&self, topic: &str) -> ChannelBuilder {
517 ChannelBuilder {
518 client: self,
519 topic: topic.to_string(),
520 callbacks: HashMap::new(),
521 }
522 }
523
524 fn next_ref(&self) -> String {
526 self.next_ref.fetch_add(1, Ordering::SeqCst).to_string()
527 }
528
529 async fn set_connection_state(&self, state: ConnectionState) {
531 let mut state_guard = self.state.write().await;
532 let old_state = *state_guard;
533 *state_guard = state;
534
535 if old_state != state {
536 let _ = self.state_change.send(state);
537 }
538 }
539
540 fn connect(&self) -> impl std::future::Future<Output = Result<(), RealtimeError>> + Send + 'static {
542 let client_clone = self.clone();
543 async move {
544 if client_clone.get_connection_state().await == ConnectionState::Connected {
545 return Ok(());
546 }
547
548 client_clone.set_connection_state(ConnectionState::Connecting).await;
549
550 let mut url = Url::parse(&client_clone.url)?;
552 url.query_pairs_mut().append_pair("apikey", &client_clone.key);
553 url.query_pairs_mut().append_pair("vsn", "1.0.0");
554
555 let (ws_stream, _) = connect_async(url).await?;
557 let (mut write, read) = ws_stream.split();
558
559 let (tx, mut rx) = mpsc::channel::<Message>(32);
561
562 let client_state = client_clone.state.clone();
564 let auto_reconnect = client_clone.options.auto_reconnect;
565 let manual_close = client_clone.is_manually_closed.load(Ordering::SeqCst);
566 let reconnect_fn = client_clone.clone();
567
568 tokio::task::spawn(async move {
570 read.for_each(|message| async {
571 match message {
572 Ok(msg) => {
573 match msg {
574 Message::Text(text) => {
575 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&text) {
577 let topic = json.get("topic").and_then(|v| v.as_str()).unwrap_or_default();
579 let event = json.get("event").and_then(|v| v.as_str()).unwrap_or_default();
580 let payload = json.get("payload").cloned().unwrap_or(serde_json::json!({}));
581
582 if topic == "phoenix" && event == "phx_reply" {
584 let status = payload.get("status").and_then(|v| v.as_str()).unwrap_or_default();
585
586 if status == "ok" {
587 let mut state_guard = client_state.write().await;
588 *state_guard = ConnectionState::Connected;
589
590 reconnect_fn.reconnect_attempts.store(0, Ordering::SeqCst);
592
593 let _ = reconnect_fn.state_change.send(ConnectionState::Connected);
595 }
596 }
597 else if let Some(payload_data) = payload.get("data") {
599 let decoded_payload = Payload {
600 data: payload_data.clone(),
601 event_type: payload.get("type").and_then(|v| v.as_str()).map(|s| s.to_string()),
602 timestamp: payload.get("timestamp").and_then(|v| v.as_i64()),
603 };
604
605 if let Ok(channels_guard) = reconnect_fn.channels.try_read() {
606 if let Some(channel) = channels_guard.get(topic) {
607 if let Ok(callbacks_guard) = channel.callbacks.try_read() {
608 for callback in callbacks_guard.values() {
609 callback(decoded_payload.clone());
610 }
611 }
612 }
613 }
614 }
615 }
616 }
617 Message::Close(_) => {
618 if !manual_close && auto_reconnect {
620 let mut state_guard = client_state.write().await;
621 *state_guard = ConnectionState::Reconnecting;
622
623 let reconnect_client = reconnect_fn.clone();
625 reconnect_client.reconnect().await;
626 } else {
627 let mut state_guard = client_state.write().await;
628 *state_guard = ConnectionState::Disconnected;
629 }
630 }
631 _ => {}
632 }
633 }
634 Err(e) => {
635 eprintln!("WebSocket error: {}", e);
636
637 if !manual_close && auto_reconnect {
639 let mut state_guard = client_state.write().await;
640 *state_guard = ConnectionState::Reconnecting;
641
642 let reconnect_client = reconnect_fn.clone();
644 reconnect_client.reconnect().await;
645 } else {
646 let mut state_guard = client_state.write().await;
647 *state_guard = ConnectionState::Disconnected;
648 }
649 }
650 }
651 }).await;
652 });
653
654 tokio::task::spawn(async move {
656 while let Some(msg) = rx.recv().await {
657 if let Err(e) = write.send(msg).await {
658 eprintln!("Error sending message: {}", e);
659
660 if auto_reconnect && !manual_close {
662 break;
664 }
665 }
666 }
667 });
668
669 let mut socket_guard = client_clone.socket.write().await;
671 *socket_guard = Some(tx.clone());
672
673 let socket_clone = tx.clone();
675 let heartbeat_interval = client_clone.options.heartbeat_interval;
676 let is_manually_closed = Arc::new(AtomicBool::new(client_clone.is_manually_closed.load(Ordering::SeqCst)));
677
678 tokio::task::spawn(async move {
679 loop {
680 sleep(Duration::from_millis(heartbeat_interval)).await;
681
682 if is_manually_closed.load(Ordering::SeqCst) {
683 break;
684 }
685
686 let heartbeat_msg = serde_json::json!({
688 "topic": "phoenix",
689 "event": "heartbeat",
690 "payload": {},
691 "ref": null
692 });
693
694 if let Err(_) = socket_clone.send(Message::Text(heartbeat_msg.to_string())).await {
695 break;
696 }
697 }
698 });
699
700 Ok(())
701 }
702 }
703
704 pub async fn disconnect(&self) -> Result<(), RealtimeError> {
706 self.is_manually_closed.store(true, Ordering::SeqCst);
707
708 let mut socket_guard = self.socket.write().await;
709 if let Some(tx) = socket_guard.take() {
710 let close_msg = Message::Close(None);
712 let _ = tx.send(close_msg).await;
713 }
714
715 self.set_connection_state(ConnectionState::Disconnected).await;
716
717 Ok(())
718 }
719
720 fn reconnect(&self) -> impl std::future::Future<Output = ()> + Send + 'static {
722 let client_clone = self.clone();
723 async move {
724 if !client_clone.options.auto_reconnect || client_clone.is_manually_closed.load(Ordering::SeqCst) {
725 return;
726 }
727
728 client_clone.set_connection_state(ConnectionState::Reconnecting).await;
729
730 let current_attempt = client_clone.reconnect_attempts.fetch_add(1, Ordering::SeqCst) + 1;
732 if let Some(max) = client_clone.options.max_reconnect_attempts {
733 if current_attempt > max {
734 client_clone.set_connection_state(ConnectionState::Disconnected).await;
735 return;
736 }
737 }
738
739 let base_interval = client_clone.options.reconnect_interval as f64;
741 let factor = client_clone.options.reconnect_backoff_factor.powi(current_attempt as i32 - 1);
742 let interval = (base_interval * factor).min(client_clone.options.max_reconnect_interval as f64) as u64;
743
744 sleep(Duration::from_millis(interval)).await;
746
747 let _ = client_clone.connect().await;
749
750 if client_clone.get_connection_state().await == ConnectionState::Connected {
752 let channels_guard = client_clone.channels.read().await;
753
754 for (topic, _channel) in channels_guard.iter() {
755 let join_msg = serde_json::json!({
756 "topic": topic,
757 "event": "phx_join",
758 "payload": {},
759 "ref": client_clone.next_ref()
760 });
761
762 let socket_guard = client_clone.socket.read().await;
763 if let Some(tx) = &*socket_guard {
764 let _ = tx.send(Message::Text(join_msg.to_string())).await;
765 }
766 }
767 }
768 }
769 }
770}
771
772impl Clone for RealtimeClient {
774 fn clone(&self) -> Self {
775 Self {
776 url: self.url.clone(),
777 key: self.key.clone(),
778 next_ref: AtomicU32::new(self.next_ref.load(Ordering::SeqCst)),
779 channels: self.channels.clone(),
780 socket: self.socket.clone(),
781 options: self.options.clone(),
782 state: self.state.clone(),
783 reconnect_attempts: AtomicU32::new(self.reconnect_attempts.load(Ordering::SeqCst)),
784 is_manually_closed: AtomicBool::new(self.is_manually_closed.load(Ordering::SeqCst)),
785 state_change: self.state_change.clone(),
786 }
787 }
788}
789
790pub struct ChannelBuilder<'a> {
792 client: &'a RealtimeClient,
793 topic: String,
794 callbacks: HashMap<String, Box<dyn Fn(Payload) + Send + Sync>>,
795}
796
797impl<'a> ChannelBuilder<'a> {
798 pub fn on<F>(mut self, _changes: DatabaseChanges, callback: F) -> Self
800 where
801 F: Fn(Payload) + Send + Sync + 'static,
802 {
803 let id = self.client.next_ref();
804 self.callbacks.insert(id, Box::new(callback));
805 self
806 }
807
808 pub fn on_broadcast<F>(mut self, _changes: BroadcastChanges, callback: F) -> Self
810 where
811 F: Fn(Payload) + Send + Sync + 'static,
812 {
813 let id = self.client.next_ref();
814 self.callbacks.insert(id, Box::new(callback));
815 self
816 }
817
818 pub fn on_presence<F>(mut self, callback: F) -> Self
820 where
821 F: Fn(PresenceChange) + Send + Sync + 'static,
822 {
823 let id = self.client.next_ref();
824
825 let wrapper_callback = move |payload: Payload| {
826 if let Ok(presence_diff) = serde_json::from_value::<PresenceChange>(payload.data) {
827 callback(presence_diff);
828 }
829 };
830
831 self.callbacks.insert(id, Box::new(wrapper_callback));
832 self
833 }
834
835 pub async fn subscribe(self) -> Result<Subscription, RealtimeError> {
837 let client = self.client;
838 let topic = self.topic;
839
840 let current_state = client.get_connection_state().await;
842 if current_state == ConnectionState::Disconnected {
843 client.connect().await?;
844 }
845
846 let channel = {
848 let mut channels = client.channels.write().await;
849
850 if let Some(channel) = channels.get(&topic) {
851 channel.clone()
852 } else {
853 let channel = Arc::new(Channel {
854 topic: topic.clone(),
855 socket: client.socket.clone(),
856 callbacks: RwLock::new(HashMap::new()),
857 });
858
859 channels.insert(topic.clone(), channel.clone());
860 channel
861 }
862 };
863
864 {
866 let mut callbacks = channel.callbacks.write().await;
867 for (id, callback) in self.callbacks {
868 callbacks.insert(id.clone(), callback);
869 }
870 }
871
872 let join_ref = client.next_ref();
874 let join_msg = serde_json::json!({
875 "topic": topic,
876 "event": "phx_join",
877 "payload": {},
878 "ref": join_ref
879 });
880
881 let socket_guard = client.socket.read().await;
882 if let Some(tx) = &*socket_guard {
883 tx.send(Message::Text(join_msg.to_string())).await?;
884 } else {
885 return Err(RealtimeError::ConnectionError("Socket not connected".to_string()));
886 }
887
888 Ok(Subscription {
890 id: join_ref,
891 channel,
892 })
893 }
894
895 pub async fn track_presence(
897 &self,
898 user_id: &str,
899 user_data: serde_json::Value
900 ) -> Result<(), RealtimeError> {
901 let socket_guard = self.client.socket.read().await;
902 if let Some(tx) = &*socket_guard {
903 let presence_msg = serde_json::json!({
904 "topic": self.topic,
905 "event": "presence",
906 "payload": {
907 "user_id": user_id,
908 "user_data": user_data
909 },
910 "ref": self.client.next_ref()
911 });
912
913 tx.send(Message::Text(presence_msg.to_string())).await
914 .map_err(|_| RealtimeError::ChannelError("Failed to send presence message".to_string()))?;
915
916 Ok(())
917 } else {
918 Err(RealtimeError::ConnectionError("Socket not connected".to_string()))
919 }
920 }
921}
922
923impl Channel {
924 async fn unsubscribe(&self, id: &str) -> Result<(), RealtimeError> {
926 let mut callbacks_guard = self.callbacks.write().await;
928 callbacks_guard.remove(id);
929
930 if callbacks_guard.is_empty() {
932 drop(callbacks_guard);
933
934 let unsubscribe_message = serde_json::json!({
936 "topic": self.topic,
937 "event": "phx_leave",
938 "payload": {},
939 "ref": id,
940 });
941
942 let socket_guard = self.socket.read().await;
943 if let Some(tx) = &*socket_guard {
944 tx.send(Message::Text(unsubscribe_message.to_string())).await
945 .map_err(|_| RealtimeError::SubscriptionError("Failed to send unsubscription message".to_string()))?;
946 } else {
947 return Err(RealtimeError::ConnectionError("WebSocket not connected".to_string()));
948 }
949 }
950
951 Ok(())
952 }
953}
954
955impl From<tokio::sync::mpsc::error::SendError<Message>> for RealtimeError {
956 fn from(err: tokio::sync::mpsc::error::SendError<Message>) -> Self {
957 RealtimeError::ChannelError(format!("Failed to send message: {}", err))
958 }
959}
960
961#[cfg(test)]
962mod tests {
963 use super::*;
964
965 #[tokio::test]
966 async fn test_reconnection() {
967 let client = super::RealtimeClient::new("https://example.supabase.co", "test-key");
972
973 let mut status_receiver = client.on_state_change();
975
976 tokio::spawn(async move {
978 tokio::time::sleep(Duration::from_millis(100)).await;
979
980 while let Ok(state) = status_receiver.recv().await {
981 println!("Connection state changed: {:?}", state);
982
983 if state == super::ConnectionState::Connected {
984 break;
986 }
987 }
988 });
989 }
990}