Skip to main content

s2_lite/backend/
core.rs

1use std::sync::Arc;
2
3use bytesize::ByteSize;
4use dashmap::DashMap;
5use futures::{
6    FutureExt as _,
7    future::{BoxFuture, Shared},
8};
9use s2_common::{
10    encryption::{EncryptionAlgorithm, EncryptionSpec},
11    record::{NonZeroSeqNum, SeqNum, StreamPosition},
12    types::{
13        basin::BasinName,
14        config::{BasinConfig, OptionalStreamConfig},
15        resources::CreateMode,
16        stream::StreamName,
17    },
18};
19use slatedb::{
20    IterationOrder,
21    config::{DurabilityLevel, ScanOptions},
22};
23use tokio::sync::{Semaphore, broadcast};
24
25use super::{
26    StreamHandle,
27    durability_notifier::DurabilityNotifier,
28    error::{
29        BasinDeletionPendingError, BasinNotFoundError, CreateStreamError, GetBasinConfigError,
30        StorageError, StreamDeletionPendingError, StreamNotFoundError, StreamerError,
31        StreamerMissingInActionError, TransactionConflictError,
32    },
33    kv,
34    streamer::{GuardedStreamerClient, StreamerClient, StreamerGenerationId},
35};
36use crate::{backend::bgtasks::BgtaskTrigger, stream_id::StreamId};
37
38type StreamerInitFuture = Shared<BoxFuture<'static, Result<StreamerClient, StreamerError>>>;
39
40#[derive(Clone)]
41enum StreamerClientSlot {
42    Initializing {
43        generation_id: StreamerGenerationId,
44        future: StreamerInitFuture,
45    },
46    Ready {
47        client: StreamerClient,
48    },
49}
50
51#[derive(Clone)]
52pub struct Backend {
53    pub(super) db: slatedb::Db,
54    streamer_slots: Arc<DashMap<StreamId, StreamerClientSlot>>,
55    append_inflight_bytes_sema: Arc<Semaphore>,
56    durability_notifier: DurabilityNotifier,
57    bgtask_trigger_tx: broadcast::Sender<BgtaskTrigger>,
58}
59
60impl Backend {
61    pub fn new(db: slatedb::Db, append_inflight_bytes: ByteSize) -> Self {
62        let (bgtask_trigger_tx, _) = broadcast::channel(16);
63        let append_inflight_bytes = Arc::new(Semaphore::new(
64            (append_inflight_bytes.as_u64() as usize).clamp(
65                s2_common::caps::RECORD_BATCH_MAX.bytes,
66                Semaphore::MAX_PERMITS,
67            ),
68        ));
69        let durability_notifier = DurabilityNotifier::spawn(&db);
70        Self {
71            db,
72            streamer_slots: Arc::new(DashMap::new()),
73            append_inflight_bytes_sema: append_inflight_bytes,
74            durability_notifier,
75            bgtask_trigger_tx,
76        }
77    }
78
79    pub(super) fn bgtask_trigger(&self, trigger: BgtaskTrigger) {
80        let _ = self.bgtask_trigger_tx.send(trigger);
81    }
82
83    pub(super) fn bgtask_trigger_subscribe(&self) -> broadcast::Receiver<BgtaskTrigger> {
84        self.bgtask_trigger_tx.subscribe()
85    }
86
87    async fn start_streamer(
88        &self,
89        generation_id: StreamerGenerationId,
90        basin: BasinName,
91        stream: StreamName,
92    ) -> Result<StreamerClient, StreamerError> {
93        let stream_id = StreamId::new(&basin, &stream);
94
95        let (meta, tail_pos, fencing_token, trim_point) = tokio::try_join!(
96            self.db_get(
97                kv::stream_meta::ser_key(&basin, &stream),
98                kv::stream_meta::deser_value,
99            ),
100            self.db_get(
101                kv::stream_tail_position::ser_key(stream_id),
102                kv::stream_tail_position::deser_value,
103            ),
104            self.db_get(
105                kv::stream_fencing_token::ser_key(stream_id),
106                kv::stream_fencing_token::deser_value,
107            ),
108            self.db_get(
109                kv::stream_trim_point::ser_key(stream_id),
110                kv::stream_trim_point::deser_value,
111            )
112        )?;
113
114        let Some(meta) = meta else {
115            return Err(StreamNotFoundError { basin, stream }.into());
116        };
117
118        let tail_pos = tail_pos.map(|(pos, _)| pos).unwrap_or(StreamPosition::MIN);
119        self.assert_no_records_following_tail(stream_id, &basin, &stream, tail_pos)
120            .await?;
121
122        let fencing_token = fencing_token.unwrap_or_default();
123
124        if trim_point == Some(..NonZeroSeqNum::MAX) {
125            return Err(StreamDeletionPendingError { basin, stream }.into());
126        }
127
128        let streamer_slots = self.streamer_slots.clone();
129        Ok(super::streamer::Spawner {
130            generation_id,
131            db: self.db.clone(),
132            stream_id,
133            config: meta.config,
134            cipher: meta.cipher,
135            tail_pos,
136            fencing_token,
137            trim_point: ..trim_point.map_or(SeqNum::MIN, |tp| tp.end.get()),
138            append_inflight_bytes_sema: self.append_inflight_bytes_sema.clone(),
139            durability_notifier: self.durability_notifier.clone(),
140            bgtask_trigger_tx: self.bgtask_trigger_tx.clone(),
141        }
142        .spawn(move |client_id| {
143            streamer_slots.remove_if(&stream_id, |_, slot| {
144                matches!(slot, StreamerClientSlot::Ready { client } if client.generation_id() == client_id)
145            });
146        }))
147    }
148
149    async fn assert_no_records_following_tail(
150        &self,
151        stream_id: StreamId,
152        basin: &BasinName,
153        stream: &StreamName,
154        tail_pos: StreamPosition,
155    ) -> Result<(), StorageError> {
156        let start_key = kv::stream_record_data::ser_key(
157            stream_id,
158            StreamPosition {
159                seq_num: tail_pos.seq_num,
160                timestamp: 0,
161            },
162        );
163        static SCAN_OPTS: ScanOptions = ScanOptions {
164            durability_filter: DurabilityLevel::Remote,
165            dirty: false,
166            read_ahead_bytes: 1,
167            cache_blocks: false,
168            max_fetch_tasks: 1,
169            order: IterationOrder::Ascending,
170        };
171        let mut it = self.db.scan_with_options(start_key.., &SCAN_OPTS).await?;
172        let Some(kv) = it.next().await? else {
173            return Ok(());
174        };
175        if kv.key.first().copied() != Some(kv::KeyType::StreamRecordData as u8) {
176            return Ok(());
177        }
178        let (deser_stream_id, pos) = kv::stream_record_data::deser_key(kv.key)?;
179        assert!(
180            deser_stream_id != stream_id,
181            "invariant violation: stream `{basin}/{stream}` tail_pos {tail_pos:?} but found record at {pos:?}"
182        );
183        Ok(())
184    }
185
186    fn streamer_client_slot(&self, basin: &BasinName, stream: &StreamName) -> StreamerClientSlot {
187        match self.streamer_slots.entry(StreamId::new(basin, stream)) {
188            dashmap::Entry::Occupied(mut oe) => {
189                if matches!(oe.get(), StreamerClientSlot::Ready { client } if client.is_dead()) {
190                    let slot = self.clone().new_initializing_slot(basin, stream);
191                    oe.insert(slot.clone());
192                    slot
193                } else {
194                    oe.get().clone()
195                }
196            }
197            dashmap::Entry::Vacant(ve) => {
198                let slot = self.clone().new_initializing_slot(basin, stream);
199                ve.insert(slot.clone());
200                slot
201            }
202        }
203    }
204
205    fn new_initializing_slot(self, basin: &BasinName, stream: &StreamName) -> StreamerClientSlot {
206        let basin = basin.clone();
207        let stream = stream.clone();
208        let generation_id = StreamerGenerationId::next();
209        let future = async move { self.start_streamer(generation_id, basin, stream).await }
210            .boxed()
211            .shared();
212        StreamerClientSlot::Initializing {
213            generation_id,
214            future,
215        }
216    }
217
218    fn streamer_finish_initialization(
219        &self,
220        stream_id: StreamId,
221        generation_id: StreamerGenerationId,
222        result: &Result<StreamerClient, StreamerError>,
223    ) {
224        if let dashmap::Entry::Occupied(mut oe) = self.streamer_slots.entry(stream_id) {
225            let is_same_init = matches!(
226                oe.get(),
227                StreamerClientSlot::Initializing {
228                    generation_id: state_generation_id,
229                    ..
230                } if *state_generation_id == generation_id
231            );
232            if is_same_init {
233                match result {
234                    Ok(client) => {
235                        debug_assert_eq!(client.generation_id(), generation_id);
236                        if client.is_dead() {
237                            oe.remove();
238                        } else {
239                            oe.insert(StreamerClientSlot::Ready {
240                                client: client.clone(),
241                            });
242                        }
243                    }
244                    Err(_) => {
245                        oe.remove();
246                    }
247                }
248            }
249        }
250    }
251
252    pub(super) async fn streamer_client(
253        &self,
254        basin: &BasinName,
255        stream: &StreamName,
256    ) -> Result<StreamerClient, StreamerError> {
257        let stream_id = StreamId::new(basin, stream);
258        match self.streamer_client_slot(basin, stream) {
259            StreamerClientSlot::Initializing {
260                generation_id,
261                future,
262            } => {
263                let result = future.await;
264                self.streamer_finish_initialization(stream_id, generation_id, &result);
265                result
266            }
267            StreamerClientSlot::Ready { client } => Ok(client),
268        }
269    }
270
271    pub(super) fn streamer_client_if_active(
272        &self,
273        basin: &BasinName,
274        stream: &StreamName,
275    ) -> Option<StreamerClient> {
276        let stream_id = StreamId::new(basin, stream);
277        let slot = self.streamer_slots.get(&stream_id)?;
278        match slot.value() {
279            StreamerClientSlot::Ready { client } if !client.is_dead() => Some(client.clone()),
280            _ => None,
281        }
282    }
283
284    pub(super) async fn streamer_client_guarded(
285        &self,
286        basin: &BasinName,
287        stream: &StreamName,
288    ) -> Result<GuardedStreamerClient, StreamerError> {
289        loop {
290            let client = self.streamer_client(basin, stream).await?;
291            match client.guard() {
292                Ok(client) => return Ok(client),
293                Err(StreamerMissingInActionError) => continue,
294            }
295        }
296    }
297
298    pub(super) async fn stream_handle_with_auto_create<E>(
299        &self,
300        basin: &BasinName,
301        stream: &StreamName,
302        should_auto_create: impl FnOnce(&BasinConfig) -> bool,
303        resolve_encryption: impl FnOnce(Option<EncryptionAlgorithm>) -> Result<EncryptionSpec, E>,
304    ) -> Result<StreamHandle, E>
305    where
306        E: From<StreamerError>
307            + From<StorageError>
308            + From<BasinNotFoundError>
309            + From<TransactionConflictError>
310            + From<BasinDeletionPendingError>
311            + From<StreamDeletionPendingError>
312            + From<StreamNotFoundError>,
313    {
314        match self.streamer_client_guarded(basin, stream).await {
315            Ok(client) => Ok(StreamHandle {
316                db: self.db.clone(),
317                encryption: resolve_encryption(client.cipher())?,
318                client,
319            }),
320            Err(StreamerError::StreamNotFound(e)) => {
321                let config = match self.get_basin_config(basin.clone()).await {
322                    Ok(config) => config,
323                    Err(GetBasinConfigError::Storage(e)) => Err(e)?,
324                    Err(GetBasinConfigError::BasinNotFound(e)) => Err(e)?,
325                };
326                if should_auto_create(&config) {
327                    if let Err(e) = self
328                        .create_stream(
329                            basin.clone(),
330                            stream.clone(),
331                            OptionalStreamConfig::default(),
332                            CreateMode::CreateOnly(None),
333                        )
334                        .await
335                    {
336                        match e {
337                            CreateStreamError::Storage(e) => Err(e)?,
338                            CreateStreamError::TransactionConflict(e) => Err(e)?,
339                            CreateStreamError::BasinDeletionPending(e) => Err(e)?,
340                            CreateStreamError::StreamDeletionPending(e) => Err(e)?,
341                            CreateStreamError::BasinNotFound(e) => Err(e)?,
342                            CreateStreamError::StreamAlreadyExists(_) => {}
343                            CreateStreamError::Validation(_) => {
344                                unreachable!("auto-create uses default config")
345                            }
346                        }
347                    }
348                    let client = self.streamer_client_guarded(basin, stream).await?;
349                    let encryption = resolve_encryption(client.cipher())?;
350                    Ok(StreamHandle {
351                        db: self.db.clone(),
352                        encryption,
353                        client,
354                    })
355                } else {
356                    Err(e.into())
357                }
358            }
359            Err(e) => Err(e.into()),
360        }
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use std::str::FromStr as _;
367
368    use bytes::Bytes;
369    use s2_common::{
370        record::{Metered, Record, StoredRecord, StreamPosition},
371        types::{
372            config::{BasinConfig, OptionalStreamConfig},
373            resources::CreateMode,
374        },
375    };
376    use slatedb::{WriteBatch, config::WriteOptions, object_store};
377    use time::OffsetDateTime;
378
379    use super::*;
380
381    async fn new_test_backend() -> Backend {
382        let object_store: Arc<dyn object_store::ObjectStore> =
383            Arc::new(object_store::memory::InMemory::new());
384        let db = slatedb::Db::builder("test", object_store)
385            .build()
386            .await
387            .unwrap();
388        Backend::new(db, ByteSize::b(1))
389    }
390
391    #[tokio::test]
392    #[should_panic(expected = "invariant violation: stream `testbasin1/stream1` tail_pos")]
393    async fn start_streamer_fails_if_records_exist_after_tail_pos() {
394        let backend = new_test_backend().await;
395
396        let basin = BasinName::from_str("testbasin1").unwrap();
397        let stream = StreamName::from_str("stream1").unwrap();
398        let stream_id = StreamId::new(&basin, &stream);
399
400        let meta = kv::stream_meta::StreamMeta {
401            config: OptionalStreamConfig::default(),
402            cipher: None,
403            created_at: OffsetDateTime::now_utc(),
404            deleted_at: None,
405            creation_idempotency_key: None,
406        };
407
408        let tail_pos = StreamPosition {
409            seq_num: 1,
410            timestamp: 123,
411        };
412        let record_pos = StreamPosition {
413            seq_num: tail_pos.seq_num,
414            timestamp: tail_pos.timestamp,
415        };
416
417        let record = Record::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap();
418        let metered_record: Metered<StoredRecord> = StoredRecord::from(record).into();
419
420        let mut wb = WriteBatch::new();
421        wb.put(
422            kv::stream_meta::ser_key(&basin, &stream),
423            kv::stream_meta::ser_value(&meta),
424        );
425        wb.put(
426            kv::stream_tail_position::ser_key(stream_id),
427            kv::stream_tail_position::ser_value(
428                tail_pos,
429                kv::timestamp::TimestampSecs::from_secs(1),
430            ),
431        );
432        wb.put(
433            kv::stream_record_data::ser_key(stream_id, record_pos),
434            kv::stream_record_data::ser_value(metered_record.as_ref()),
435        );
436        static WRITE_OPTS: WriteOptions = WriteOptions {
437            await_durable: true,
438        };
439        backend
440            .db
441            .write_with_options(wb, &WRITE_OPTS)
442            .await
443            .unwrap();
444
445        backend
446            .start_streamer(StreamerGenerationId::next(), basin.clone(), stream.clone())
447            .await
448            .unwrap();
449    }
450
451    #[tokio::test]
452    async fn streamer_client_slot_uses_single_initializer() {
453        let backend = new_test_backend().await;
454        let basin = BasinName::from_str("testbasin2").unwrap();
455        let stream = StreamName::from_str("stream2").unwrap();
456
457        let slot_1 = backend.streamer_client_slot(&basin, &stream);
458        let slot_2 = backend.streamer_client_slot(&basin, &stream);
459
460        let (generation_id_1, generation_id_2) = match (slot_1, slot_2) {
461            (
462                StreamerClientSlot::Initializing {
463                    generation_id: generation_id_1,
464                    ..
465                },
466                StreamerClientSlot::Initializing {
467                    generation_id: generation_id_2,
468                    ..
469                },
470            ) => (generation_id_1, generation_id_2),
471            _ => panic!("expected both slots to be Initializing"),
472        };
473        assert_eq!(generation_id_1, generation_id_2);
474        assert_eq!(backend.streamer_slots.len(), 1);
475    }
476
477    #[tokio::test]
478    async fn streamer_client_if_active_is_peek_only() {
479        let backend = new_test_backend().await;
480        let basin = BasinName::from_str("testbasin3").unwrap();
481        let stream = StreamName::from_str("stream3").unwrap();
482
483        backend
484            .create_basin(
485                basin.clone(),
486                BasinConfig::default(),
487                CreateMode::CreateOnly(None),
488            )
489            .await
490            .unwrap();
491        backend
492            .create_stream(
493                basin.clone(),
494                stream.clone(),
495                OptionalStreamConfig::default(),
496                CreateMode::CreateOnly(None),
497            )
498            .await
499            .unwrap();
500
501        assert!(backend.streamer_slots.is_empty());
502        assert!(backend.streamer_client_if_active(&basin, &stream).is_none());
503        assert!(backend.streamer_slots.is_empty());
504    }
505
506    #[tokio::test]
507    async fn streamer_client_failed_init_is_not_memoized() {
508        let backend = new_test_backend().await;
509        let basin = BasinName::from_str("testbasin4").unwrap();
510        let stream = StreamName::from_str("stream4").unwrap();
511        let stream_id = StreamId::new(&basin, &stream);
512
513        for _ in 0..2 {
514            let err = backend.streamer_client(&basin, &stream).await;
515            assert!(matches!(err, Err(StreamerError::StreamNotFound(_))));
516            assert!(
517                backend.streamer_slots.get(&stream_id).is_none(),
518                "failed init should not be cached"
519            );
520        }
521    }
522
523    #[tokio::test]
524    async fn streamer_finish_initialization_ignores_stale_generation_id() {
525        let backend = new_test_backend().await;
526        let basin = BasinName::from_str("testbasin5").unwrap();
527        let stream = StreamName::from_str("stream5").unwrap();
528        let stream_id = StreamId::new(&basin, &stream);
529
530        let stale_generation_id = StreamerGenerationId::next();
531        let current_generation_id = StreamerGenerationId::next();
532        let future = futures::future::pending::<Result<StreamerClient, StreamerError>>()
533            .boxed()
534            .shared();
535        backend.streamer_slots.insert(
536            stream_id,
537            StreamerClientSlot::Initializing {
538                generation_id: current_generation_id,
539                future: future.clone(),
540            },
541        );
542
543        let stale_result = Err(StreamNotFoundError { basin, stream }.into());
544        backend.streamer_finish_initialization(stream_id, stale_generation_id, &stale_result);
545
546        let Some(slot) = backend.streamer_slots.get(&stream_id) else {
547            panic!("stale init completion should not alter slot state");
548        };
549        match slot.value() {
550            StreamerClientSlot::Initializing { generation_id, .. } => {
551                assert_eq!(*generation_id, current_generation_id)
552            }
553            _ => panic!("expected initializing slot to remain unchanged"),
554        }
555    }
556}