socketioxide_redis/drivers/
fred.rs1use std::{
2 collections::HashMap,
3 fmt,
4 sync::{Arc, RwLock},
5};
6
7use tokio::sync::{broadcast, mpsc};
8
9use super::{ChanItem, Driver, MessageStream};
10
11use fred::{
12 interfaces::PubsubInterface,
13 prelude::{ClientLike, EventInterface, FredResult},
14 types::Message,
15};
16
17pub use fred as fred_client;
18
19#[derive(Debug)]
21pub struct FredError(fred::error::Error);
22
23impl From<fred::error::Error> for FredError {
24 fn from(e: fred::error::Error) -> Self {
25 Self(e)
26 }
27}
28impl fmt::Display for FredError {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 self.0.fmt(f)
31 }
32}
33impl std::error::Error for FredError {}
34
35type HandlerMap = HashMap<String, mpsc::Sender<ChanItem>>;
36
37fn read_msg(msg: Message) -> Option<ChanItem> {
39 let chan = msg.channel.to_string();
40 let data = msg.value.into_owned_bytes()?;
41 Some((chan, data))
42}
43
44async fn msg_handler(mut rx: broadcast::Receiver<Message>, handlers: Arc<RwLock<HandlerMap>>) {
46 loop {
47 match rx.recv().await {
48 Ok(msg) => {
49 if let Some((chan, data)) = read_msg(msg) {
50 if let Some(tx) = handlers.read().unwrap().get(&chan) {
51 tx.try_send((chan, data)).unwrap();
52 } else {
53 tracing::warn!(chan, "no handler for channel");
54 }
55 }
56 }
57 Err(broadcast::error::RecvError::Closed) => return,
60 Err(broadcast::error::RecvError::Lagged(n)) => {
61 tracing::warn!("fred driver pubsub channel lagged by {}", n);
62 }
63 }
64 }
65}
66
67#[derive(Clone)]
69pub struct FredDriver {
70 handlers: Arc<RwLock<HandlerMap>>,
71 conn: fred::clients::SubscriberClient,
72}
73
74impl FredDriver {
75 pub async fn new(client: fred::clients::SubscriberClient) -> FredResult<Self> {
77 let handlers = Arc::new(RwLock::new(HashMap::new()));
78 tokio::spawn(msg_handler(client.message_rx(), handlers.clone()));
79 client.init().await?;
80
81 Ok(Self {
82 conn: client,
83 handlers,
84 })
85 }
86}
87
88impl Driver for FredDriver {
89 type Error = FredError;
90
91 async fn publish(&self, chan: String, val: Vec<u8>) -> Result<(), Self::Error> {
92 self.conn.spublish::<u16, _, _>(chan, val).await?;
94 Ok(())
95 }
96
97 async fn subscribe(
98 &self,
99 chan: String,
100 size: usize,
101 ) -> Result<MessageStream<ChanItem>, Self::Error> {
102 self.conn.clone().ssubscribe(chan.as_str()).await?;
103 let (tx, rx) = mpsc::channel(size);
104 self.handlers.write().unwrap().insert(chan, tx);
105 Ok(MessageStream::new(rx))
106 }
107
108 async fn unsubscribe(&self, chan: String) -> Result<(), Self::Error> {
109 self.handlers.write().unwrap().remove(&chan);
110 self.conn.sunsubscribe(chan).await?;
111 Ok(())
112 }
113
114 async fn num_serv(&self, chan: &str) -> Result<u16, Self::Error> {
115 let (_, num): (String, u16) = self.conn.pubsub_shardnumsub(chan).await?;
116 Ok(num)
117 }
118}
119
120#[cfg(test)]
121mod tests {
122
123 use fred::{
124 prelude::Server,
125 types::{MessageKind, Value},
126 };
127 use std::time::Duration;
128 use tokio::time;
129 const TIMEOUT: Duration = Duration::from_millis(100);
130
131 use super::*;
132 #[tokio::test]
133 async fn watch_handle_message() {
134 let mut handlers = HashMap::new();
135 let (tx, mut rx) = mpsc::channel(1);
136 let (tx1, rx1) = broadcast::channel(1);
137 handlers.insert("test".to_string(), tx);
138 tokio::spawn(msg_handler(rx1, Arc::new(RwLock::new(handlers))));
139 let msg = Message {
140 channel: "test".into(),
141 kind: MessageKind::Message,
142 value: "foo".into(),
143 server: Server::new("0.0.0.0", 0),
144 };
145 tx1.send(msg).unwrap();
146 let (chan, data) = time::timeout(TIMEOUT, rx.recv()).await.unwrap().unwrap();
147 assert_eq!(chan, "test");
148 assert_eq!(data, "foo".as_bytes());
149 }
150
151 #[tokio::test]
152 async fn watch_handler_pattern() {
153 let mut handlers = HashMap::new();
154
155 let (tx, mut rx) = mpsc::channel(1);
156 handlers.insert("test-response#namespace#uid#".to_string(), tx);
157 let (tx1, rx1) = broadcast::channel(1);
158 tokio::spawn(msg_handler(rx1, Arc::new(RwLock::new(handlers))));
159 let msg = Message {
160 channel: "test-response#namespace#uid#".into(),
161 kind: MessageKind::Message,
162 value: Value::from_static(b"foo"),
163 server: Server::new("0.0.0.0", 0),
164 };
165 tx1.send(msg).unwrap();
166 let (chan, data) = time::timeout(TIMEOUT, rx.recv()).await.unwrap().unwrap();
167 assert_eq!(chan, "test-response#namespace#uid#");
168 assert_eq!(data, "foo".as_bytes());
169 }
170}