1use std::collections::{HashMap, VecDeque};
4use std::fmt::{Debug, Formatter};
5use std::time::Duration;
6
7use fred::clients::Pool;
8use fred::interfaces::StreamsInterface;
9use fred::types::streams::XReadValue;
10use futures::Stream;
11use futures::stream::unfold;
12use ruststream::{BatchSubscriber, Subscriber};
13
14use crate::{
15 convert::parts_from_fields, error::RedisError, message::RedisMessage, stream::ReadMode,
16};
17
18type Entry = (String, HashMap<String, Vec<u8>>);
20
21type RawStreams = Vec<(String, Vec<(String, Vec<(String, Vec<u8>)>)>)>;
26
27const RECLAIM_START: &str = "0-0";
29
30fn duration_to_millis(d: Duration) -> u64 {
31 u64::try_from(d.as_millis()).unwrap_or(u64::MAX)
32}
33
34pub struct RedisSubscriber {
39 pool: Pool,
40 key: String,
41 group: String,
42 consumer: String,
43 count: u64,
44 block: Duration,
45 mode: ReadMode,
46 cursor: String,
48 buffer: VecDeque<Entry>,
50}
51
52impl Debug for RedisSubscriber {
53 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
54 f.debug_struct("RedisSubscriber")
55 .field("key", &self.key)
56 .field("group", &self.group)
57 .field("consumer", &self.consumer)
58 .field("mode", &self.mode)
59 .finish_non_exhaustive()
60 }
61}
62
63impl RedisSubscriber {
64 pub(crate) fn new(
65 pool: Pool,
66 key: String,
67 group: String,
68 consumer: String,
69 count: u64,
70 block: Duration,
71 mode: ReadMode,
72 ) -> Self {
73 Self {
74 pool,
75 key,
76 group,
77 consumer,
78 count,
79 block,
80 mode,
81 cursor: RECLAIM_START.to_owned(),
82 buffer: VecDeque::new(),
83 }
84 }
85
86 fn message(&self, id: String, fields: HashMap<String, Vec<u8>>) -> RedisMessage {
87 let (payload, headers) = parts_from_fields(fields);
88 RedisMessage::new(
89 self.pool.clone(),
90 self.key.clone(),
91 self.group.clone(),
92 id,
93 payload,
94 headers,
95 )
96 }
97
98 async fn fetch(&mut self) -> Result<(), RedisError> {
101 let entries = match self.mode.clone() {
102 ReadMode::Fresh => self.fetch_fresh().await?,
103 ReadMode::Reclaim { min_idle } => self.fetch_reclaim(min_idle).await?,
104 };
105 self.buffer.extend(entries);
106 Ok(())
107 }
108
109 async fn fetch_fresh(&self) -> Result<Vec<Entry>, RedisError> {
110 let resp: RawStreams = self
111 .pool
112 .xreadgroup(
113 self.group.as_str(),
114 self.consumer.as_str(),
115 Some(self.count),
116 Some(duration_to_millis(self.block)),
117 false,
118 self.key.as_str(),
119 ">",
120 )
121 .await
122 .map_err(RedisError::stream)?;
123 let entries = resp
124 .into_iter()
125 .find(|(key, _)| key == &self.key)
126 .map(|(_, entries)| entries)
127 .unwrap_or_default();
128 Ok(entries
129 .into_iter()
130 .map(|(id, fields)| (id, fields.into_iter().collect()))
131 .collect())
132 }
133
134 async fn fetch_reclaim(&mut self, min_idle: Duration) -> Result<Vec<Entry>, RedisError> {
135 let (cursor, entries): (String, Vec<XReadValue<String, String, Vec<u8>>>) = self
136 .pool
137 .xautoclaim_values(
138 self.key.as_str(),
139 self.group.as_str(),
140 self.consumer.as_str(),
141 duration_to_millis(min_idle),
142 self.cursor.as_str(),
143 Some(self.count),
144 false,
145 )
146 .await
147 .map_err(RedisError::stream)?;
148 self.cursor = cursor;
149 if entries.is_empty() {
151 tokio::time::sleep(self.block).await;
152 }
153 Ok(entries)
154 }
155}
156
157impl Subscriber for RedisSubscriber {
158 type Message = RedisMessage;
159 type Error = RedisError;
160
161 fn stream(&mut self) -> impl Stream<Item = Result<Self::Message, Self::Error>> + Send + '_ {
169 unfold(self, |s| async move {
170 loop {
171 if let Some((id, fields)) = s.buffer.pop_front() {
172 return Some((Ok(s.message(id, fields)), s));
173 }
174 if let Err(err) = s.fetch().await {
176 return Some((Err(err), s));
177 }
178 }
179 })
180 }
181}
182
183impl BatchSubscriber for RedisSubscriber {
184 type Batch = Vec<RedisMessage>;
185
186 fn batches(&mut self) -> impl Stream<Item = Result<Self::Batch, Self::Error>> + Send + '_ {
194 unfold(self, |s| async move {
195 loop {
196 if !s.buffer.is_empty() {
197 let entries = std::mem::take(&mut s.buffer);
200 let batch = entries
201 .into_iter()
202 .map(|(id, fields)| s.message(id, fields))
203 .collect::<Vec<_>>();
204 return Some((Ok(batch), s));
205 }
206 if let Err(err) = s.fetch().await {
207 return Some((Err(err), s));
208 }
209 }
210 })
211 }
212}