1use std::{sync::Arc, time::Duration};
2
3use bytesize::ByteSize;
4use dashmap::DashMap;
5use enum_ordinalize::Ordinalize;
6use s2_common::{
7 record::{NonZeroSeqNum, SeqNum, StreamPosition},
8 types::{
9 basin::BasinName,
10 config::{BasinConfig, OptionalStreamConfig},
11 resources::CreateMode,
12 stream::StreamName,
13 },
14};
15use slatedb::config::{DurabilityLevel, ScanOptions};
16use tokio::{
17 sync::{Notify, broadcast},
18 time::Instant,
19};
20
21use super::{
22 error::{
23 BasinDeletionPendingError, BasinNotFoundError, CreateStreamError, GetBasinConfigError,
24 StorageError, StreamDeletionPendingError, StreamNotFoundError, StreamerError,
25 TransactionConflictError,
26 },
27 kv,
28 stream_id::StreamId,
29 streamer::{StreamerClient, StreamerClientState},
30};
31use crate::backend::bgtasks::BgtaskTrigger;
32
33#[derive(Clone)]
34pub struct Backend {
35 pub(super) db: slatedb::Db,
36 client_states: Arc<DashMap<StreamId, StreamerClientState>>,
37 append_inflight_max: ByteSize,
38 bgtask_trigger_tx: broadcast::Sender<BgtaskTrigger>,
39}
40
41impl Backend {
42 const FAILED_INIT_MEMORY: Duration = Duration::from_secs(1);
43
44 pub fn new(db: slatedb::Db, append_inflight_max: ByteSize) -> Self {
45 let (bgtask_trigger_tx, _) = broadcast::channel(16);
46 Self {
47 db,
48 client_states: Arc::new(DashMap::new()),
49 append_inflight_max,
50 bgtask_trigger_tx,
51 }
52 }
53
54 pub(super) fn bgtask_trigger(&self, trigger: BgtaskTrigger) {
55 let _ = self.bgtask_trigger_tx.send(trigger);
56 }
57
58 pub(super) fn bgtask_trigger_subscribe(&self) -> broadcast::Receiver<BgtaskTrigger> {
59 self.bgtask_trigger_tx.subscribe()
60 }
61
62 async fn start_streamer(
63 &self,
64 basin: BasinName,
65 stream: StreamName,
66 ) -> Result<StreamerClient, StreamerError> {
67 let stream_id = StreamId::new(&basin, &stream);
68
69 let (meta, tail_pos, fencing_token, trim_point) = tokio::try_join!(
70 self.db_get(
71 kv::stream_meta::ser_key(&basin, &stream),
72 kv::stream_meta::deser_value,
73 ),
74 self.db_get(
75 kv::stream_tail_position::ser_key(stream_id),
76 kv::stream_tail_position::deser_value,
77 ),
78 self.db_get(
79 kv::stream_fencing_token::ser_key(stream_id),
80 kv::stream_fencing_token::deser_value,
81 ),
82 self.db_get(
83 kv::stream_trim_point::ser_key(stream_id),
84 kv::stream_trim_point::deser_value,
85 )
86 )?;
87
88 let Some(meta) = meta else {
89 return Err(StreamNotFoundError { basin, stream }.into());
90 };
91
92 let tail_pos = tail_pos.map(|(pos, _)| pos).unwrap_or(StreamPosition::MIN);
93 self.assert_no_records_following_tail(stream_id, &basin, &stream, tail_pos)
94 .await?;
95
96 let fencing_token = fencing_token.unwrap_or_default();
97
98 if trim_point == Some(..NonZeroSeqNum::MAX) {
99 return Err(StreamDeletionPendingError { basin, stream }.into());
100 }
101
102 let client_states = self.client_states.clone();
103 Ok(super::streamer::Spawner {
104 db: self.db.clone(),
105 stream_id,
106 config: meta.config,
107 tail_pos,
108 fencing_token,
109 trim_point: ..trim_point.map_or(SeqNum::MIN, |tp| tp.end.get()),
110 append_inflight_max: self.append_inflight_max,
111 bgtask_trigger_tx: self.bgtask_trigger_tx.clone(),
112 }
113 .spawn(move |client_id| {
114 client_states.remove_if(&stream_id, |_, state| {
115 matches!(state, StreamerClientState::Ready { client } if client.id() == client_id)
116 });
117 }))
118 }
119
120 async fn assert_no_records_following_tail(
121 &self,
122 stream_id: StreamId,
123 basin: &BasinName,
124 stream: &StreamName,
125 tail_pos: StreamPosition,
126 ) -> Result<(), StorageError> {
127 let start_key = kv::stream_record_data::ser_key(
128 stream_id,
129 StreamPosition {
130 seq_num: tail_pos.seq_num,
131 timestamp: 0,
132 },
133 );
134 static SCAN_OPTS: ScanOptions = ScanOptions {
135 durability_filter: DurabilityLevel::Remote,
136 dirty: false,
137 read_ahead_bytes: 1,
138 cache_blocks: false,
139 max_fetch_tasks: 1,
140 };
141 let mut it = self.db.scan_with_options(start_key.., &SCAN_OPTS).await?;
142 let Some(kv) = it.next().await? else {
143 return Ok(());
144 };
145 if kv.key.first().copied() != Some(kv::KeyType::StreamRecordData.ordinal()) {
146 return Ok(());
147 }
148 let (deser_stream_id, pos) = kv::stream_record_data::deser_key(kv.key)?;
149 assert!(
150 deser_stream_id != stream_id,
151 "invariant violation: stream `{basin}/{stream}` tail_pos {tail_pos:?} but found record at {pos:?}"
152 );
153 Ok(())
154 }
155
156 fn streamer_client_state(&self, basin: &BasinName, stream: &StreamName) -> StreamerClientState {
157 match self.client_states.entry(StreamId::new(basin, stream)) {
158 dashmap::Entry::Occupied(oe) => oe.get().clone(),
159 dashmap::Entry::Vacant(ve) => {
160 let this = self.clone();
161 let stream_id = *(ve.key());
162 let basin = basin.clone();
163 let stream = stream.clone();
164 let notify = Arc::new(Notify::new());
165 let notify_waiters = {
166 let notify = notify.clone();
167 move || notify.notify_waiters()
168 };
169 tokio::spawn(async move {
170 let state = match this.start_streamer(basin, stream).await {
171 Ok(client) => StreamerClientState::Ready { client },
172 Err(error) => StreamerClientState::InitError {
173 error: Box::new(error),
174 timestamp: Instant::now(),
175 },
176 };
177 this.client_states.insert(stream_id, state);
178 notify_waiters();
179 });
180 ve.insert(StreamerClientState::Blocked { notify })
181 .value()
182 .clone()
183 }
184 }
185 }
186
187 fn streamer_remove_unready(&self, stream_id: StreamId) {
188 if let dashmap::Entry::Occupied(oe) = self.client_states.entry(stream_id)
189 && let StreamerClientState::InitError { .. } = oe.get()
190 {
191 oe.remove();
192 }
193 }
194
195 pub(super) async fn streamer_client(
196 &self,
197 basin: &BasinName,
198 stream: &StreamName,
199 ) -> Result<StreamerClient, StreamerError> {
200 let mut waited = false;
201 loop {
202 match self.streamer_client_state(basin, stream) {
203 StreamerClientState::Blocked { notify } => {
204 notify.notified().await;
205 waited = true;
206 }
207 StreamerClientState::InitError { error, timestamp } => {
208 if !waited || timestamp.elapsed() > Self::FAILED_INIT_MEMORY {
209 self.streamer_remove_unready(StreamId::new(basin, stream));
210 } else {
211 return Err(*error);
212 }
213 }
214 StreamerClientState::Ready { client } => {
215 return Ok(client);
216 }
217 }
218 }
219 }
220
221 pub(super) fn streamer_client_if_active(
222 &self,
223 basin: &BasinName,
224 stream: &StreamName,
225 ) -> Option<StreamerClient> {
226 match self.streamer_client_state(basin, stream) {
227 StreamerClientState::Ready { client } => Some(client),
228 _ => None,
229 }
230 }
231
232 pub(super) async fn streamer_client_with_auto_create<E>(
233 &self,
234 basin: &BasinName,
235 stream: &StreamName,
236 should_auto_create: impl FnOnce(&BasinConfig) -> bool,
237 ) -> Result<StreamerClient, E>
238 where
239 E: From<StreamerError>
240 + From<StorageError>
241 + From<BasinNotFoundError>
242 + From<TransactionConflictError>
243 + From<BasinDeletionPendingError>
244 + From<StreamDeletionPendingError>
245 + From<StreamNotFoundError>,
246 {
247 match self.streamer_client(basin, stream).await {
248 Ok(client) => Ok(client),
249 Err(StreamerError::StreamNotFound(e)) => {
250 let config = match self.get_basin_config(basin.clone()).await {
251 Ok(config) => config,
252 Err(GetBasinConfigError::Storage(e)) => Err(e)?,
253 Err(GetBasinConfigError::BasinNotFound(e)) => Err(e)?,
254 };
255 if should_auto_create(&config) {
256 if let Err(e) = self
257 .create_stream(
258 basin.clone(),
259 stream.clone(),
260 OptionalStreamConfig::default(),
261 CreateMode::CreateOnly(None),
262 )
263 .await
264 {
265 match e {
266 CreateStreamError::Storage(e) => Err(e)?,
267 CreateStreamError::TransactionConflict(e) => Err(e)?,
268 CreateStreamError::BasinDeletionPending(e) => Err(e)?,
269 CreateStreamError::StreamDeletionPending(e) => Err(e)?,
270 CreateStreamError::BasinNotFound(e) => Err(e)?,
271 CreateStreamError::StreamAlreadyExists(_) => {}
272 }
273 }
274 Ok(self.streamer_client(basin, stream).await?)
275 } else {
276 Err(e.into())
277 }
278 }
279 Err(e) => Err(e.into()),
280 }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use std::str::FromStr as _;
287
288 use bytes::Bytes;
289 use s2_common::record::{Metered, Record, StreamPosition};
290 use slatedb::{WriteBatch, config::WriteOptions, object_store};
291 use time::OffsetDateTime;
292
293 use super::*;
294
295 #[tokio::test]
296 #[should_panic(expected = "invariant violation: stream `testbasin1/stream1` tail_pos")]
297 async fn start_streamer_fails_if_records_exist_after_tail_pos() {
298 let object_store: Arc<dyn object_store::ObjectStore> =
299 Arc::new(object_store::memory::InMemory::new());
300 let db = slatedb::Db::builder("test", object_store)
301 .build()
302 .await
303 .unwrap();
304
305 let backend = Backend::new(db.clone(), ByteSize::b(1));
306
307 let basin = BasinName::from_str("testbasin1").unwrap();
308 let stream = StreamName::from_str("stream1").unwrap();
309 let stream_id = StreamId::new(&basin, &stream);
310
311 let meta = kv::stream_meta::StreamMeta {
312 config: OptionalStreamConfig::default(),
313 created_at: OffsetDateTime::now_utc(),
314 deleted_at: None,
315 creation_idempotency_key: None,
316 };
317
318 let tail_pos = StreamPosition {
319 seq_num: 1,
320 timestamp: 123,
321 };
322 let record_pos = StreamPosition {
323 seq_num: tail_pos.seq_num,
324 timestamp: tail_pos.timestamp,
325 };
326
327 let record = Record::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap();
328 let metered_record: Metered<Record> = record.into();
329
330 let mut wb = WriteBatch::new();
331 wb.put(
332 kv::stream_meta::ser_key(&basin, &stream),
333 kv::stream_meta::ser_value(&meta),
334 );
335 wb.put(
336 kv::stream_tail_position::ser_key(stream_id),
337 kv::stream_tail_position::ser_value(
338 tail_pos,
339 kv::timestamp::TimestampSecs::from_secs(1),
340 ),
341 );
342 wb.put(
343 kv::stream_record_data::ser_key(stream_id, record_pos),
344 kv::stream_record_data::ser_value(metered_record.as_ref()),
345 );
346 static WRITE_OPTS: WriteOptions = WriteOptions {
347 await_durable: true,
348 };
349 db.write_with_options(wb, &WRITE_OPTS).await.unwrap();
350
351 backend
352 .start_streamer(basin.clone(), stream.clone())
353 .await
354 .unwrap();
355 }
356}