1use std::fmt::{Debug, Formatter};
17use std::sync::Arc;
18
19use bytes::Bytes;
20use fred::clients::Client;
21use fred::interfaces::{ClientLike, PubsubInterface};
22use fred::types::Message;
23use futures::Stream;
24use futures::stream::unfold;
25use ruststream::codec::Codec;
26use ruststream::{
27 AckError, Headers, IncomingMessage, OutgoingMessage, Partitioned, Publisher, SubscriptionSource,
28};
29use tokio::sync::OnceCell;
30use tokio::sync::broadcast::{Receiver, error::RecvError};
31
32use crate::envelope::{SharedEnvelope, frame, unframe};
33use crate::{RedisBroker, error::RedisError, message::PARTITION_KEY_HEADER};
34
35#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
37pub enum PubSubMode {
38 #[default]
40 Classic,
41 Sharded,
43}
44
45#[derive(Clone)]
58#[must_use]
59pub struct RedisPubSub {
60 channel: String,
61 mode: PubSubMode,
62 pattern: bool,
63 codec: Option<SharedEnvelope>,
64}
65
66impl Debug for RedisPubSub {
67 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
68 f.debug_struct("RedisPubSub")
69 .field("channel", &self.channel)
70 .field("mode", &self.mode)
71 .field("pattern", &self.pattern)
72 .field("codec", &self.codec.is_some())
73 .finish()
74 }
75}
76
77impl RedisPubSub {
78 pub fn new(channel: impl Into<String>) -> Self {
80 Self {
81 channel: channel.into(),
82 mode: PubSubMode::default(),
83 pattern: false,
84 codec: None,
85 }
86 }
87
88 pub const fn mode(mut self, mode: PubSubMode) -> Self {
90 self.mode = mode;
91 self
92 }
93
94 pub const fn pattern(mut self) -> Self {
97 self.pattern = true;
98 self
99 }
100
101 pub fn codec(mut self, codec: impl Codec + 'static) -> Self {
104 self.codec = Some(Arc::new(codec));
105 self
106 }
107
108 #[must_use]
110 pub fn channel(&self) -> &str {
111 &self.channel
112 }
113
114 pub(crate) const fn delivery_mode(&self) -> PubSubMode {
115 self.mode
116 }
117
118 pub(crate) const fn is_pattern(&self) -> bool {
119 self.pattern
120 }
121
122 pub(crate) fn codec_handle(&self) -> Option<SharedEnvelope> {
123 self.codec.clone()
124 }
125
126 pub(crate) fn validate(&self) -> Result<(), RedisError> {
127 if self.pattern && matches!(self.mode, PubSubMode::Sharded) {
128 return Err(RedisError::InvalidOptions(
129 "pattern subscriptions are classic-only; sharded pub/sub has no PSUBSCRIBE"
130 .to_owned(),
131 ));
132 }
133 Ok(())
134 }
135}
136
137impl SubscriptionSource<RedisBroker> for RedisPubSub {
138 type Subscriber = RedisPubSubSubscriber;
139
140 fn name(&self) -> &str {
141 self.channel()
142 }
143
144 async fn subscribe(self, broker: &RedisBroker) -> Result<Self::Subscriber, RedisError> {
145 broker.subscribe_pubsub(self).await
146 }
147}
148
149pub struct RedisPubSubSubscriber {
152 client: Client,
153 rx: Receiver<Message>,
154 codec: Option<SharedEnvelope>,
155}
156
157impl Debug for RedisPubSubSubscriber {
158 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("RedisPubSubSubscriber")
160 .finish_non_exhaustive()
161 }
162}
163
164impl RedisPubSubSubscriber {
165 pub(crate) fn new(
166 client: Client,
167 rx: Receiver<Message>,
168 codec: Option<SharedEnvelope>,
169 ) -> Self {
170 Self { client, rx, codec }
171 }
172}
173
174impl Drop for RedisPubSubSubscriber {
175 fn drop(&mut self) {
176 let client = self.client.clone();
179 tokio::spawn(async move {
180 let _ = client.quit().await;
181 });
182 }
183}
184
185fn to_message(msg: &Message, codec: Option<&SharedEnvelope>) -> RedisPubSubMessage {
186 let raw = msg.value.as_bytes().unwrap_or(&[]);
187 let (payload, headers) = unframe(codec, raw);
188 RedisPubSubMessage {
189 channel: msg.channel.to_string(),
190 payload,
191 headers,
192 }
193}
194
195impl ruststream::Subscriber for RedisPubSubSubscriber {
196 type Message = RedisPubSubMessage;
197 type Error = RedisError;
198
199 fn stream(&mut self) -> impl Stream<Item = Result<Self::Message, Self::Error>> + Send + '_ {
207 let codec = self.codec.clone();
208 unfold((&mut self.rx, codec), |(rx, codec)| async move {
209 loop {
210 match rx.recv().await {
211 Ok(msg) => {
212 let message = to_message(&msg, codec.as_ref());
213 return Some((Ok(message), (rx, codec)));
214 }
215 Err(RecvError::Lagged(_)) => {}
217 Err(RecvError::Closed) => return None,
218 }
219 }
220 })
221 }
222}
223
224pub struct RedisPubSubMessage {
226 channel: String,
227 payload: Bytes,
228 headers: Headers,
229}
230
231impl Debug for RedisPubSubMessage {
232 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
233 f.debug_struct("RedisPubSubMessage")
234 .field("channel", &self.channel)
235 .field("payload_len", &self.payload.len())
236 .finish_non_exhaustive()
237 }
238}
239
240impl RedisPubSubMessage {
241 #[must_use]
243 pub fn channel(&self) -> &str {
244 &self.channel
245 }
246}
247
248impl IncomingMessage for RedisPubSubMessage {
249 fn payload(&self) -> &[u8] {
250 &self.payload
251 }
252
253 fn headers(&self) -> &Headers {
254 &self.headers
255 }
256
257 async fn ack(self) -> Result<(), AckError> {
258 Err(AckError::Unsupported)
259 }
260
261 async fn nack(self, _requeue: bool) -> Result<(), AckError> {
262 Err(AckError::Unsupported)
263 }
264}
265
266impl Partitioned for RedisPubSubMessage {
267 fn partition_key(&self) -> Option<&[u8]> {
268 self.headers().get(PARTITION_KEY_HEADER)
269 }
270}
271
272#[derive(Clone)]
279pub struct RedisPubSubPublisher {
280 pool: Arc<OnceCell<fred::clients::Pool>>,
281 mode: PubSubMode,
282 codec: Option<SharedEnvelope>,
283}
284
285impl Debug for RedisPubSubPublisher {
286 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
287 f.debug_struct("RedisPubSubPublisher")
288 .field("mode", &self.mode)
289 .field("codec", &self.codec.is_some())
290 .finish_non_exhaustive()
291 }
292}
293
294impl RedisPubSubPublisher {
295 pub(crate) fn new(pool: Arc<OnceCell<fred::clients::Pool>>, mode: PubSubMode) -> Self {
296 Self {
297 pool,
298 mode,
299 codec: None,
300 }
301 }
302
303 #[must_use]
306 pub const fn mode(mut self, mode: PubSubMode) -> Self {
307 self.mode = mode;
308 self
309 }
310
311 #[must_use]
314 pub fn codec(mut self, codec: impl Codec + 'static) -> Self {
315 self.codec = Some(Arc::new(codec));
316 self
317 }
318}
319
320impl Publisher for RedisPubSubPublisher {
321 type Error = RedisError;
322
323 async fn publish(&self, msg: OutgoingMessage<'_>) -> Result<(), Self::Error> {
324 let pool = self.pool.get().cloned().ok_or(RedisError::NotConnected)?;
325 let client = pool.next();
326 let channel = msg.name().to_owned();
327 let body = frame(self.codec.as_ref(), msg.payload(), msg.headers());
328 let _: i64 = match self.mode {
329 PubSubMode::Classic => client.publish(channel, body).await,
330 PubSubMode::Sharded => client.spublish(channel, body).await,
331 }
332 .map_err(RedisError::publish)?;
333 Ok(())
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn pattern_with_sharded_is_rejected() {
343 let err = RedisPubSub::new("e.*")
344 .mode(PubSubMode::Sharded)
345 .pattern()
346 .validate()
347 .unwrap_err();
348 assert!(matches!(err, RedisError::InvalidOptions(msg) if msg.contains("classic-only")));
349 }
350
351 #[test]
352 fn classic_pattern_validates() {
353 RedisPubSub::new("e.*").pattern().validate().expect("ok");
354 }
355}