1use std::collections::{HashMap, HashSet};
2
3use combine::{Parser, choice, many1, optional};
4use redis_protocol::resp3;
5use redis_protocol::resp3::types::BytesFrame;
6use sierradb::StreamId;
7use sierradb::id::NAMESPACE_PARTITION_KEY;
8use sierradb_cluster::subscription::{FromVersions, Subscribe, SubscriptionMatcher};
9use tokio::io::{self, AsyncWriteExt};
10use tokio::sync::{mpsc, watch};
11use tracing::debug;
12use uuid::Uuid;
13
14use crate::error::AsRedisError;
15use crate::parser::{
16 FrameStream, keyword, number_u64, number_u64_min, partition_key, stream_id, stream_id_version,
17};
18use crate::request::{HandleRequest, number, simple_str};
19use crate::server::Conn;
20
21pub struct ESub {
58 pub matcher: SubscriptionMatcher,
59 pub window_size: Option<u64>,
60}
61
62impl ESub {
63 pub fn parser<'a>() -> impl Parser<FrameStream<'a>, Output = ESub> + 'a {
64 (
65 Selector::parser(),
66 optional(from_versions()),
67 optional(window()),
68 )
69 .map(|(selector, from_versions, window_size)| {
70 let matcher = match selector {
71 Selector::StreamId {
72 stream_id,
73 partition_key,
74 } => {
75 let partition_key = partition_key.unwrap_or_else(|| {
76 Uuid::new_v5(&NAMESPACE_PARTITION_KEY, stream_id.as_bytes())
77 });
78 match from_versions {
79 Some(FromVersionsArg::Latest) | None => SubscriptionMatcher::Stream {
80 partition_key,
81 stream_id,
82 from_version: None,
83 },
84 Some(FromVersionsArg::Streams(from_versions)) => {
85 SubscriptionMatcher::Stream {
86 partition_key,
87 from_version: from_versions.get(&stream_id).copied(),
88 stream_id,
89 }
90 }
91 Some(FromVersionsArg::AllStreams(from_version)) => {
92 SubscriptionMatcher::Stream {
93 partition_key,
94 stream_id,
95 from_version: Some(from_version),
96 }
97 }
98 }
99 }
100 Selector::StreamIds(stream_ids) => {
101 let stream_ids: HashSet<_> = stream_ids
102 .into_iter()
103 .map(|(stream_id, partition_key)| {
104 let partition_key = partition_key.unwrap_or_else(|| {
105 Uuid::new_v5(&NAMESPACE_PARTITION_KEY, stream_id.as_bytes())
106 });
107 (partition_key, stream_id)
108 })
109 .collect();
110 SubscriptionMatcher::Streams {
111 from_versions: match from_versions {
112 Some(FromVersionsArg::Latest) | None => FromVersions::Latest,
113 Some(FromVersionsArg::Streams(from_versions)) => {
114 FromVersions::Streams(
115 from_versions
116 .into_iter()
117 .filter_map(|(stream_id, version)| {
118 let (partition_key, _) = stream_ids
119 .iter()
120 .find(|(_, sid)| sid == &stream_id)?;
121 Some(((*partition_key, stream_id), version))
122 })
123 .collect(),
124 )
125 }
126 Some(FromVersionsArg::AllStreams(from_version)) => {
127 FromVersions::AllStreams(from_version)
128 }
129 },
130 stream_ids,
131 }
132 }
133 };
134 ESub {
135 matcher,
136 window_size,
137 }
138 })
139 }
140}
141
142enum Selector {
143 StreamId {
144 stream_id: StreamId,
145 partition_key: Option<Uuid>,
146 },
147 StreamIds(HashSet<(StreamId, Option<Uuid>)>),
148}
149
150impl Selector {
151 fn parser<'a>() -> impl Parser<FrameStream<'a>, Output = Self> + 'a {
153 many1::<HashSet<_>, _, _>((
154 stream_id(),
155 optional(keyword("PARTITION_KEY").with(partition_key())),
156 ))
157 .map(|stream_ids| {
158 if stream_ids.len() == 1 {
159 let (stream_id, partition_key) =
161 unsafe { stream_ids.into_iter().next().unwrap_unchecked() };
162 return Selector::StreamId {
163 stream_id,
164 partition_key,
165 };
166 }
167
168 Selector::StreamIds(stream_ids)
169 })
170 }
171}
172
173pub enum FromVersionsArg {
174 Latest,
175 Streams(HashMap<StreamId, u64>),
176 AllStreams(u64),
177}
178
179fn from_versions<'a>() -> impl Parser<FrameStream<'a>, Output = FromVersionsArg> + 'a {
181 let latest = keyword("LATEST").map(|_| FromVersionsArg::Latest);
182 let sequence = number_u64().map(FromVersionsArg::AllStreams);
183 let map = (keyword("MAP").with(many1::<HashMap<_, _>, _, _>(stream_id_version())))
184 .map(FromVersionsArg::Streams);
185
186 keyword("FROM").with(choice((latest, sequence, map)))
187}
188
189fn window<'a>() -> impl Parser<FrameStream<'a>, Output = u64> + 'a {
190 keyword("WINDOW").with(number_u64_min(1))
191}
192
193impl HandleRequest for ESub {
194 type Error = String;
195 type Ok = BytesFrame;
196
197 async fn handle_request(self, conn: &mut Conn) -> Result<Option<Self::Ok>, Self::Error> {
198 let sender = match conn
199 .subscription_channel
200 .as_ref()
201 .and_then(|(weak_sender, _)| weak_sender.upgrade())
202 {
203 Some(sender) => sender,
204 None => {
205 let (sender, receiver) = mpsc::unbounded_channel();
206 conn.subscription_channel = Some((sender.downgrade(), receiver));
207 sender
208 }
209 };
210
211 let subscription_id = Uuid::new_v4();
212 let (last_ack_tx, last_ack_rx) = watch::channel(None);
213 conn.cluster_ref
214 .ask(Subscribe {
215 subscription_id,
216 matcher: self.matcher,
217 last_ack_rx,
218 update_tx: sender,
219 window_size: self.window_size.unwrap_or(1_000),
220 })
221 .await
222 .map_err(|err| {
223 err.map_err::<&'static str, _>(|_| unreachable!("infallible error"))
224 .as_redis_error()
225 })?;
226
227 conn.subscriptions.insert(subscription_id, last_ack_tx);
228
229 debug!(
230 subscription_id = %subscription_id,
231 "created subscription"
232 );
233
234 Ok(Some(simple_str(subscription_id.to_string())))
235 }
236
237 async fn handle_request_failable(
238 self,
239 conn: &mut Conn,
240 ) -> Result<Option<BytesFrame>, io::Error> {
241 let subscription_id = match self.handle_request(conn).await {
242 Ok(Some(subscription_id)) => subscription_id,
243 Ok(None) => unreachable!("always returns some"),
244 Err(err) => {
245 return Ok(Some(BytesFrame::SimpleError {
246 data: err.into(),
247 attributes: None,
248 }));
249 }
250 };
251
252 resp3::encode::complete::extend_encode(&mut conn.write, &subscription_id, false)
253 .map_err(io::Error::other)?;
254
255 resp3::encode::complete::extend_encode(
256 &mut conn.write,
257 &BytesFrame::Push {
258 data: vec![
259 simple_str("subscribe"),
260 subscription_id,
261 number(conn.subscriptions.len() as i64),
262 ],
263 attributes: None,
264 },
265 false,
266 )
267 .map_err(io::Error::other)?;
268
269 conn.stream.write_all(&conn.write).await?;
270 conn.stream.flush().await?;
271 conn.write.clear();
272
273 Ok(None)
274 }
275}