Skip to main content

sierradb_server/request/
epsub.rs

1use std::collections::{HashMap, HashSet};
2
3use combine::{Parser, choice, many1, optional};
4use redis_protocol::resp3;
5use redis_protocol::resp3::types::BytesFrame;
6use sierradb::bucket::PartitionId;
7use sierradb_cluster::subscription::{FromSequences, Subscribe, SubscriptionMatcher};
8use tokio::io::{self, AsyncWriteExt};
9use tokio::sync::{mpsc, watch};
10use tracing::debug;
11use uuid::Uuid;
12
13use crate::error::AsRedisError;
14use crate::parser::{
15    FrameStream, all_selector, keyword, number_u64, number_u64_min, partition_id,
16    partition_id_sequence, partition_ids,
17};
18use crate::request::{HandleRequest, number, simple_str};
19use crate::server::Conn;
20
21/// Subscribe to events from one or more partitions.
22///
23/// # Syntax
24/// ```text
25/// # All partitions
26/// EPSUB * [FROM LATEST | FROM <sequence> | FROM MAP <p1>=<s1> <p2>=<s2>... [DEFAULT <seq>]] [WINDOW <size>]
27///
28/// # Single partition
29/// EPSUB <partition_id> [FROM <sequence>] [WINDOW <size>]
30///
31/// # Multiple partitions
32/// EPSUB <p1>,<p2>,<p3> [FROM LATEST | FROM <sequence> | FROM MAP <p1>=<s1> <p2>=<s2>... [DEFAULT <seq>]] [WINDOW <size>]
33/// ```
34///
35/// # Examples
36/// ```text
37/// EPSUB *                                               # All partitions, latest, no window
38/// EPSUB * WINDOW 100                                    # All partitions, latest, window 100
39/// EPSUB * FROM 1000 WINDOW 100                          # All partitions, from seq 1000, window 100
40/// EPSUB 5 FROM 100 WINDOW 50                            # Partition 5, from seq 100, window 50
41/// EPSUB 1,2,3 FROM MAP 1=100 2=200 DEFAULT 0 WINDOW 500
42/// ```
43///
44/// **Note:** Establishes a persistent connection to receive real-time events
45/// from the specified partitions.
46#[derive(Debug)]
47pub struct EPSub {
48    pub matcher: SubscriptionMatcher,
49    pub window_size: Option<u64>,
50}
51
52impl EPSub {
53    pub fn parser<'a>() -> impl Parser<FrameStream<'a>, Output = EPSub> + 'a {
54        (
55            Selector::parser(),
56            optional(from_sequences()),
57            optional(window()),
58        )
59            .map(|(selector, from_sequences, window_size)| {
60                let matcher = match selector {
61                    Selector::All => SubscriptionMatcher::AllPartitions {
62                        from_sequences: from_sequences.unwrap_or(FromSequences::Latest),
63                    },
64                    Selector::Partition(partition_id) => match from_sequences {
65                        Some(FromSequences::Latest) => SubscriptionMatcher::Partition {
66                            partition_id,
67                            from_sequence: None,
68                        },
69                        Some(FromSequences::Partitions {
70                            from_sequences,
71                            fallback,
72                        }) => SubscriptionMatcher::Partition {
73                            partition_id,
74                            from_sequence: from_sequences.get(&partition_id).copied().or(fallback),
75                        },
76                        Some(FromSequences::AllPartitions(sequence)) => {
77                            SubscriptionMatcher::Partition {
78                                partition_id,
79                                from_sequence: Some(sequence),
80                            }
81                        }
82                        None => SubscriptionMatcher::Partition {
83                            partition_id,
84                            from_sequence: None,
85                        },
86                    },
87                    Selector::Partitions(partition_ids) => SubscriptionMatcher::Partitions {
88                        partition_ids,
89                        from_sequences: from_sequences.unwrap_or(FromSequences::Latest),
90                    },
91                };
92                EPSub {
93                    matcher,
94                    window_size,
95                }
96            })
97    }
98}
99
100enum Selector {
101    All,
102    Partition(PartitionId),
103    Partitions(HashSet<PartitionId>),
104}
105
106impl Selector {
107    fn parser<'a>() -> impl Parser<FrameStream<'a>, Output = Self> + 'a {
108        choice!(
109            all_selector().map(|_| Selector::All),
110            partition_id().map(Selector::Partition),
111            partition_ids().map(Selector::Partitions)
112        )
113    }
114}
115
116// [FROM LATEST | FROM <sequence> | FROM MAP <p1>=<s1> <p2>=<s2>... [DEFAULT
117// <seq>]]
118fn from_sequences<'a>() -> impl Parser<FrameStream<'a>, Output = FromSequences> + 'a {
119    let latest = keyword("LATEST").map(|_| FromSequences::Latest);
120    let sequence = number_u64().map(FromSequences::AllPartitions);
121    let map = (keyword("MAP").with((
122        many1::<HashMap<_, _>, _, _>(partition_id_sequence()),
123        optional(keyword("DEFAULT").with(number_u64())),
124    )))
125    .map(|(from_sequences, fallback)| FromSequences::Partitions {
126        from_sequences,
127        fallback,
128    });
129
130    keyword("FROM").with(choice((latest, sequence, map)))
131}
132
133fn window<'a>() -> impl Parser<FrameStream<'a>, Output = u64> + 'a {
134    keyword("WINDOW").with(number_u64_min(1))
135}
136
137impl HandleRequest for EPSub {
138    type Error = String;
139    type Ok = BytesFrame;
140
141    async fn handle_request(self, conn: &mut Conn) -> Result<Option<Self::Ok>, Self::Error> {
142        let sender = match conn
143            .subscription_channel
144            .as_ref()
145            .and_then(|(weak_sender, _)| weak_sender.upgrade())
146        {
147            Some(sender) => sender,
148            None => {
149                let (sender, receiver) = mpsc::unbounded_channel();
150                conn.subscription_channel = Some((sender.downgrade(), receiver));
151                sender
152            }
153        };
154
155        let subscription_id = Uuid::new_v4();
156        let (last_ack_tx, last_ack_rx) = watch::channel(None);
157        conn.cluster_ref
158            .ask(Subscribe {
159                subscription_id,
160                matcher: self.matcher,
161                last_ack_rx,
162                update_tx: sender,
163                window_size: self.window_size.unwrap_or(1_000),
164            })
165            .await
166            .map_err(|err| {
167                err.map_err::<&'static str, _>(|_| unreachable!("infallible error"))
168                    .as_redis_error()
169            })?;
170
171        conn.subscriptions.insert(subscription_id, last_ack_tx);
172
173        debug!(
174            subscription_id = %subscription_id,
175            "created subscription"
176        );
177
178        Ok(Some(simple_str(subscription_id.to_string())))
179    }
180
181    async fn handle_request_failable(
182        self,
183        conn: &mut Conn,
184    ) -> Result<Option<BytesFrame>, io::Error> {
185        let subscription_id = match self.handle_request(conn).await {
186            Ok(Some(subscription_id)) => subscription_id,
187            Ok(None) => unreachable!("always returns some"),
188            Err(err) => {
189                return Ok(Some(BytesFrame::SimpleError {
190                    data: err.into(),
191                    attributes: None,
192                }));
193            }
194        };
195
196        resp3::encode::complete::extend_encode(&mut conn.write, &subscription_id, false)
197            .map_err(io::Error::other)?;
198
199        resp3::encode::complete::extend_encode(
200            &mut conn.write,
201            &BytesFrame::Push {
202                data: vec![
203                    simple_str("subscribe"),
204                    subscription_id.clone(),
205                    number(conn.subscriptions.len() as i64),
206                ],
207                attributes: None,
208            },
209            false,
210        )
211        .map_err(io::Error::other)?;
212
213        conn.stream.write_all(&conn.write).await?;
214        conn.stream.flush().await?;
215        conn.write.clear();
216
217        Ok(None)
218    }
219}