socketioxide_redis/drivers/
redis.rs
1use std::{
2 collections::HashMap,
3 fmt,
4 sync::{Arc, RwLock},
5};
6
7use redis::{AsyncCommands, FromRedisValue, PushInfo, RedisResult, aio::MultiplexedConnection};
8use tokio::sync::mpsc;
9
10use super::{ChanItem, Driver, MessageStream};
11
12pub use redis as redis_client;
13
14#[derive(Debug)]
16pub struct RedisError(redis::RedisError);
17
18impl From<redis::RedisError> for RedisError {
19 fn from(e: redis::RedisError) -> Self {
20 Self(e)
21 }
22}
23impl fmt::Display for RedisError {
24 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25 self.0.fmt(f)
26 }
27}
28impl std::error::Error for RedisError {}
29
30type HandlerMap = HashMap<String, mpsc::Sender<ChanItem>>;
31#[derive(Clone)]
33pub struct RedisDriver {
34 handlers: Arc<RwLock<HandlerMap>>,
35 conn: MultiplexedConnection,
36}
37
38#[cfg(feature = "redis-cluster")]
40#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
41#[derive(Clone)]
42pub struct ClusterDriver {
43 handlers: Arc<RwLock<HandlerMap>>,
44 conn: redis::cluster_async::ClusterConnection,
45}
46
47fn read_msg(msg: redis::PushInfo) -> RedisResult<Option<(String, Vec<u8>)>> {
48 match msg.kind {
49 redis::PushKind::Message | redis::PushKind::SMessage => {
50 if msg.data.len() < 2 {
51 return Ok(None);
52 }
53 let mut iter = msg.data.into_iter();
54 let channel: String = FromRedisValue::from_owned_redis_value(iter.next().unwrap())?;
55 let message = FromRedisValue::from_owned_redis_value(iter.next().unwrap())?;
56 Ok(Some((channel, message)))
57 }
58 _ => Ok(None),
59 }
60}
61
62fn handle_msg(msg: PushInfo, handlers: Arc<RwLock<HandlerMap>>) {
63 match read_msg(msg) {
64 Ok(Some((chan, msg))) => {
65 if let Some(tx) = handlers.read().unwrap().get(&chan) {
66 if let Err(e) = tx.try_send((chan, msg)) {
67 tracing::warn!("redis pubsub channel full {e}");
68 }
69 } else {
70 tracing::warn!(chan, "no handler for channel");
71 }
72 }
73 Ok(_) => {}
74 Err(e) => {
75 tracing::error!("error reading message from redis: {e}");
76 }
77 }
78}
79impl RedisDriver {
80 pub async fn new(client: &redis::Client) -> Result<Self, redis::RedisError> {
82 let handlers = Arc::new(RwLock::new(HashMap::new()));
83 let handlers_clone = handlers.clone();
84 let config = redis::AsyncConnectionConfig::new().set_push_sender(move |msg| {
85 handle_msg(msg, handlers_clone.clone());
86 Ok::<(), std::convert::Infallible>(())
87 });
88
89 let conn = client
90 .get_multiplexed_async_connection_with_config(&config)
91 .await?;
92
93 Ok(Self { conn, handlers })
94 }
95}
96
97#[cfg(feature = "redis-cluster")]
98impl ClusterDriver {
99 #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
101 pub async fn new(client: &redis::cluster::ClusterClient) -> Result<Self, redis::RedisError> {
102 let handlers = Arc::new(RwLock::new(HashMap::new()));
103 let handlers_clone = handlers.clone();
104 let config = redis::cluster::ClusterConfig::new().set_push_sender(move |msg| {
105 handle_msg(msg, handlers_clone.clone());
106 Ok::<(), std::convert::Infallible>(())
107 });
108 let conn = client.get_async_connection_with_config(config).await?;
109
110 Ok(Self { conn, handlers })
111 }
112}
113
114impl Driver for RedisDriver {
115 type Error = RedisError;
116
117 async fn publish(&self, chan: String, val: Vec<u8>) -> Result<(), Self::Error> {
118 self.conn
119 .clone()
120 .publish::<_, _, redis::Value>(chan, val)
121 .await?;
122 Ok(())
123 }
124
125 async fn subscribe(
126 &self,
127 chan: String,
128 size: usize,
129 ) -> Result<MessageStream<ChanItem>, Self::Error> {
130 self.conn.clone().subscribe(chan.as_str()).await?;
131 let (tx, rx) = mpsc::channel(size);
132 self.handlers.write().unwrap().insert(chan, tx);
133 Ok(MessageStream::new(rx))
134 }
135
136 async fn unsubscribe(&self, chan: String) -> Result<(), Self::Error> {
137 self.handlers.write().unwrap().remove(&chan);
138 self.conn.clone().unsubscribe(chan).await?;
139 Ok(())
140 }
141
142 async fn num_serv(&self, chan: &str) -> Result<u16, Self::Error> {
143 let mut conn = self.conn.clone();
144 let (_, count): (String, u16) = redis::cmd("PUBSUB")
145 .arg("NUMSUB")
146 .arg(chan)
147 .query_async(&mut conn)
148 .await?;
149 Ok(count)
150 }
151}
152
153#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
154#[cfg(feature = "redis-cluster")]
155impl Driver for ClusterDriver {
156 type Error = RedisError;
157
158 async fn publish(&self, chan: String, val: Vec<u8>) -> Result<(), Self::Error> {
159 self.conn
160 .clone()
161 .spublish::<_, _, redis::Value>(chan, val)
162 .await?;
163 Ok(())
164 }
165
166 async fn subscribe(
167 &self,
168 chan: String,
169 size: usize,
170 ) -> Result<MessageStream<ChanItem>, Self::Error> {
171 self.conn.clone().ssubscribe(chan.as_str()).await?;
172 let (tx, rx) = mpsc::channel(size);
173 self.handlers.write().unwrap().insert(chan, tx);
174 Ok(MessageStream::new(rx))
175 }
176
177 async fn unsubscribe(&self, chan: String) -> Result<(), Self::Error> {
178 self.handlers.write().unwrap().remove(&chan);
179 self.conn.clone().sunsubscribe(chan).await?;
180 Ok(())
181 }
182
183 async fn num_serv(&self, chan: &str) -> Result<u16, Self::Error> {
184 let mut conn = self.conn.clone();
185 let (_, count): (String, u16) = redis::cmd("PUBSUB")
186 .arg("SHARDNUMSUB")
187 .arg(chan)
188 .query_async(&mut conn)
189 .await?;
190 Ok(count)
191 }
192}
193
194#[cfg(test)]
195mod tests {
196
197 use super::*;
198 #[test]
199 fn watch_handle_message() {
200 let mut handlers = HashMap::new();
201
202 let (tx, mut rx) = mpsc::channel(1);
203 handlers.insert("test".to_string(), tx);
204 let msg = redis::PushInfo {
205 kind: redis::PushKind::Message,
206 data: vec![
207 redis::Value::BulkString("test".into()),
208 redis::Value::BulkString("foo".into()),
209 ],
210 };
211 super::handle_msg(msg, Arc::new(RwLock::new(handlers)));
212 let (chan, data) = rx.try_recv().unwrap();
213 assert_eq!(chan, "test");
214 assert_eq!(data, "foo".as_bytes());
215 }
216
217 #[test]
218 fn watch_handler_pattern() {
219 let mut handlers = HashMap::new();
220
221 let (tx1, mut rx1) = mpsc::channel(1);
222 handlers.insert("test-response#namespace#uid#".to_string(), tx1);
223 let msg = redis::PushInfo {
224 kind: redis::PushKind::Message,
225 data: vec![
226 redis::Value::BulkString("test-response#namespace#uid#".into()),
227 redis::Value::BulkString("foo".into()),
228 ],
229 };
230 super::handle_msg(msg, Arc::new(RwLock::new(handlers)));
231 let (chan, data) = rx1.try_recv().unwrap();
232 assert_eq!(chan, "test-response#namespace#uid#");
233 assert_eq!(data, "foo".as_bytes());
234 }
235}