Skip to main content

s2_lite/backend/
core.rs

1use std::sync::Arc;
2
3use bytesize::ByteSize;
4use dashmap::DashMap;
5use enum_ordinalize::Ordinalize;
6use futures::{
7    FutureExt as _,
8    future::{BoxFuture, Shared},
9};
10use s2_common::{
11    encryption::{EncryptionAlgorithm, EncryptionSpec},
12    record::{NonZeroSeqNum, SeqNum, StreamPosition},
13    types::{
14        basin::BasinName,
15        config::{BasinConfig, OptionalStreamConfig},
16        resources::CreateMode,
17        stream::StreamName,
18    },
19};
20use slatedb::{
21    IterationOrder,
22    config::{DurabilityLevel, ScanOptions},
23};
24use tokio::sync::{Semaphore, broadcast};
25
26use super::{
27    StreamHandle,
28    durability_notifier::DurabilityNotifier,
29    error::{
30        BasinDeletionPendingError, BasinNotFoundError, CreateStreamError, GetBasinConfigError,
31        StorageError, StreamDeletionPendingError, StreamNotFoundError, StreamerError,
32        StreamerMissingInActionError, TransactionConflictError,
33    },
34    kv,
35    streamer::{GuardedStreamerClient, StreamerClient, StreamerGenerationId},
36};
37use crate::{backend::bgtasks::BgtaskTrigger, stream_id::StreamId};
38
39type StreamerInitFuture = Shared<BoxFuture<'static, Result<StreamerClient, StreamerError>>>;
40
41#[derive(Clone)]
42enum StreamerClientSlot {
43    Initializing {
44        generation_id: StreamerGenerationId,
45        future: StreamerInitFuture,
46    },
47    Ready {
48        client: StreamerClient,
49    },
50}
51
52#[derive(Clone)]
53pub struct Backend {
54    pub(super) db: slatedb::Db,
55    streamer_slots: Arc<DashMap<StreamId, StreamerClientSlot>>,
56    append_inflight_bytes_sema: Arc<Semaphore>,
57    durability_notifier: DurabilityNotifier,
58    bgtask_trigger_tx: broadcast::Sender<BgtaskTrigger>,
59}
60
61impl Backend {
62    pub fn new(db: slatedb::Db, append_inflight_bytes: ByteSize) -> Self {
63        let (bgtask_trigger_tx, _) = broadcast::channel(16);
64        let append_inflight_bytes = Arc::new(Semaphore::new(
65            (append_inflight_bytes.as_u64() as usize).clamp(
66                s2_common::caps::RECORD_BATCH_MAX.bytes,
67                Semaphore::MAX_PERMITS,
68            ),
69        ));
70        let durability_notifier = DurabilityNotifier::spawn(&db);
71        Self {
72            db,
73            streamer_slots: Arc::new(DashMap::new()),
74            append_inflight_bytes_sema: append_inflight_bytes,
75            durability_notifier,
76            bgtask_trigger_tx,
77        }
78    }
79
80    pub(super) fn bgtask_trigger(&self, trigger: BgtaskTrigger) {
81        let _ = self.bgtask_trigger_tx.send(trigger);
82    }
83
84    pub(super) fn bgtask_trigger_subscribe(&self) -> broadcast::Receiver<BgtaskTrigger> {
85        self.bgtask_trigger_tx.subscribe()
86    }
87
88    async fn start_streamer(
89        &self,
90        generation_id: StreamerGenerationId,
91        basin: BasinName,
92        stream: StreamName,
93    ) -> Result<StreamerClient, StreamerError> {
94        let stream_id = StreamId::new(&basin, &stream);
95
96        let (meta, tail_pos, fencing_token, trim_point) = tokio::try_join!(
97            self.db_get(
98                kv::stream_meta::ser_key(&basin, &stream),
99                kv::stream_meta::deser_value,
100            ),
101            self.db_get(
102                kv::stream_tail_position::ser_key(stream_id),
103                kv::stream_tail_position::deser_value,
104            ),
105            self.db_get(
106                kv::stream_fencing_token::ser_key(stream_id),
107                kv::stream_fencing_token::deser_value,
108            ),
109            self.db_get(
110                kv::stream_trim_point::ser_key(stream_id),
111                kv::stream_trim_point::deser_value,
112            )
113        )?;
114
115        let Some(meta) = meta else {
116            return Err(StreamNotFoundError { basin, stream }.into());
117        };
118
119        let tail_pos = tail_pos.map(|(pos, _)| pos).unwrap_or(StreamPosition::MIN);
120        self.assert_no_records_following_tail(stream_id, &basin, &stream, tail_pos)
121            .await?;
122
123        let fencing_token = fencing_token.unwrap_or_default();
124
125        if trim_point == Some(..NonZeroSeqNum::MAX) {
126            return Err(StreamDeletionPendingError { basin, stream }.into());
127        }
128
129        let streamer_slots = self.streamer_slots.clone();
130        Ok(super::streamer::Spawner {
131            generation_id,
132            db: self.db.clone(),
133            stream_id,
134            config: meta.config,
135            cipher: meta.cipher,
136            tail_pos,
137            fencing_token,
138            trim_point: ..trim_point.map_or(SeqNum::MIN, |tp| tp.end.get()),
139            append_inflight_bytes_sema: self.append_inflight_bytes_sema.clone(),
140            durability_notifier: self.durability_notifier.clone(),
141            bgtask_trigger_tx: self.bgtask_trigger_tx.clone(),
142        }
143        .spawn(move |client_id| {
144            streamer_slots.remove_if(&stream_id, |_, slot| {
145                matches!(slot, StreamerClientSlot::Ready { client } if client.generation_id() == client_id)
146            });
147        }))
148    }
149
150    async fn assert_no_records_following_tail(
151        &self,
152        stream_id: StreamId,
153        basin: &BasinName,
154        stream: &StreamName,
155        tail_pos: StreamPosition,
156    ) -> Result<(), StorageError> {
157        let start_key = kv::stream_record_data::ser_key(
158            stream_id,
159            StreamPosition {
160                seq_num: tail_pos.seq_num,
161                timestamp: 0,
162            },
163        );
164        static SCAN_OPTS: ScanOptions = ScanOptions {
165            durability_filter: DurabilityLevel::Remote,
166            dirty: false,
167            read_ahead_bytes: 1,
168            cache_blocks: false,
169            max_fetch_tasks: 1,
170            order: IterationOrder::Ascending,
171        };
172        let mut it = self.db.scan_with_options(start_key.., &SCAN_OPTS).await?;
173        let Some(kv) = it.next().await? else {
174            return Ok(());
175        };
176        if kv.key.first().copied() != Some(kv::KeyType::StreamRecordData.ordinal()) {
177            return Ok(());
178        }
179        let (deser_stream_id, pos) = kv::stream_record_data::deser_key(kv.key)?;
180        assert!(
181            deser_stream_id != stream_id,
182            "invariant violation: stream `{basin}/{stream}` tail_pos {tail_pos:?} but found record at {pos:?}"
183        );
184        Ok(())
185    }
186
187    fn streamer_client_slot(&self, basin: &BasinName, stream: &StreamName) -> StreamerClientSlot {
188        match self.streamer_slots.entry(StreamId::new(basin, stream)) {
189            dashmap::Entry::Occupied(mut oe) => {
190                if matches!(oe.get(), StreamerClientSlot::Ready { client } if client.is_dead()) {
191                    let slot = self.clone().new_initializing_slot(basin, stream);
192                    oe.insert(slot.clone());
193                    slot
194                } else {
195                    oe.get().clone()
196                }
197            }
198            dashmap::Entry::Vacant(ve) => {
199                let slot = self.clone().new_initializing_slot(basin, stream);
200                ve.insert(slot.clone());
201                slot
202            }
203        }
204    }
205
206    fn new_initializing_slot(self, basin: &BasinName, stream: &StreamName) -> StreamerClientSlot {
207        let basin = basin.clone();
208        let stream = stream.clone();
209        let generation_id = StreamerGenerationId::next();
210        let future = async move { self.start_streamer(generation_id, basin, stream).await }
211            .boxed()
212            .shared();
213        StreamerClientSlot::Initializing {
214            generation_id,
215            future,
216        }
217    }
218
219    fn streamer_finish_initialization(
220        &self,
221        stream_id: StreamId,
222        generation_id: StreamerGenerationId,
223        result: &Result<StreamerClient, StreamerError>,
224    ) {
225        if let dashmap::Entry::Occupied(mut oe) = self.streamer_slots.entry(stream_id) {
226            let is_same_init = matches!(
227                oe.get(),
228                StreamerClientSlot::Initializing {
229                    generation_id: state_generation_id,
230                    ..
231                } if *state_generation_id == generation_id
232            );
233            if is_same_init {
234                match result {
235                    Ok(client) => {
236                        debug_assert_eq!(client.generation_id(), generation_id);
237                        if client.is_dead() {
238                            oe.remove();
239                        } else {
240                            oe.insert(StreamerClientSlot::Ready {
241                                client: client.clone(),
242                            });
243                        }
244                    }
245                    Err(_) => {
246                        oe.remove();
247                    }
248                }
249            }
250        }
251    }
252
253    pub(super) async fn streamer_client(
254        &self,
255        basin: &BasinName,
256        stream: &StreamName,
257    ) -> Result<StreamerClient, StreamerError> {
258        let stream_id = StreamId::new(basin, stream);
259        match self.streamer_client_slot(basin, stream) {
260            StreamerClientSlot::Initializing {
261                generation_id,
262                future,
263            } => {
264                let result = future.await;
265                self.streamer_finish_initialization(stream_id, generation_id, &result);
266                result
267            }
268            StreamerClientSlot::Ready { client } => Ok(client),
269        }
270    }
271
272    pub(super) fn streamer_client_if_active(
273        &self,
274        basin: &BasinName,
275        stream: &StreamName,
276    ) -> Option<StreamerClient> {
277        let stream_id = StreamId::new(basin, stream);
278        let slot = self.streamer_slots.get(&stream_id)?;
279        match slot.value() {
280            StreamerClientSlot::Ready { client } if !client.is_dead() => Some(client.clone()),
281            _ => None,
282        }
283    }
284
285    pub(super) async fn streamer_client_guarded(
286        &self,
287        basin: &BasinName,
288        stream: &StreamName,
289    ) -> Result<GuardedStreamerClient, StreamerError> {
290        loop {
291            let client = self.streamer_client(basin, stream).await?;
292            match client.guard() {
293                Ok(client) => return Ok(client),
294                Err(StreamerMissingInActionError) => continue,
295            }
296        }
297    }
298
299    pub(super) async fn stream_handle_with_auto_create<E>(
300        &self,
301        basin: &BasinName,
302        stream: &StreamName,
303        should_auto_create: impl FnOnce(&BasinConfig) -> bool,
304        resolve_encryption: impl FnOnce(Option<EncryptionAlgorithm>) -> Result<EncryptionSpec, E>,
305    ) -> Result<StreamHandle, E>
306    where
307        E: From<StreamerError>
308            + From<StorageError>
309            + From<BasinNotFoundError>
310            + From<TransactionConflictError>
311            + From<BasinDeletionPendingError>
312            + From<StreamDeletionPendingError>
313            + From<StreamNotFoundError>,
314    {
315        match self.streamer_client_guarded(basin, stream).await {
316            Ok(client) => Ok(StreamHandle {
317                db: self.db.clone(),
318                encryption: resolve_encryption(client.cipher())?,
319                client,
320            }),
321            Err(StreamerError::StreamNotFound(e)) => {
322                let config = match self.get_basin_config(basin.clone()).await {
323                    Ok(config) => config,
324                    Err(GetBasinConfigError::Storage(e)) => Err(e)?,
325                    Err(GetBasinConfigError::BasinNotFound(e)) => Err(e)?,
326                };
327                if should_auto_create(&config) {
328                    if let Err(e) = self
329                        .create_stream(
330                            basin.clone(),
331                            stream.clone(),
332                            OptionalStreamConfig::default(),
333                            CreateMode::CreateOnly(None),
334                        )
335                        .await
336                    {
337                        match e {
338                            CreateStreamError::Storage(e) => Err(e)?,
339                            CreateStreamError::TransactionConflict(e) => Err(e)?,
340                            CreateStreamError::BasinDeletionPending(e) => Err(e)?,
341                            CreateStreamError::StreamDeletionPending(e) => Err(e)?,
342                            CreateStreamError::BasinNotFound(e) => Err(e)?,
343                            CreateStreamError::StreamAlreadyExists(_) => {}
344                            CreateStreamError::Validation(_) => {
345                                unreachable!("auto-create uses default config")
346                            }
347                        }
348                    }
349                    let client = self.streamer_client_guarded(basin, stream).await?;
350                    let encryption = resolve_encryption(client.cipher())?;
351                    Ok(StreamHandle {
352                        db: self.db.clone(),
353                        encryption,
354                        client,
355                    })
356                } else {
357                    Err(e.into())
358                }
359            }
360            Err(e) => Err(e.into()),
361        }
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use std::str::FromStr as _;
368
369    use bytes::Bytes;
370    use s2_common::{
371        record::{Metered, Record, StoredRecord, StreamPosition},
372        types::{
373            config::{BasinConfig, OptionalStreamConfig},
374            resources::CreateMode,
375        },
376    };
377    use slatedb::{WriteBatch, config::WriteOptions, object_store};
378    use time::OffsetDateTime;
379
380    use super::*;
381
382    async fn new_test_backend() -> Backend {
383        let object_store: Arc<dyn object_store::ObjectStore> =
384            Arc::new(object_store::memory::InMemory::new());
385        let db = slatedb::Db::builder("test", object_store)
386            .build()
387            .await
388            .unwrap();
389        Backend::new(db, ByteSize::b(1))
390    }
391
392    #[tokio::test]
393    #[should_panic(expected = "invariant violation: stream `testbasin1/stream1` tail_pos")]
394    async fn start_streamer_fails_if_records_exist_after_tail_pos() {
395        let backend = new_test_backend().await;
396
397        let basin = BasinName::from_str("testbasin1").unwrap();
398        let stream = StreamName::from_str("stream1").unwrap();
399        let stream_id = StreamId::new(&basin, &stream);
400
401        let meta = kv::stream_meta::StreamMeta {
402            config: OptionalStreamConfig::default(),
403            cipher: None,
404            created_at: OffsetDateTime::now_utc(),
405            deleted_at: None,
406            creation_idempotency_key: None,
407        };
408
409        let tail_pos = StreamPosition {
410            seq_num: 1,
411            timestamp: 123,
412        };
413        let record_pos = StreamPosition {
414            seq_num: tail_pos.seq_num,
415            timestamp: tail_pos.timestamp,
416        };
417
418        let record = Record::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap();
419        let metered_record: Metered<StoredRecord> = StoredRecord::from(record).into();
420
421        let mut wb = WriteBatch::new();
422        wb.put(
423            kv::stream_meta::ser_key(&basin, &stream),
424            kv::stream_meta::ser_value(&meta),
425        );
426        wb.put(
427            kv::stream_tail_position::ser_key(stream_id),
428            kv::stream_tail_position::ser_value(
429                tail_pos,
430                kv::timestamp::TimestampSecs::from_secs(1),
431            ),
432        );
433        wb.put(
434            kv::stream_record_data::ser_key(stream_id, record_pos),
435            kv::stream_record_data::ser_value(metered_record.as_ref()),
436        );
437        static WRITE_OPTS: WriteOptions = WriteOptions {
438            await_durable: true,
439        };
440        backend
441            .db
442            .write_with_options(wb, &WRITE_OPTS)
443            .await
444            .unwrap();
445
446        backend
447            .start_streamer(StreamerGenerationId::next(), basin.clone(), stream.clone())
448            .await
449            .unwrap();
450    }
451
452    #[tokio::test]
453    async fn streamer_client_slot_uses_single_initializer() {
454        let backend = new_test_backend().await;
455        let basin = BasinName::from_str("testbasin2").unwrap();
456        let stream = StreamName::from_str("stream2").unwrap();
457
458        let slot_1 = backend.streamer_client_slot(&basin, &stream);
459        let slot_2 = backend.streamer_client_slot(&basin, &stream);
460
461        let (generation_id_1, generation_id_2) = match (slot_1, slot_2) {
462            (
463                StreamerClientSlot::Initializing {
464                    generation_id: generation_id_1,
465                    ..
466                },
467                StreamerClientSlot::Initializing {
468                    generation_id: generation_id_2,
469                    ..
470                },
471            ) => (generation_id_1, generation_id_2),
472            _ => panic!("expected both slots to be Initializing"),
473        };
474        assert_eq!(generation_id_1, generation_id_2);
475        assert_eq!(backend.streamer_slots.len(), 1);
476    }
477
478    #[tokio::test]
479    async fn streamer_client_if_active_is_peek_only() {
480        let backend = new_test_backend().await;
481        let basin = BasinName::from_str("testbasin3").unwrap();
482        let stream = StreamName::from_str("stream3").unwrap();
483
484        backend
485            .create_basin(
486                basin.clone(),
487                BasinConfig::default(),
488                CreateMode::CreateOnly(None),
489            )
490            .await
491            .unwrap();
492        backend
493            .create_stream(
494                basin.clone(),
495                stream.clone(),
496                OptionalStreamConfig::default(),
497                CreateMode::CreateOnly(None),
498            )
499            .await
500            .unwrap();
501
502        assert!(backend.streamer_slots.is_empty());
503        assert!(backend.streamer_client_if_active(&basin, &stream).is_none());
504        assert!(backend.streamer_slots.is_empty());
505    }
506
507    #[tokio::test]
508    async fn streamer_client_failed_init_is_not_memoized() {
509        let backend = new_test_backend().await;
510        let basin = BasinName::from_str("testbasin4").unwrap();
511        let stream = StreamName::from_str("stream4").unwrap();
512        let stream_id = StreamId::new(&basin, &stream);
513
514        for _ in 0..2 {
515            let err = backend.streamer_client(&basin, &stream).await;
516            assert!(matches!(err, Err(StreamerError::StreamNotFound(_))));
517            assert!(
518                backend.streamer_slots.get(&stream_id).is_none(),
519                "failed init should not be cached"
520            );
521        }
522    }
523
524    #[tokio::test]
525    async fn streamer_finish_initialization_ignores_stale_generation_id() {
526        let backend = new_test_backend().await;
527        let basin = BasinName::from_str("testbasin5").unwrap();
528        let stream = StreamName::from_str("stream5").unwrap();
529        let stream_id = StreamId::new(&basin, &stream);
530
531        let stale_generation_id = StreamerGenerationId::next();
532        let current_generation_id = StreamerGenerationId::next();
533        let future = futures::future::pending::<Result<StreamerClient, StreamerError>>()
534            .boxed()
535            .shared();
536        backend.streamer_slots.insert(
537            stream_id,
538            StreamerClientSlot::Initializing {
539                generation_id: current_generation_id,
540                future: future.clone(),
541            },
542        );
543
544        let stale_result = Err(StreamNotFoundError { basin, stream }.into());
545        backend.streamer_finish_initialization(stream_id, stale_generation_id, &stale_result);
546
547        let Some(slot) = backend.streamer_slots.get(&stream_id) else {
548            panic!("stale init completion should not alter slot state");
549        };
550        match slot.value() {
551            StreamerClientSlot::Initializing { generation_id, .. } => {
552                assert_eq!(*generation_id, current_generation_id)
553            }
554            _ => panic!("expected initializing slot to remain unchanged"),
555        }
556    }
557}