Skip to main content

shaperail_runtime/ws/
pubsub.rs

1use std::sync::Arc;
2
3use futures_util::StreamExt;
4use redis::AsyncCommands;
5use serde::{Deserialize, Serialize};
6use shaperail_core::ShaperailError;
7
8use super::room::RoomManager;
9
10/// A broadcast message published through Redis pub/sub.
11///
12/// Serialized to JSON for transport across server instances.
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct PubSubMessage {
15    /// The channel name (matches ChannelDefinition.channel).
16    pub channel: String,
17    /// The room within the channel.
18    pub room: String,
19    /// The event name (e.g., "user.created").
20    pub event: String,
21    /// The event payload.
22    pub data: serde_json::Value,
23}
24
25/// Redis pub/sub backend for cross-instance WebSocket broadcast.
26///
27/// When a server instance needs to broadcast to a room, it publishes
28/// to a Redis channel. All instances subscribe to these channels and
29/// route messages to their locally connected clients.
30#[derive(Clone)]
31pub struct RedisPubSub {
32    pool: Arc<deadpool_redis::Pool>,
33}
34
35impl RedisPubSub {
36    /// Creates a new Redis pub/sub backend.
37    pub fn new(pool: Arc<deadpool_redis::Pool>) -> Self {
38        Self { pool }
39    }
40
41    /// Returns the Redis channel name for a given WebSocket channel.
42    fn redis_channel(channel: &str) -> String {
43        format!("shaperail:ws:{channel}")
44    }
45
46    /// Publishes a broadcast message to Redis so all instances receive it.
47    pub async fn publish(&self, msg: &PubSubMessage) -> Result<(), ShaperailError> {
48        let mut conn = self
49            .pool
50            .get()
51            .await
52            .map_err(|e| ShaperailError::Internal(format!("Redis connection failed: {e}")))?;
53
54        let payload = serde_json::to_string(msg).map_err(|e| {
55            ShaperailError::Internal(format!("Failed to serialize pub/sub message: {e}"))
56        })?;
57
58        let redis_channel = Self::redis_channel(&msg.channel);
59        conn.publish::<_, _, ()>(&redis_channel, &payload)
60            .await
61            .map_err(|e| ShaperailError::Internal(format!("Redis publish failed: {e}")))?;
62
63        tracing::debug!(
64            channel = %msg.channel,
65            room = %msg.room,
66            event = %msg.event,
67            "Published broadcast via Redis pub/sub"
68        );
69
70        Ok(())
71    }
72
73    /// Starts a subscriber task that listens on Redis pub/sub and routes
74    /// messages to the local room manager.
75    ///
76    /// This spawns a background Tokio task that runs until the returned
77    /// `tokio::task::JoinHandle` is aborted.
78    pub fn start_subscriber(
79        &self,
80        channel_name: &str,
81        room_manager: RoomManager,
82        redis_url: &str,
83    ) -> tokio::task::JoinHandle<()> {
84        let redis_channel = Self::redis_channel(channel_name);
85        let redis_url = redis_url.to_string();
86
87        tokio::spawn(async move {
88            if let Err(e) = Self::subscriber_loop(&redis_url, &redis_channel, &room_manager).await {
89                tracing::error!(
90                    error = %e,
91                    channel = %redis_channel,
92                    "Redis subscriber exited with error"
93                );
94            }
95        })
96    }
97
98    /// Internal subscriber loop — connects to Redis and processes messages.
99    async fn subscriber_loop(
100        redis_url: &str,
101        redis_channel: &str,
102        room_manager: &RoomManager,
103    ) -> Result<(), ShaperailError> {
104        let client = redis::Client::open(redis_url)
105            .map_err(|e| ShaperailError::Internal(format!("Redis client creation failed: {e}")))?;
106
107        let mut pubsub_conn = client.get_async_pubsub().await.map_err(|e| {
108            ShaperailError::Internal(format!("Redis pub/sub connection failed: {e}"))
109        })?;
110
111        pubsub_conn
112            .subscribe(redis_channel)
113            .await
114            .map_err(|e| ShaperailError::Internal(format!("Redis subscribe failed: {e}")))?;
115
116        tracing::info!(channel = %redis_channel, "Redis pub/sub subscriber started");
117
118        // Use into_on_message to consume the PubSub and get an owned stream
119        let mut msg_stream = pubsub_conn.into_on_message();
120
121        while let Some(msg) = msg_stream.next().await {
122            let payload: String = match msg.get_payload() {
123                Ok(p) => p,
124                Err(e) => {
125                    tracing::warn!(error = %e, "Failed to get pub/sub payload");
126                    continue;
127                }
128            };
129
130            let broadcast: PubSubMessage = match serde_json::from_str(&payload) {
131                Ok(m) => m,
132                Err(e) => {
133                    tracing::warn!(error = %e, "Failed to parse pub/sub message");
134                    continue;
135                }
136            };
137
138            // Build server message and broadcast to local clients
139            let server_msg = shaperail_core::WsServerMessage::Broadcast {
140                room: broadcast.room.clone(),
141                event: broadcast.event,
142                data: broadcast.data,
143            };
144
145            if let Ok(text) = serde_json::to_string(&server_msg) {
146                room_manager.broadcast_to_room(&broadcast.room, &text).await;
147            }
148        }
149
150        tracing::warn!(channel = %redis_channel, "Redis pub/sub stream ended");
151        Ok(())
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158
159    #[test]
160    fn redis_channel_name() {
161        assert_eq!(
162            RedisPubSub::redis_channel("notifications"),
163            "shaperail:ws:notifications"
164        );
165    }
166
167    #[test]
168    fn pubsub_message_serde_roundtrip() {
169        let msg = PubSubMessage {
170            channel: "notifications".to_string(),
171            room: "org:123".to_string(),
172            event: "user.created".to_string(),
173            data: serde_json::json!({"id": "abc"}),
174        };
175        let json = serde_json::to_string(&msg).unwrap();
176        let back: PubSubMessage = serde_json::from_str(&json).unwrap();
177        assert_eq!(back.channel, "notifications");
178        assert_eq!(back.room, "org:123");
179        assert_eq!(back.event, "user.created");
180    }
181}