socketioxide_redis/drivers/
redis.rs

1use std::{
2    collections::HashMap,
3    fmt,
4    sync::{Arc, RwLock},
5};
6
7use redis::{AsyncCommands, FromRedisValue, PushInfo, RedisResult, aio::MultiplexedConnection};
8use tokio::sync::mpsc;
9
10use super::{ChanItem, Driver, MessageStream};
11
12pub use redis as redis_client;
13
14/// An error type for the redis driver.
15#[derive(Debug)]
16pub struct RedisError(redis::RedisError);
17
18impl From<redis::RedisError> for RedisError {
19    fn from(e: redis::RedisError) -> Self {
20        Self(e)
21    }
22}
23impl fmt::Display for RedisError {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        self.0.fmt(f)
26    }
27}
28impl std::error::Error for RedisError {}
29
30type HandlerMap = HashMap<String, mpsc::Sender<ChanItem>>;
31/// A driver implementation for the [redis](docs.rs/redis) pub/sub backend.
32#[derive(Clone)]
33pub struct RedisDriver {
34    handlers: Arc<RwLock<HandlerMap>>,
35    conn: MultiplexedConnection,
36}
37
38/// A driver implementation for the [redis](docs.rs/redis) pub/sub backend.
39#[cfg(feature = "redis-cluster")]
40#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
41#[derive(Clone)]
42pub struct ClusterDriver {
43    handlers: Arc<RwLock<HandlerMap>>,
44    conn: redis::cluster_async::ClusterConnection,
45}
46
47fn read_msg(msg: redis::PushInfo) -> RedisResult<Option<(String, Vec<u8>)>> {
48    match msg.kind {
49        redis::PushKind::Message | redis::PushKind::SMessage => {
50            if msg.data.len() < 2 {
51                return Ok(None);
52            }
53            let mut iter = msg.data.into_iter();
54            let channel: String = FromRedisValue::from_owned_redis_value(iter.next().unwrap())?;
55            let message = FromRedisValue::from_owned_redis_value(iter.next().unwrap())?;
56            Ok(Some((channel, message)))
57        }
58        _ => Ok(None),
59    }
60}
61
62fn handle_msg(msg: PushInfo, handlers: Arc<RwLock<HandlerMap>>) {
63    match read_msg(msg) {
64        Ok(Some((chan, msg))) => {
65            if let Some(tx) = handlers.read().unwrap().get(&chan) {
66                if let Err(e) = tx.try_send((chan, msg)) {
67                    tracing::warn!("redis pubsub channel full {e}");
68                }
69            } else {
70                tracing::warn!(chan, "no handler for channel");
71            }
72        }
73        Ok(_) => {}
74        Err(e) => {
75            tracing::error!("error reading message from redis: {e}");
76        }
77    }
78}
79impl RedisDriver {
80    /// Create a new redis driver from a redis client.
81    pub async fn new(client: &redis::Client) -> Result<Self, redis::RedisError> {
82        let handlers = Arc::new(RwLock::new(HashMap::new()));
83        let handlers_clone = handlers.clone();
84        let config = redis::AsyncConnectionConfig::new().set_push_sender(move |msg| {
85            handle_msg(msg, handlers_clone.clone());
86            Ok::<(), std::convert::Infallible>(())
87        });
88
89        let conn = client
90            .get_multiplexed_async_connection_with_config(&config)
91            .await?;
92
93        Ok(Self { conn, handlers })
94    }
95}
96
97#[cfg(feature = "redis-cluster")]
98impl ClusterDriver {
99    /// Create a new redis driver from a redis cluster client.
100    #[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
101    pub async fn new(client: &redis::cluster::ClusterClient) -> Result<Self, redis::RedisError> {
102        let handlers = Arc::new(RwLock::new(HashMap::new()));
103        let handlers_clone = handlers.clone();
104        let config = redis::cluster::ClusterConfig::new().set_push_sender(move |msg| {
105            handle_msg(msg, handlers_clone.clone());
106            Ok::<(), std::convert::Infallible>(())
107        });
108        let conn = client.get_async_connection_with_config(config).await?;
109
110        Ok(Self { conn, handlers })
111    }
112}
113
114impl Driver for RedisDriver {
115    type Error = RedisError;
116
117    async fn publish(&self, chan: String, val: Vec<u8>) -> Result<(), Self::Error> {
118        self.conn
119            .clone()
120            .publish::<_, _, redis::Value>(chan, val)
121            .await?;
122        Ok(())
123    }
124
125    async fn subscribe(
126        &self,
127        chan: String,
128        size: usize,
129    ) -> Result<MessageStream<ChanItem>, Self::Error> {
130        self.conn.clone().subscribe(chan.as_str()).await?;
131        let (tx, rx) = mpsc::channel(size);
132        self.handlers.write().unwrap().insert(chan, tx);
133        Ok(MessageStream::new(rx))
134    }
135
136    async fn unsubscribe(&self, chan: String) -> Result<(), Self::Error> {
137        self.handlers.write().unwrap().remove(&chan);
138        self.conn.clone().unsubscribe(chan).await?;
139        Ok(())
140    }
141
142    async fn num_serv(&self, chan: &str) -> Result<u16, Self::Error> {
143        let mut conn = self.conn.clone();
144        let (_, count): (String, u16) = redis::cmd("PUBSUB")
145            .arg("NUMSUB")
146            .arg(chan)
147            .query_async(&mut conn)
148            .await?;
149        Ok(count)
150    }
151}
152
153#[cfg_attr(docsrs, doc(cfg(feature = "redis-cluster")))]
154#[cfg(feature = "redis-cluster")]
155impl Driver for ClusterDriver {
156    type Error = RedisError;
157
158    async fn publish(&self, chan: String, val: Vec<u8>) -> Result<(), Self::Error> {
159        self.conn
160            .clone()
161            .spublish::<_, _, redis::Value>(chan, val)
162            .await?;
163        Ok(())
164    }
165
166    async fn subscribe(
167        &self,
168        chan: String,
169        size: usize,
170    ) -> Result<MessageStream<ChanItem>, Self::Error> {
171        self.conn.clone().ssubscribe(chan.as_str()).await?;
172        let (tx, rx) = mpsc::channel(size);
173        self.handlers.write().unwrap().insert(chan, tx);
174        Ok(MessageStream::new(rx))
175    }
176
177    async fn unsubscribe(&self, chan: String) -> Result<(), Self::Error> {
178        self.handlers.write().unwrap().remove(&chan);
179        self.conn.clone().sunsubscribe(chan).await?;
180        Ok(())
181    }
182
183    async fn num_serv(&self, chan: &str) -> Result<u16, Self::Error> {
184        let mut conn = self.conn.clone();
185        let (_, count): (String, u16) = redis::cmd("PUBSUB")
186            .arg("SHARDNUMSUB")
187            .arg(chan)
188            .query_async(&mut conn)
189            .await?;
190        Ok(count)
191    }
192}
193
194#[cfg(test)]
195mod tests {
196
197    use super::*;
198    #[test]
199    fn watch_handle_message() {
200        let mut handlers = HashMap::new();
201
202        let (tx, mut rx) = mpsc::channel(1);
203        handlers.insert("test".to_string(), tx);
204        let msg = redis::PushInfo {
205            kind: redis::PushKind::Message,
206            data: vec![
207                redis::Value::BulkString("test".into()),
208                redis::Value::BulkString("foo".into()),
209            ],
210        };
211        super::handle_msg(msg, Arc::new(RwLock::new(handlers)));
212        let (chan, data) = rx.try_recv().unwrap();
213        assert_eq!(chan, "test");
214        assert_eq!(data, "foo".as_bytes());
215    }
216
217    #[test]
218    fn watch_handler_pattern() {
219        let mut handlers = HashMap::new();
220
221        let (tx1, mut rx1) = mpsc::channel(1);
222        handlers.insert("test-response#namespace#uid#".to_string(), tx1);
223        let msg = redis::PushInfo {
224            kind: redis::PushKind::Message,
225            data: vec![
226                redis::Value::BulkString("test-response#namespace#uid#".into()),
227                redis::Value::BulkString("foo".into()),
228            ],
229        };
230        super::handle_msg(msg, Arc::new(RwLock::new(handlers)));
231        let (chan, data) = rx1.try_recv().unwrap();
232        assert_eq!(chan, "test-response#namespace#uid#");
233        assert_eq!(data, "foo".as_bytes());
234    }
235}