1use crate::client::RealtimeClient; use crate::error::RealtimeError;
3use crate::filters::{DatabaseFilter, FilterOperator};
4use crate::message::{ChannelEvent, Payload, PresenceChange, RealtimeMessage};
5use log::{debug, error, info, trace}; use serde::Serialize;
7use serde_json::json;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
12use tokio::time::{timeout, Duration};
13#[derive(Debug, Clone, Serialize)]
17pub struct DatabaseChanges {
18 schema: String,
19 table: String,
20 events: Vec<ChannelEvent>,
21 filter: Option<Vec<DatabaseFilter>>,
22}
23
24impl DatabaseChanges {
25 pub fn new(table: &str) -> Self {
27 Self {
28 schema: "public".to_string(),
29 table: table.to_string(),
30 events: Vec::new(),
31 filter: None,
32 }
33 }
34
35 pub fn schema(mut self, schema: &str) -> Self {
37 self.schema = schema.to_string();
38 self
39 }
40
41 pub fn event(mut self, event: ChannelEvent) -> Self {
43 if !self.events.contains(&event) {
44 self.events.push(event);
45 }
46 self
47 }
48
49 pub fn filter(mut self, filter: DatabaseFilter) -> Self {
51 self.filter.get_or_insert_with(Vec::new).push(filter);
52 self
53 }
54
55 pub fn eq<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
58 self.filter(DatabaseFilter {
59 column: column.to_string(),
60 operator: FilterOperator::Eq,
61 value: value.into(),
62 })
63 }
64
65 pub fn neq<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
66 self.filter(DatabaseFilter {
67 column: column.to_string(),
68 operator: FilterOperator::Neq,
69 value: value.into(),
70 })
71 }
72
73 pub fn gt<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
74 self.filter(DatabaseFilter {
75 column: column.to_string(),
76 operator: FilterOperator::Gt,
77 value: value.into(),
78 })
79 }
80
81 pub fn gte<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
82 self.filter(DatabaseFilter {
83 column: column.to_string(),
84 operator: FilterOperator::Gte,
85 value: value.into(),
86 })
87 }
88
89 pub fn lt<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
90 self.filter(DatabaseFilter {
91 column: column.to_string(),
92 operator: FilterOperator::Lt,
93 value: value.into(),
94 })
95 }
96
97 pub fn lte<T: Into<serde_json::Value>>(self, column: &str, value: T) -> Self {
98 self.filter(DatabaseFilter {
99 column: column.to_string(),
100 operator: FilterOperator::Lte,
101 value: value.into(),
102 })
103 }
104
105 pub fn in_values<T: Into<serde_json::Value>>(self, column: &str, values: Vec<T>) -> Self {
106 self.filter(DatabaseFilter {
107 column: column.to_string(),
108 operator: FilterOperator::In,
109 value: values
110 .into_iter()
111 .map(|v| v.into())
112 .collect::<Vec<_>>()
113 .into(),
114 })
115 }
116
117 }
126
127#[derive(Debug, Clone, Serialize)]
129pub struct BroadcastChanges {
130 event: String, }
132
133impl BroadcastChanges {
134 pub fn new(event: &str) -> Self {
135 Self {
136 event: event.to_string(),
137 }
138 }
139
140 #[allow(dead_code)] pub(crate) fn get_event_name(&self) -> &str {
142 &self.event
143 }
144}
145
146#[derive(Debug, Clone, Default, Serialize)]
148pub struct PresenceChanges;
149
150impl PresenceChanges {
151 pub fn new() -> Self {
152 PresenceChanges }
155}
156
157pub struct Subscription {
159 id: String, channel: Arc<Channel>,
161}
162
163impl Drop for Subscription {
164 fn drop(&mut self) {
165 let id_clone = self.id.clone();
166 let channel_clone = self.channel.clone();
167 tokio::spawn(async move {
168 if let Err(e) = channel_clone.unsubscribe(&id_clone).await {
169 eprintln!("Error unsubscribing from channel: {}", e);
171 }
172 });
173 }
174}
175
176type CallbackFn = Box<dyn Fn(Payload) + Send + Sync>;
177type PresenceCallbackFn = Box<dyn Fn(PresenceChange) + Send + Sync>;
178
179pub(crate) struct Channel {
181 topic: String,
182 client: Arc<RealtimeClient>, callbacks: Arc<RwLock<HashMap<String, CallbackFn>>>,
184 presence_callbacks: Arc<RwLock<Vec<PresenceCallbackFn>>>,
185 state: Arc<RwLock<ChannelState>>,
187}
188
189#[derive(Debug, Clone, Copy, PartialEq, Eq)]
190pub(crate) enum ChannelState {
191 Closed,
192 Joining,
193 Joined,
194 Leaving,
195 Errored,
196}
197
198impl Channel {
199 pub(crate) fn new(topic: String, client: Arc<RealtimeClient>) -> Self {
200 debug!("Channel::new created for topic: {}", topic);
201 Self {
202 topic,
203 client,
204 callbacks: Arc::new(RwLock::new(HashMap::new())),
205 presence_callbacks: Arc::new(RwLock::new(Vec::new())),
206 state: Arc::new(RwLock::new(ChannelState::Closed)),
207 }
208 }
209
210 async fn set_state(&self, state: ChannelState) {
211 let mut current_state = self.state.write().await;
212 if *current_state != state {
213 info!(
214 "Channel '{}' state changing from {:?} to {:?}",
215 self.topic, *current_state, state
216 );
217 *current_state = state;
218 } else {
219 trace!(
220 "Channel '{}' state already {:?}, not changing.",
221 self.topic,
222 state
223 );
224 }
225 }
226
227 async fn join(&self) -> Result<(), RealtimeError> {
229 self.set_state(ChannelState::Joining).await;
230 let join_ref = self.client.next_ref();
231 info!(
232 "Channel '{}' sending join message with ref {}",
233 self.topic, join_ref
234 );
235 let join_msg = json!({
236 "topic": self.topic,
237 "event": ChannelEvent::PhoenixJoin,
238 "payload": {},
239 "ref": join_ref
240 });
241 self.client.send_message(join_msg).await
243 }
245
246 async fn unsubscribe(&self, id: &str) -> Result<(), RealtimeError> {
251 self.callbacks.write().await.remove(id);
253 println!(
258 "Subscription {} dropped. Channel {} might need explicit leave.",
259 id, self.topic
260 );
261 Ok(())
262 }
263
264 pub(crate) async fn handle_message(&self, message: RealtimeMessage) {
266 debug!(
267 "Channel '{}' handling message: event={:?}, ref={:?}",
268 self.topic, message.event, message.message_ref
269 );
270
271 match message.event {
272 ChannelEvent::PhoenixReply => {
273 info!(
275 "Channel '{}' received PhoenixReply: {:?}",
276 self.topic, message.payload
277 );
278 if *self.state.read().await == ChannelState::Joining {
279 self.set_state(ChannelState::Joined).await;
281 } else if *self.state.read().await == ChannelState::Leaving {
282 self.set_state(ChannelState::Closed).await;
283 }
284 }
285 ChannelEvent::PhoenixClose => {
286 info!(
287 "Channel '{}' received PhoenixClose. Setting state to Closed.",
288 self.topic
289 );
290 self.set_state(ChannelState::Closed).await;
291 }
292 ChannelEvent::PhoenixError => {
293 error!(
294 "Channel '{}' received PhoenixError: {:?}",
295 self.topic, message.payload
296 );
297 self.set_state(ChannelState::Errored).await;
298 }
299 ChannelEvent::PostgresChanges | ChannelEvent::Broadcast | ChannelEvent::Presence => {
300 let payload = Payload {
302 data: message.payload.clone(), event_type: Some(message.event.to_string()), timestamp: None, };
306 trace!(
307 "Channel '{}' dispatching event {:?} to callbacks",
308 self.topic,
309 message.event
310 );
311 let callbacks_guard = self.callbacks.read().await;
312 for callback in callbacks_guard.values() {
313 callback(payload.clone());
315 }
316 }
318 _ => {
321 trace!(
322 "Channel '{}' ignored event: {:?}",
323 self.topic,
324 message.event
325 );
326 }
327 }
328 }
329}
330
331pub struct ChannelBuilder<'a> {
333 client: &'a RealtimeClient,
334 topic: String,
335 db_callbacks: HashMap<String, (DatabaseChanges, CallbackFn)>,
336 broadcast_callbacks: HashMap<String, (BroadcastChanges, CallbackFn)>,
337 presence_callbacks: Vec<PresenceCallbackFn>,
338}
339
340impl<'a> ChannelBuilder<'a> {
341 pub(crate) fn new(client: &'a RealtimeClient, topic: &str) -> Self {
342 debug!("ChannelBuilder::new for topic: {}", topic);
343 Self {
344 client,
345 topic: topic.to_string(),
346 db_callbacks: HashMap::new(),
347 broadcast_callbacks: HashMap::new(),
348 presence_callbacks: Vec::new(),
349 }
350 }
351
352 pub fn on<F>(mut self, changes: DatabaseChanges, callback: F) -> Self
354 where
355 F: Fn(Payload) + Send + Sync + 'static,
356 {
357 let id = uuid::Uuid::new_v4().to_string();
359 self.db_callbacks.insert(id, (changes, Box::new(callback)));
360 self
361 }
362
363 pub fn on_broadcast<F>(mut self, changes: BroadcastChanges, callback: F) -> Self
365 where
366 F: Fn(Payload) + Send + Sync + 'static,
367 {
368 let id = uuid::Uuid::new_v4().to_string();
369 self.broadcast_callbacks
370 .insert(id, (changes, Box::new(callback)));
371 self
372 }
373
374 pub fn on_presence<F>(mut self, callback: F) -> Self
376 where
377 F: Fn(PresenceChange) + Send + Sync + 'static,
378 {
379 self.presence_callbacks.push(Box::new(callback));
380 self
381 }
382
383 pub async fn subscribe(self) -> Result<Vec<Subscription>, RealtimeError> {
385 info!("ChannelBuilder subscribing for topic: {}", self.topic);
386 let client_arc = Arc::new(self.client.clone()); let mut channels_guard = client_arc.channels.write().await;
390 let channel = channels_guard
391 .entry(self.topic.clone())
392 .or_insert_with(|| Arc::new(Channel::new(self.topic.clone(), client_arc.clone())))
393 .clone();
394 drop(channels_guard); debug!("Got or created Channel Arc for topic: {}", self.topic);
396
397 let mut subscriptions = Vec::new();
398 let mut callbacks_guard = channel.callbacks.write().await;
399 let mut presence_callbacks_guard = channel.presence_callbacks.write().await;
400
401 for (id, (_changes, callback)) in self.db_callbacks {
403 debug!("Adding DB callback ID {} to channel {}", id, self.topic);
404 callbacks_guard.insert(id.clone(), callback);
405 subscriptions.push(Subscription {
406 id,
407 channel: channel.clone(),
408 });
409 }
410
411 for (id, (_changes, callback)) in self.broadcast_callbacks {
413 debug!(
414 "Adding Broadcast callback ID {} to channel {}",
415 id, self.topic
416 );
417 callbacks_guard.insert(id.clone(), callback);
419 subscriptions.push(Subscription {
420 id,
421 channel: channel.clone(),
422 });
423 }
424
425 for callback in self.presence_callbacks {
427 debug!("Adding Presence callback to channel {}", self.topic);
428 presence_callbacks_guard.push(callback);
429 let id = format!("presence_{}", self.topic); subscriptions.push(Subscription {
432 id,
433 channel: channel.clone(),
434 });
435 }
436
437 drop(callbacks_guard);
438 drop(presence_callbacks_guard);
439
440 let current_state = *channel.state.read().await;
442 if current_state == ChannelState::Closed || current_state == ChannelState::Errored {
443 info!(
444 "Channel '{}' is {:?}, attempting to join.",
445 self.topic, current_state
446 );
447 match channel.join().await {
448 Ok(_) => {
449 debug!(
451 "Join message sent for channel '{}'. Waiting for reply.",
452 self.topic
453 );
454 match timeout(Duration::from_secs(10), async {
456 while *channel.state.read().await != ChannelState::Joined {
457 tokio::time::sleep(Duration::from_millis(50)).await;
458 let check_state = *channel.state.read().await;
460 if check_state == ChannelState::Errored
461 || check_state == ChannelState::Closed
462 {
463 return Err(RealtimeError::SubscriptionError(format!(
464 "Channel '{}' entered state {:?} while waiting for join reply",
465 self.topic, check_state
466 )));
467 }
468 }
469 Ok(())
470 })
471 .await
472 {
473 Ok(Ok(_)) => info!("Channel '{}' successfully joined.", self.topic),
474 Ok(Err(e)) => {
475 error!(
476 "Error waiting for join confirmation for channel '{}': {:?}",
477 self.topic, e
478 );
479 return Err(e);
480 }
481 Err(_) => {
482 error!(
483 "Timed out waiting for join confirmation for channel '{}'",
484 self.topic
485 );
486 channel.set_state(ChannelState::Errored).await;
487 return Err(RealtimeError::SubscriptionError(format!(
488 "Timed out waiting for join confirmation for channel '{}'",
489 self.topic
490 )));
491 }
492 }
493 }
494 Err(e) => {
495 error!(
496 "Failed to send join message for channel '{}': {}",
497 self.topic, e
498 );
499 channel.set_state(ChannelState::Errored).await;
500 return Err(e);
501 }
502 }
503 } else {
504 info!(
505 "Channel '{}' is already {:?}, not sending join message.",
506 self.topic, current_state
507 );
508 }
509
510 info!(
511 "ChannelBuilder subscribe finished for topic '{}', returning {} subscriptions.",
512 self.topic,
513 subscriptions.len()
514 );
515 Ok(subscriptions)
516 }
517
518 pub async fn track_presence(
520 &self,
521 _user_id: &str,
522 _user_data: serde_json::Value,
523 ) -> Result<(), RealtimeError> {
524 Err(RealtimeError::ChannelError(
526 "track_presence not implemented".to_string(),
527 ))
528 }
529}