1use std::collections::{HashMap, HashSet};
2
3use combine::{Parser, choice, many1, optional};
4use redis_protocol::resp3;
5use redis_protocol::resp3::types::BytesFrame;
6use sierradb::bucket::PartitionId;
7use sierradb_cluster::subscription::{FromSequences, Subscribe, SubscriptionMatcher};
8use tokio::io::{self, AsyncWriteExt};
9use tokio::sync::{mpsc, watch};
10use tracing::debug;
11use uuid::Uuid;
12
13use crate::error::AsRedisError;
14use crate::parser::{
15 FrameStream, all_selector, keyword, number_u64, number_u64_min, partition_id,
16 partition_id_sequence, partition_ids,
17};
18use crate::request::{HandleRequest, number, simple_str};
19use crate::server::Conn;
20
21#[derive(Debug)]
47pub struct EPSub {
48 pub matcher: SubscriptionMatcher,
49 pub window_size: Option<u64>,
50}
51
52impl EPSub {
53 pub fn parser<'a>() -> impl Parser<FrameStream<'a>, Output = EPSub> + 'a {
54 (
55 Selector::parser(),
56 optional(from_sequences()),
57 optional(window()),
58 )
59 .map(|(selector, from_sequences, window_size)| {
60 let matcher = match selector {
61 Selector::All => SubscriptionMatcher::AllPartitions {
62 from_sequences: from_sequences.unwrap_or(FromSequences::Latest),
63 },
64 Selector::Partition(partition_id) => match from_sequences {
65 Some(FromSequences::Latest) => SubscriptionMatcher::Partition {
66 partition_id,
67 from_sequence: None,
68 },
69 Some(FromSequences::Partitions {
70 from_sequences,
71 fallback,
72 }) => SubscriptionMatcher::Partition {
73 partition_id,
74 from_sequence: from_sequences.get(&partition_id).copied().or(fallback),
75 },
76 Some(FromSequences::AllPartitions(sequence)) => {
77 SubscriptionMatcher::Partition {
78 partition_id,
79 from_sequence: Some(sequence),
80 }
81 }
82 None => SubscriptionMatcher::Partition {
83 partition_id,
84 from_sequence: None,
85 },
86 },
87 Selector::Partitions(partition_ids) => SubscriptionMatcher::Partitions {
88 partition_ids,
89 from_sequences: from_sequences.unwrap_or(FromSequences::Latest),
90 },
91 };
92 EPSub {
93 matcher,
94 window_size,
95 }
96 })
97 }
98}
99
100enum Selector {
101 All,
102 Partition(PartitionId),
103 Partitions(HashSet<PartitionId>),
104}
105
106impl Selector {
107 fn parser<'a>() -> impl Parser<FrameStream<'a>, Output = Self> + 'a {
108 choice!(
109 all_selector().map(|_| Selector::All),
110 partition_id().map(Selector::Partition),
111 partition_ids().map(Selector::Partitions)
112 )
113 }
114}
115
116fn from_sequences<'a>() -> impl Parser<FrameStream<'a>, Output = FromSequences> + 'a {
119 let latest = keyword("LATEST").map(|_| FromSequences::Latest);
120 let sequence = number_u64().map(FromSequences::AllPartitions);
121 let map = (keyword("MAP").with((
122 many1::<HashMap<_, _>, _, _>(partition_id_sequence()),
123 optional(keyword("DEFAULT").with(number_u64())),
124 )))
125 .map(|(from_sequences, fallback)| FromSequences::Partitions {
126 from_sequences,
127 fallback,
128 });
129
130 keyword("FROM").with(choice((latest, sequence, map)))
131}
132
133fn window<'a>() -> impl Parser<FrameStream<'a>, Output = u64> + 'a {
134 keyword("WINDOW").with(number_u64_min(1))
135}
136
137impl HandleRequest for EPSub {
138 type Error = String;
139 type Ok = BytesFrame;
140
141 async fn handle_request(self, conn: &mut Conn) -> Result<Option<Self::Ok>, Self::Error> {
142 let sender = match conn
143 .subscription_channel
144 .as_ref()
145 .and_then(|(weak_sender, _)| weak_sender.upgrade())
146 {
147 Some(sender) => sender,
148 None => {
149 let (sender, receiver) = mpsc::unbounded_channel();
150 conn.subscription_channel = Some((sender.downgrade(), receiver));
151 sender
152 }
153 };
154
155 let subscription_id = Uuid::new_v4();
156 let (last_ack_tx, last_ack_rx) = watch::channel(None);
157 conn.cluster_ref
158 .ask(Subscribe {
159 subscription_id,
160 matcher: self.matcher,
161 last_ack_rx,
162 update_tx: sender,
163 window_size: self.window_size.unwrap_or(1_000),
164 })
165 .await
166 .map_err(|err| {
167 err.map_err::<&'static str, _>(|_| unreachable!("infallible error"))
168 .as_redis_error()
169 })?;
170
171 conn.subscriptions.insert(subscription_id, last_ack_tx);
172
173 debug!(
174 subscription_id = %subscription_id,
175 "created subscription"
176 );
177
178 Ok(Some(simple_str(subscription_id.to_string())))
179 }
180
181 async fn handle_request_failable(
182 self,
183 conn: &mut Conn,
184 ) -> Result<Option<BytesFrame>, io::Error> {
185 let subscription_id = match self.handle_request(conn).await {
186 Ok(Some(subscription_id)) => subscription_id,
187 Ok(None) => unreachable!("always returns some"),
188 Err(err) => {
189 return Ok(Some(BytesFrame::SimpleError {
190 data: err.into(),
191 attributes: None,
192 }));
193 }
194 };
195
196 resp3::encode::complete::extend_encode(&mut conn.write, &subscription_id, false)
197 .map_err(io::Error::other)?;
198
199 resp3::encode::complete::extend_encode(
200 &mut conn.write,
201 &BytesFrame::Push {
202 data: vec![
203 simple_str("subscribe"),
204 subscription_id.clone(),
205 number(conn.subscriptions.len() as i64),
206 ],
207 attributes: None,
208 },
209 false,
210 )
211 .map_err(io::Error::other)?;
212
213 conn.stream.write_all(&conn.write).await?;
214 conn.stream.flush().await?;
215 conn.write.clear();
216
217 Ok(None)
218 }
219}