socketioxide_redis/drivers/
fred.rs

1use std::{
2    collections::HashMap,
3    fmt,
4    sync::{Arc, RwLock},
5};
6
7use tokio::sync::{broadcast, mpsc};
8
9use super::{ChanItem, Driver, MessageStream};
10
11use fred::{
12    interfaces::PubsubInterface,
13    prelude::{ClientLike, EventInterface, FredResult},
14    types::Message,
15};
16
17pub use fred as fred_client;
18
19/// An error type for the fred driver.
20#[derive(Debug)]
21pub struct FredError(fred::error::Error);
22
23impl From<fred::error::Error> for FredError {
24    fn from(e: fred::error::Error) -> Self {
25        Self(e)
26    }
27}
28impl fmt::Display for FredError {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        self.0.fmt(f)
31    }
32}
33impl std::error::Error for FredError {}
34
35type HandlerMap = HashMap<String, mpsc::Sender<ChanItem>>;
36
37/// Return the channel, data and an optional req_id from a message.
38fn read_msg(msg: Message) -> Option<ChanItem> {
39    let chan = msg.channel.to_string();
40    let data = msg.value.into_owned_bytes()?;
41    Some((chan, data))
42}
43
44/// Pipe messages from the fred client to the handlers.
45async fn msg_handler(mut rx: broadcast::Receiver<Message>, handlers: Arc<RwLock<HandlerMap>>) {
46    loop {
47        match rx.recv().await {
48            Ok(msg) => {
49                if let Some((chan, data)) = read_msg(msg) {
50                    if let Some(tx) = handlers.read().unwrap().get(&chan) {
51                        tx.try_send((chan, data)).unwrap();
52                    } else {
53                        tracing::warn!(chan, "no handler for channel");
54                    }
55                }
56            }
57            // From the fred docs, even if the connection closed, the receiver will not be closed.
58            // Therefore if it happens, we should just return.
59            Err(broadcast::error::RecvError::Closed) => return,
60            Err(broadcast::error::RecvError::Lagged(n)) => {
61                tracing::warn!("fred driver pubsub channel lagged by {}", n);
62            }
63        }
64    }
65}
66
67/// A driver implementation for the [fred](docs.rs/fred) pub/sub backend.
68#[derive(Clone)]
69pub struct FredDriver {
70    handlers: Arc<RwLock<HandlerMap>>,
71    conn: fred::clients::SubscriberClient,
72}
73
74impl FredDriver {
75    /// Create a new redis driver from a redis client.
76    pub async fn new(client: fred::clients::SubscriberClient) -> FredResult<Self> {
77        let handlers = Arc::new(RwLock::new(HashMap::new()));
78        tokio::spawn(msg_handler(client.message_rx(), handlers.clone()));
79        client.init().await?;
80
81        Ok(Self {
82            conn: client,
83            handlers,
84        })
85    }
86}
87
88impl Driver for FredDriver {
89    type Error = FredError;
90
91    async fn publish(&self, chan: String, val: Vec<u8>) -> Result<(), Self::Error> {
92        // We could use the receiver count from here. This would avoid a call to `server_cnt`.
93        self.conn.spublish::<u16, _, _>(chan, val).await?;
94        Ok(())
95    }
96
97    async fn subscribe(
98        &self,
99        chan: String,
100        size: usize,
101    ) -> Result<MessageStream<ChanItem>, Self::Error> {
102        self.conn.clone().ssubscribe(chan.as_str()).await?;
103        let (tx, rx) = mpsc::channel(size);
104        self.handlers.write().unwrap().insert(chan, tx);
105        Ok(MessageStream::new(rx))
106    }
107
108    async fn unsubscribe(&self, chan: String) -> Result<(), Self::Error> {
109        self.handlers.write().unwrap().remove(&chan);
110        self.conn.sunsubscribe(chan).await?;
111        Ok(())
112    }
113
114    async fn num_serv(&self, chan: &str) -> Result<u16, Self::Error> {
115        let (_, num): (String, u16) = self.conn.pubsub_shardnumsub(chan).await?;
116        Ok(num)
117    }
118}
119
120#[cfg(test)]
121mod tests {
122
123    use fred::{
124        prelude::Server,
125        types::{MessageKind, Value},
126    };
127    use std::time::Duration;
128    use tokio::time;
129    const TIMEOUT: Duration = Duration::from_millis(100);
130
131    use super::*;
132    #[tokio::test]
133    async fn watch_handle_message() {
134        let mut handlers = HashMap::new();
135        let (tx, mut rx) = mpsc::channel(1);
136        let (tx1, rx1) = broadcast::channel(1);
137        handlers.insert("test".to_string(), tx);
138        tokio::spawn(msg_handler(rx1, Arc::new(RwLock::new(handlers))));
139        let msg = Message {
140            channel: "test".into(),
141            kind: MessageKind::Message,
142            value: "foo".into(),
143            server: Server::new("0.0.0.0", 0),
144        };
145        tx1.send(msg).unwrap();
146        let (chan, data) = time::timeout(TIMEOUT, rx.recv()).await.unwrap().unwrap();
147        assert_eq!(chan, "test");
148        assert_eq!(data, "foo".as_bytes());
149    }
150
151    #[tokio::test]
152    async fn watch_handler_pattern() {
153        let mut handlers = HashMap::new();
154
155        let (tx, mut rx) = mpsc::channel(1);
156        handlers.insert("test-response#namespace#uid#".to_string(), tx);
157        let (tx1, rx1) = broadcast::channel(1);
158        tokio::spawn(msg_handler(rx1, Arc::new(RwLock::new(handlers))));
159        let msg = Message {
160            channel: "test-response#namespace#uid#".into(),
161            kind: MessageKind::Message,
162            value: Value::from_static(b"foo"),
163            server: Server::new("0.0.0.0", 0),
164        };
165        tx1.send(msg).unwrap();
166        let (chan, data) = time::timeout(TIMEOUT, rx.recv()).await.unwrap().unwrap();
167        assert_eq!(chan, "test-response#namespace#uid#");
168        assert_eq!(data, "foo".as_bytes());
169    }
170}