Skip to main content

sierradb_server/request/
esub.rs

1use std::collections::{HashMap, HashSet};
2
3use combine::{Parser, choice, many1, optional};
4use redis_protocol::resp3;
5use redis_protocol::resp3::types::BytesFrame;
6use sierradb::StreamId;
7use sierradb::id::NAMESPACE_PARTITION_KEY;
8use sierradb_cluster::subscription::{FromVersions, Subscribe, SubscriptionMatcher};
9use tokio::io::{self, AsyncWriteExt};
10use tokio::sync::{mpsc, watch};
11use tracing::debug;
12use uuid::Uuid;
13
14use crate::error::AsRedisError;
15use crate::parser::{
16    FrameStream, keyword, number_u64, number_u64_min, partition_key, stream_id, stream_id_version,
17};
18use crate::request::{HandleRequest, number, simple_str};
19use crate::server::Conn;
20
21/// Subscribe to events from one or more streams.
22///
23/// # Syntax
24/// ```text
25/// # Single stream
26/// ESUB <stream_id> [PARTITION_KEY <partition_key>] [FROM <version>] [WINDOW <size>]
27///
28/// # Multiple streams
29/// ESUB <stream_id_1> [PARTITION_KEY <pk_1>] <stream_id_2> [PARTITION_KEY <pk_2>] ... [FROM LATEST | FROM <version> | FROM MAP <stream>=<ver>...] [WINDOW <size>]
30/// ```
31///
32/// # Examples
33/// ```text
34/// ESUB user-123                                         # Single stream, latest, no window
35/// ESUB user-123 WINDOW 100                              # Single stream, latest, window 100
36/// ESUB user-123 FROM 50 WINDOW 100                      # Single stream, from version 50, window 100
37/// ESUB user-123 PARTITION_KEY abc-def FROM 50           # Single stream with partition key
38/// ESUB user-123 PARTITION_KEY abc-def FROM 50 WINDOW 100  # With partition key and window
39///
40/// # Multiple streams
41/// ESUB user-1 user-2 user-3                             # Multiple streams, latest, no window
42/// ESUB user-1 user-2 user-3 WINDOW 500                  # Multiple streams with window
43/// ESUB user-1 user-2 user-3 FROM LATEST WINDOW 500      # Explicit latest with window
44/// ESUB user-1 user-2 user-3 FROM 100 WINDOW 500         # All from version 100
45///
46/// # Multiple streams with partition keys
47/// ESUB user-1 PARTITION_KEY abc user-2 PARTITION_KEY def user-3 PARTITION_KEY ghi FROM LATEST WINDOW 100
48/// ESUB user-1 PARTITION_KEY abc user-2 user-3 PARTITION_KEY ghi FROM LATEST WINDOW 100  # Mixed
49///
50/// # Multiple streams with per-stream versions
51/// ESUB user-1 user-2 user-3 FROM MAP user-1=10 user-2=20 user-3=30 WINDOW 50
52/// ESUB stream1 PARTITION_KEY pk1 stream2 stream3 PARTITION_KEY pk3 FROM MAP stream1=10 stream2=20 stream3=30 WINDOW 50
53/// ```
54///
55/// **Note:** Establishes a persistent connection to receive real-time stream
56/// events.
57pub struct ESub {
58    pub matcher: SubscriptionMatcher,
59    pub window_size: Option<u64>,
60}
61
62impl ESub {
63    pub fn parser<'a>() -> impl Parser<FrameStream<'a>, Output = ESub> + 'a {
64        (
65            Selector::parser(),
66            optional(from_versions()),
67            optional(window()),
68        )
69            .map(|(selector, from_versions, window_size)| {
70                let matcher = match selector {
71                    Selector::StreamId {
72                        stream_id,
73                        partition_key,
74                    } => {
75                        let partition_key = partition_key.unwrap_or_else(|| {
76                            Uuid::new_v5(&NAMESPACE_PARTITION_KEY, stream_id.as_bytes())
77                        });
78                        match from_versions {
79                            Some(FromVersionsArg::Latest) | None => SubscriptionMatcher::Stream {
80                                partition_key,
81                                stream_id,
82                                from_version: None,
83                            },
84                            Some(FromVersionsArg::Streams(from_versions)) => {
85                                SubscriptionMatcher::Stream {
86                                    partition_key,
87                                    from_version: from_versions.get(&stream_id).copied(),
88                                    stream_id,
89                                }
90                            }
91                            Some(FromVersionsArg::AllStreams(from_version)) => {
92                                SubscriptionMatcher::Stream {
93                                    partition_key,
94                                    stream_id,
95                                    from_version: Some(from_version),
96                                }
97                            }
98                        }
99                    }
100                    Selector::StreamIds(stream_ids) => {
101                        let stream_ids: HashSet<_> = stream_ids
102                            .into_iter()
103                            .map(|(stream_id, partition_key)| {
104                                let partition_key = partition_key.unwrap_or_else(|| {
105                                    Uuid::new_v5(&NAMESPACE_PARTITION_KEY, stream_id.as_bytes())
106                                });
107                                (partition_key, stream_id)
108                            })
109                            .collect();
110                        SubscriptionMatcher::Streams {
111                            from_versions: match from_versions {
112                                Some(FromVersionsArg::Latest) | None => FromVersions::Latest,
113                                Some(FromVersionsArg::Streams(from_versions)) => {
114                                    FromVersions::Streams(
115                                        from_versions
116                                            .into_iter()
117                                            .filter_map(|(stream_id, version)| {
118                                                let (partition_key, _) = stream_ids
119                                                    .iter()
120                                                    .find(|(_, sid)| sid == &stream_id)?;
121                                                Some(((*partition_key, stream_id), version))
122                                            })
123                                            .collect(),
124                                    )
125                                }
126                                Some(FromVersionsArg::AllStreams(from_version)) => {
127                                    FromVersions::AllStreams(from_version)
128                                }
129                            },
130                            stream_ids,
131                        }
132                    }
133                };
134                ESub {
135                    matcher,
136                    window_size,
137                }
138            })
139    }
140}
141
142enum Selector {
143    StreamId {
144        stream_id: StreamId,
145        partition_key: Option<Uuid>,
146    },
147    StreamIds(HashSet<(StreamId, Option<Uuid>)>),
148}
149
150impl Selector {
151    // <stream_id_1> [PARTITION_KEY <pk_1>] <stream_id_2> [PARTITION_KEY <pk_2>]
152    fn parser<'a>() -> impl Parser<FrameStream<'a>, Output = Self> + 'a {
153        many1::<HashSet<_>, _, _>((
154            stream_id(),
155            optional(keyword("PARTITION_KEY").with(partition_key())),
156        ))
157        .map(|stream_ids| {
158            if stream_ids.len() == 1 {
159                // SAFETY: We just verified the set has exactly one element
160                let (stream_id, partition_key) =
161                    unsafe { stream_ids.into_iter().next().unwrap_unchecked() };
162                return Selector::StreamId {
163                    stream_id,
164                    partition_key,
165                };
166            }
167
168            Selector::StreamIds(stream_ids)
169        })
170    }
171}
172
173pub enum FromVersionsArg {
174    Latest,
175    Streams(HashMap<StreamId, u64>),
176    AllStreams(u64),
177}
178
179// FROM LATEST | FROM <version> | FROM MAP <stream>=<ver>...
180fn from_versions<'a>() -> impl Parser<FrameStream<'a>, Output = FromVersionsArg> + 'a {
181    let latest = keyword("LATEST").map(|_| FromVersionsArg::Latest);
182    let sequence = number_u64().map(FromVersionsArg::AllStreams);
183    let map = (keyword("MAP").with(many1::<HashMap<_, _>, _, _>(stream_id_version())))
184        .map(FromVersionsArg::Streams);
185
186    keyword("FROM").with(choice((latest, sequence, map)))
187}
188
189fn window<'a>() -> impl Parser<FrameStream<'a>, Output = u64> + 'a {
190    keyword("WINDOW").with(number_u64_min(1))
191}
192
193impl HandleRequest for ESub {
194    type Error = String;
195    type Ok = BytesFrame;
196
197    async fn handle_request(self, conn: &mut Conn) -> Result<Option<Self::Ok>, Self::Error> {
198        let sender = match conn
199            .subscription_channel
200            .as_ref()
201            .and_then(|(weak_sender, _)| weak_sender.upgrade())
202        {
203            Some(sender) => sender,
204            None => {
205                let (sender, receiver) = mpsc::unbounded_channel();
206                conn.subscription_channel = Some((sender.downgrade(), receiver));
207                sender
208            }
209        };
210
211        let subscription_id = Uuid::new_v4();
212        let (last_ack_tx, last_ack_rx) = watch::channel(None);
213        conn.cluster_ref
214            .ask(Subscribe {
215                subscription_id,
216                matcher: self.matcher,
217                last_ack_rx,
218                update_tx: sender,
219                window_size: self.window_size.unwrap_or(1_000),
220            })
221            .await
222            .map_err(|err| {
223                err.map_err::<&'static str, _>(|_| unreachable!("infallible error"))
224                    .as_redis_error()
225            })?;
226
227        conn.subscriptions.insert(subscription_id, last_ack_tx);
228
229        debug!(
230            subscription_id = %subscription_id,
231            "created subscription"
232        );
233
234        Ok(Some(simple_str(subscription_id.to_string())))
235    }
236
237    async fn handle_request_failable(
238        self,
239        conn: &mut Conn,
240    ) -> Result<Option<BytesFrame>, io::Error> {
241        let subscription_id = match self.handle_request(conn).await {
242            Ok(Some(subscription_id)) => subscription_id,
243            Ok(None) => unreachable!("always returns some"),
244            Err(err) => {
245                return Ok(Some(BytesFrame::SimpleError {
246                    data: err.into(),
247                    attributes: None,
248                }));
249            }
250        };
251
252        resp3::encode::complete::extend_encode(&mut conn.write, &subscription_id, false)
253            .map_err(io::Error::other)?;
254
255        resp3::encode::complete::extend_encode(
256            &mut conn.write,
257            &BytesFrame::Push {
258                data: vec![
259                    simple_str("subscribe"),
260                    subscription_id,
261                    number(conn.subscriptions.len() as i64),
262                ],
263                attributes: None,
264            },
265            false,
266        )
267        .map_err(io::Error::other)?;
268
269        conn.stream.write_all(&conn.write).await?;
270        conn.stream.flush().await?;
271        conn.write.clear();
272
273        Ok(None)
274    }
275}