shaperail_runtime/ws/
pubsub.rs1use std::sync::Arc;
2
3use futures_util::StreamExt;
4use redis::AsyncCommands;
5use serde::{Deserialize, Serialize};
6use shaperail_core::ShaperailError;
7
8use super::room::RoomManager;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct PubSubMessage {
15 pub channel: String,
17 pub room: String,
19 pub event: String,
21 pub data: serde_json::Value,
23}
24
25#[derive(Clone)]
31pub struct RedisPubSub {
32 pool: Arc<deadpool_redis::Pool>,
33}
34
35impl RedisPubSub {
36 pub fn new(pool: Arc<deadpool_redis::Pool>) -> Self {
38 Self { pool }
39 }
40
41 fn redis_channel(channel: &str) -> String {
43 format!("shaperail:ws:{channel}")
44 }
45
46 pub async fn publish(&self, msg: &PubSubMessage) -> Result<(), ShaperailError> {
48 let mut conn = self
49 .pool
50 .get()
51 .await
52 .map_err(|e| ShaperailError::Internal(format!("Redis connection failed: {e}")))?;
53
54 let payload = serde_json::to_string(msg).map_err(|e| {
55 ShaperailError::Internal(format!("Failed to serialize pub/sub message: {e}"))
56 })?;
57
58 let redis_channel = Self::redis_channel(&msg.channel);
59 conn.publish::<_, _, ()>(&redis_channel, &payload)
60 .await
61 .map_err(|e| ShaperailError::Internal(format!("Redis publish failed: {e}")))?;
62
63 tracing::debug!(
64 channel = %msg.channel,
65 room = %msg.room,
66 event = %msg.event,
67 "Published broadcast via Redis pub/sub"
68 );
69
70 Ok(())
71 }
72
73 pub fn start_subscriber(
79 &self,
80 channel_name: &str,
81 room_manager: RoomManager,
82 redis_url: &str,
83 ) -> tokio::task::JoinHandle<()> {
84 let redis_channel = Self::redis_channel(channel_name);
85 let redis_url = redis_url.to_string();
86
87 tokio::spawn(async move {
88 if let Err(e) = Self::subscriber_loop(&redis_url, &redis_channel, &room_manager).await {
89 tracing::error!(
90 error = %e,
91 channel = %redis_channel,
92 "Redis subscriber exited with error"
93 );
94 }
95 })
96 }
97
98 async fn subscriber_loop(
100 redis_url: &str,
101 redis_channel: &str,
102 room_manager: &RoomManager,
103 ) -> Result<(), ShaperailError> {
104 let client = redis::Client::open(redis_url)
105 .map_err(|e| ShaperailError::Internal(format!("Redis client creation failed: {e}")))?;
106
107 let mut pubsub_conn = client.get_async_pubsub().await.map_err(|e| {
108 ShaperailError::Internal(format!("Redis pub/sub connection failed: {e}"))
109 })?;
110
111 pubsub_conn
112 .subscribe(redis_channel)
113 .await
114 .map_err(|e| ShaperailError::Internal(format!("Redis subscribe failed: {e}")))?;
115
116 tracing::info!(channel = %redis_channel, "Redis pub/sub subscriber started");
117
118 let mut msg_stream = pubsub_conn.into_on_message();
120
121 while let Some(msg) = msg_stream.next().await {
122 let payload: String = match msg.get_payload() {
123 Ok(p) => p,
124 Err(e) => {
125 tracing::warn!(error = %e, "Failed to get pub/sub payload");
126 continue;
127 }
128 };
129
130 let broadcast: PubSubMessage = match serde_json::from_str(&payload) {
131 Ok(m) => m,
132 Err(e) => {
133 tracing::warn!(error = %e, "Failed to parse pub/sub message");
134 continue;
135 }
136 };
137
138 let server_msg = shaperail_core::WsServerMessage::Broadcast {
140 room: broadcast.room.clone(),
141 event: broadcast.event,
142 data: broadcast.data,
143 };
144
145 if let Ok(text) = serde_json::to_string(&server_msg) {
146 room_manager.broadcast_to_room(&broadcast.room, &text).await;
147 }
148 }
149
150 tracing::warn!(channel = %redis_channel, "Redis pub/sub stream ended");
151 Ok(())
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn redis_channel_name() {
161 assert_eq!(
162 RedisPubSub::redis_channel("notifications"),
163 "shaperail:ws:notifications"
164 );
165 }
166
167 #[test]
168 fn pubsub_message_serde_roundtrip() {
169 let msg = PubSubMessage {
170 channel: "notifications".to_string(),
171 room: "org:123".to_string(),
172 event: "user.created".to_string(),
173 data: serde_json::json!({"id": "abc"}),
174 };
175 let json = serde_json::to_string(&msg).unwrap();
176 let back: PubSubMessage = serde_json::from_str(&json).unwrap();
177 assert_eq!(back.channel, "notifications");
178 assert_eq!(back.room, "org:123");
179 assert_eq!(back.event, "user.created");
180 }
181}