Skip to main content

s2_lite/backend/
core.rs

1use std::sync::{
2    Arc,
3    atomic::{AtomicU64, Ordering},
4};
5
6use bytesize::ByteSize;
7use dashmap::DashMap;
8use enum_ordinalize::Ordinalize;
9use futures::{
10    FutureExt as _,
11    future::{BoxFuture, Shared},
12};
13use s2_common::{
14    record::{NonZeroSeqNum, SeqNum, StreamPosition},
15    types::{
16        basin::BasinName,
17        config::{BasinConfig, OptionalStreamConfig},
18        resources::CreateMode,
19        stream::StreamName,
20    },
21};
22use slatedb::config::{DurabilityLevel, ScanOptions};
23use tokio::sync::{Semaphore, broadcast};
24
25use super::{
26    durability_notifier::DurabilityNotifier,
27    error::{
28        BasinDeletionPendingError, BasinNotFoundError, CreateStreamError, GetBasinConfigError,
29        StorageError, StreamDeletionPendingError, StreamNotFoundError, StreamerError,
30        TransactionConflictError,
31    },
32    kv,
33    stream_id::StreamId,
34    streamer::StreamerClient,
35};
36use crate::backend::bgtasks::BgtaskTrigger;
37
38type StreamerInitFuture = Shared<BoxFuture<'static, Result<StreamerClient, StreamerError>>>;
39
40#[derive(Clone, Copy, Debug, PartialEq, Eq)]
41struct StreamerInitId(u64);
42
43impl StreamerInitId {
44    fn next() -> Self {
45        static NEXT_ID: AtomicU64 = AtomicU64::new(1);
46        Self(NEXT_ID.fetch_add(1, Ordering::Relaxed))
47    }
48}
49
50#[derive(Clone)]
51enum StreamerClientSlot {
52    Initializing {
53        init_id: StreamerInitId,
54        future: StreamerInitFuture,
55    },
56    Ready {
57        client: StreamerClient,
58    },
59}
60
61#[derive(Clone)]
62pub struct Backend {
63    pub(super) db: slatedb::Db,
64    streamer_slots: Arc<DashMap<StreamId, StreamerClientSlot>>,
65    append_inflight_bytes_sema: Arc<Semaphore>,
66    durability_notifier: DurabilityNotifier,
67    bgtask_trigger_tx: broadcast::Sender<BgtaskTrigger>,
68}
69
70impl Backend {
71    pub fn new(db: slatedb::Db, append_inflight_bytes: ByteSize) -> Self {
72        let (bgtask_trigger_tx, _) = broadcast::channel(16);
73        let append_inflight_bytes = Arc::new(Semaphore::new(
74            (append_inflight_bytes.as_u64() as usize).clamp(
75                s2_common::caps::RECORD_BATCH_MAX.bytes,
76                Semaphore::MAX_PERMITS,
77            ),
78        ));
79        let durability_notifier = DurabilityNotifier::spawn(&db);
80        Self {
81            db,
82            streamer_slots: Arc::new(DashMap::new()),
83            append_inflight_bytes_sema: append_inflight_bytes,
84            durability_notifier,
85            bgtask_trigger_tx,
86        }
87    }
88
89    pub(super) fn bgtask_trigger(&self, trigger: BgtaskTrigger) {
90        let _ = self.bgtask_trigger_tx.send(trigger);
91    }
92
93    pub(super) fn bgtask_trigger_subscribe(&self) -> broadcast::Receiver<BgtaskTrigger> {
94        self.bgtask_trigger_tx.subscribe()
95    }
96
97    async fn start_streamer(
98        &self,
99        basin: BasinName,
100        stream: StreamName,
101    ) -> Result<StreamerClient, StreamerError> {
102        let stream_id = StreamId::new(&basin, &stream);
103
104        let (meta, tail_pos, fencing_token, trim_point) = tokio::try_join!(
105            self.db_get(
106                kv::stream_meta::ser_key(&basin, &stream),
107                kv::stream_meta::deser_value,
108            ),
109            self.db_get(
110                kv::stream_tail_position::ser_key(stream_id),
111                kv::stream_tail_position::deser_value,
112            ),
113            self.db_get(
114                kv::stream_fencing_token::ser_key(stream_id),
115                kv::stream_fencing_token::deser_value,
116            ),
117            self.db_get(
118                kv::stream_trim_point::ser_key(stream_id),
119                kv::stream_trim_point::deser_value,
120            )
121        )?;
122
123        let Some(meta) = meta else {
124            return Err(StreamNotFoundError { basin, stream }.into());
125        };
126
127        let tail_pos = tail_pos.map(|(pos, _)| pos).unwrap_or(StreamPosition::MIN);
128        self.assert_no_records_following_tail(stream_id, &basin, &stream, tail_pos)
129            .await?;
130
131        let fencing_token = fencing_token.unwrap_or_default();
132
133        if trim_point == Some(..NonZeroSeqNum::MAX) {
134            return Err(StreamDeletionPendingError { basin, stream }.into());
135        }
136
137        let streamer_slots = self.streamer_slots.clone();
138        Ok(super::streamer::Spawner {
139            db: self.db.clone(),
140            stream_id,
141            config: meta.config,
142            tail_pos,
143            fencing_token,
144            trim_point: ..trim_point.map_or(SeqNum::MIN, |tp| tp.end.get()),
145            append_inflight_bytes_sema: self.append_inflight_bytes_sema.clone(),
146            durability_notifier: self.durability_notifier.clone(),
147            bgtask_trigger_tx: self.bgtask_trigger_tx.clone(),
148        }
149        .spawn(move |client_id| {
150            streamer_slots.remove_if(&stream_id, |_, slot| {
151                matches!(slot, StreamerClientSlot::Ready { client } if client.id() == client_id)
152            });
153        }))
154    }
155
156    async fn assert_no_records_following_tail(
157        &self,
158        stream_id: StreamId,
159        basin: &BasinName,
160        stream: &StreamName,
161        tail_pos: StreamPosition,
162    ) -> Result<(), StorageError> {
163        let start_key = kv::stream_record_data::ser_key(
164            stream_id,
165            StreamPosition {
166                seq_num: tail_pos.seq_num,
167                timestamp: 0,
168            },
169        );
170        static SCAN_OPTS: ScanOptions = ScanOptions {
171            durability_filter: DurabilityLevel::Remote,
172            dirty: false,
173            read_ahead_bytes: 1,
174            cache_blocks: false,
175            max_fetch_tasks: 1,
176        };
177        let mut it = self.db.scan_with_options(start_key.., &SCAN_OPTS).await?;
178        let Some(kv) = it.next().await? else {
179            return Ok(());
180        };
181        if kv.key.first().copied() != Some(kv::KeyType::StreamRecordData.ordinal()) {
182            return Ok(());
183        }
184        let (deser_stream_id, pos) = kv::stream_record_data::deser_key(kv.key)?;
185        assert!(
186            deser_stream_id != stream_id,
187            "invariant violation: stream `{basin}/{stream}` tail_pos {tail_pos:?} but found record at {pos:?}"
188        );
189        Ok(())
190    }
191
192    fn streamer_client_slot(&self, basin: &BasinName, stream: &StreamName) -> StreamerClientSlot {
193        match self.streamer_slots.entry(StreamId::new(basin, stream)) {
194            dashmap::Entry::Occupied(oe) => oe.get().clone(),
195            dashmap::Entry::Vacant(ve) => {
196                let this = self.clone();
197                let basin = basin.clone();
198                let stream = stream.clone();
199                let init_id = StreamerInitId::next();
200                let future = async move { this.start_streamer(basin, stream).await }
201                    .boxed()
202                    .shared();
203                let slot = StreamerClientSlot::Initializing {
204                    init_id,
205                    future: future.clone(),
206                };
207                ve.insert(slot.clone());
208                slot
209            }
210        }
211    }
212
213    fn streamer_finish_initialization(
214        &self,
215        stream_id: StreamId,
216        init_id: StreamerInitId,
217        result: &Result<StreamerClient, StreamerError>,
218    ) {
219        if let dashmap::Entry::Occupied(mut oe) = self.streamer_slots.entry(stream_id) {
220            let is_same_init = matches!(
221                oe.get(),
222                StreamerClientSlot::Initializing {
223                    init_id: state_init_id,
224                    ..
225                } if *state_init_id == init_id
226            );
227            if is_same_init {
228                match result {
229                    Ok(client) => {
230                        if client.is_dead() {
231                            oe.remove();
232                        } else {
233                            oe.insert(StreamerClientSlot::Ready {
234                                client: client.clone(),
235                            });
236                        }
237                    }
238                    Err(_) => {
239                        oe.remove();
240                    }
241                }
242            }
243        }
244    }
245
246    pub(super) async fn streamer_client(
247        &self,
248        basin: &BasinName,
249        stream: &StreamName,
250    ) -> Result<StreamerClient, StreamerError> {
251        let stream_id = StreamId::new(basin, stream);
252        match self.streamer_client_slot(basin, stream) {
253            StreamerClientSlot::Initializing { init_id, future } => {
254                let result = future.await;
255                self.streamer_finish_initialization(stream_id, init_id, &result);
256                result
257            }
258            StreamerClientSlot::Ready { client } => Ok(client),
259        }
260    }
261
262    pub(super) fn streamer_client_if_active(
263        &self,
264        basin: &BasinName,
265        stream: &StreamName,
266    ) -> Option<StreamerClient> {
267        let stream_id = StreamId::new(basin, stream);
268        let slot = self.streamer_slots.get(&stream_id)?;
269        match slot.value() {
270            StreamerClientSlot::Ready { client } => Some(client.clone()),
271            _ => None,
272        }
273    }
274
275    pub(super) async fn streamer_client_with_auto_create<E>(
276        &self,
277        basin: &BasinName,
278        stream: &StreamName,
279        should_auto_create: impl FnOnce(&BasinConfig) -> bool,
280    ) -> Result<StreamerClient, E>
281    where
282        E: From<StreamerError>
283            + From<StorageError>
284            + From<BasinNotFoundError>
285            + From<TransactionConflictError>
286            + From<BasinDeletionPendingError>
287            + From<StreamDeletionPendingError>
288            + From<StreamNotFoundError>,
289    {
290        match self.streamer_client(basin, stream).await {
291            Ok(client) => Ok(client),
292            Err(StreamerError::StreamNotFound(e)) => {
293                let config = match self.get_basin_config(basin.clone()).await {
294                    Ok(config) => config,
295                    Err(GetBasinConfigError::Storage(e)) => Err(e)?,
296                    Err(GetBasinConfigError::BasinNotFound(e)) => Err(e)?,
297                };
298                if should_auto_create(&config) {
299                    if let Err(e) = self
300                        .create_stream(
301                            basin.clone(),
302                            stream.clone(),
303                            OptionalStreamConfig::default(),
304                            CreateMode::CreateOnly(None),
305                        )
306                        .await
307                    {
308                        match e {
309                            CreateStreamError::Storage(e) => Err(e)?,
310                            CreateStreamError::TransactionConflict(e) => Err(e)?,
311                            CreateStreamError::BasinDeletionPending(e) => Err(e)?,
312                            CreateStreamError::StreamDeletionPending(e) => Err(e)?,
313                            CreateStreamError::BasinNotFound(e) => Err(e)?,
314                            CreateStreamError::StreamAlreadyExists(_) => {}
315                        }
316                    }
317                    Ok(self.streamer_client(basin, stream).await?)
318                } else {
319                    Err(e.into())
320                }
321            }
322            Err(e) => Err(e.into()),
323        }
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use std::str::FromStr as _;
330
331    use bytes::Bytes;
332    use s2_common::{
333        record::{Metered, Record, StreamPosition},
334        types::{config::BasinConfig, resources::CreateMode},
335    };
336    use slatedb::{WriteBatch, config::WriteOptions, object_store};
337    use time::OffsetDateTime;
338
339    use super::*;
340
341    async fn new_test_backend() -> Backend {
342        let object_store: Arc<dyn object_store::ObjectStore> =
343            Arc::new(object_store::memory::InMemory::new());
344        let db = slatedb::Db::builder("test", object_store)
345            .build()
346            .await
347            .unwrap();
348        Backend::new(db, ByteSize::b(1))
349    }
350
351    #[tokio::test]
352    #[should_panic(expected = "invariant violation: stream `testbasin1/stream1` tail_pos")]
353    async fn start_streamer_fails_if_records_exist_after_tail_pos() {
354        let backend = new_test_backend().await;
355
356        let basin = BasinName::from_str("testbasin1").unwrap();
357        let stream = StreamName::from_str("stream1").unwrap();
358        let stream_id = StreamId::new(&basin, &stream);
359
360        let meta = kv::stream_meta::StreamMeta {
361            config: OptionalStreamConfig::default(),
362            created_at: OffsetDateTime::now_utc(),
363            deleted_at: None,
364            creation_idempotency_key: None,
365        };
366
367        let tail_pos = StreamPosition {
368            seq_num: 1,
369            timestamp: 123,
370        };
371        let record_pos = StreamPosition {
372            seq_num: tail_pos.seq_num,
373            timestamp: tail_pos.timestamp,
374        };
375
376        let record = Record::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap();
377        let metered_record: Metered<Record> = record.into();
378
379        let mut wb = WriteBatch::new();
380        wb.put(
381            kv::stream_meta::ser_key(&basin, &stream),
382            kv::stream_meta::ser_value(&meta),
383        );
384        wb.put(
385            kv::stream_tail_position::ser_key(stream_id),
386            kv::stream_tail_position::ser_value(
387                tail_pos,
388                kv::timestamp::TimestampSecs::from_secs(1),
389            ),
390        );
391        wb.put(
392            kv::stream_record_data::ser_key(stream_id, record_pos),
393            kv::stream_record_data::ser_value(metered_record.as_ref()),
394        );
395        static WRITE_OPTS: WriteOptions = WriteOptions {
396            await_durable: true,
397        };
398        backend
399            .db
400            .write_with_options(wb, &WRITE_OPTS)
401            .await
402            .unwrap();
403
404        backend
405            .start_streamer(basin.clone(), stream.clone())
406            .await
407            .unwrap();
408    }
409
410    #[tokio::test]
411    async fn streamer_client_slot_uses_single_initializer() {
412        let backend = new_test_backend().await;
413        let basin = BasinName::from_str("testbasin2").unwrap();
414        let stream = StreamName::from_str("stream2").unwrap();
415
416        let slot_1 = backend.streamer_client_slot(&basin, &stream);
417        let slot_2 = backend.streamer_client_slot(&basin, &stream);
418
419        let (init_id_1, init_id_2) = match (slot_1, slot_2) {
420            (
421                StreamerClientSlot::Initializing {
422                    init_id: init_id_1, ..
423                },
424                StreamerClientSlot::Initializing {
425                    init_id: init_id_2, ..
426                },
427            ) => (init_id_1, init_id_2),
428            _ => panic!("expected both slots to be Initializing"),
429        };
430        assert_eq!(init_id_1, init_id_2);
431        assert_eq!(backend.streamer_slots.len(), 1);
432    }
433
434    #[tokio::test]
435    async fn streamer_client_if_active_is_peek_only() {
436        let backend = new_test_backend().await;
437        let basin = BasinName::from_str("testbasin3").unwrap();
438        let stream = StreamName::from_str("stream3").unwrap();
439
440        backend
441            .create_basin(
442                basin.clone(),
443                BasinConfig::default(),
444                CreateMode::CreateOnly(None),
445            )
446            .await
447            .unwrap();
448        backend
449            .create_stream(
450                basin.clone(),
451                stream.clone(),
452                OptionalStreamConfig::default(),
453                CreateMode::CreateOnly(None),
454            )
455            .await
456            .unwrap();
457
458        assert!(backend.streamer_slots.is_empty());
459        assert!(backend.streamer_client_if_active(&basin, &stream).is_none());
460        assert!(backend.streamer_slots.is_empty());
461    }
462
463    #[tokio::test]
464    async fn streamer_client_failed_init_is_not_memoized() {
465        let backend = new_test_backend().await;
466        let basin = BasinName::from_str("testbasin4").unwrap();
467        let stream = StreamName::from_str("stream4").unwrap();
468        let stream_id = StreamId::new(&basin, &stream);
469
470        for _ in 0..2 {
471            let err = backend.streamer_client(&basin, &stream).await;
472            assert!(matches!(err, Err(StreamerError::StreamNotFound(_))));
473            assert!(
474                backend.streamer_slots.get(&stream_id).is_none(),
475                "failed init should not be cached"
476            );
477        }
478    }
479
480    #[tokio::test]
481    async fn streamer_finish_initialization_ignores_stale_init_id() {
482        let backend = new_test_backend().await;
483        let basin = BasinName::from_str("testbasin5").unwrap();
484        let stream = StreamName::from_str("stream5").unwrap();
485        let stream_id = StreamId::new(&basin, &stream);
486
487        let stale_init_id = StreamerInitId::next();
488        let current_init_id = StreamerInitId::next();
489        let future = futures::future::pending::<Result<StreamerClient, StreamerError>>()
490            .boxed()
491            .shared();
492        backend.streamer_slots.insert(
493            stream_id,
494            StreamerClientSlot::Initializing {
495                init_id: current_init_id,
496                future: future.clone(),
497            },
498        );
499
500        let stale_result = Err(StreamNotFoundError { basin, stream }.into());
501        backend.streamer_finish_initialization(stream_id, stale_init_id, &stale_result);
502
503        let Some(slot) = backend.streamer_slots.get(&stream_id) else {
504            panic!("stale init completion should not alter slot state");
505        };
506        match slot.value() {
507            StreamerClientSlot::Initializing { init_id, .. } => {
508                assert_eq!(*init_id, current_init_id)
509            }
510            _ => panic!("expected initializing slot to remain unchanged"),
511        }
512    }
513}