Skip to main content

s2_lite/backend/
core.rs

1use std::sync::Arc;
2
3use bytesize::ByteSize;
4use dashmap::DashMap;
5use futures::{
6    FutureExt as _,
7    future::{BoxFuture, Shared},
8};
9use s2_common::{
10    encryption::{EncryptionAlgorithm, EncryptionSpec},
11    record::{NonZeroSeqNum, SeqNum, StreamPosition},
12    types::{
13        basin::BasinName,
14        config::{BasinConfig, OptionalStreamConfig},
15        stream::{CreateStreamIntent, StreamName},
16    },
17};
18use slatedb::{
19    IterationOrder,
20    config::{DurabilityLevel, ScanOptions},
21};
22use tokio::sync::{Semaphore, broadcast};
23
24use super::{
25    StreamHandle,
26    durability_notifier::DurabilityNotifier,
27    error::{
28        BasinDeletionPendingError, BasinNotFoundError, CreateStreamError, GetBasinConfigError,
29        StorageError, StreamDeletionPendingError, StreamNotFoundError, StreamerError,
30        StreamerMissingInActionError, TransactionConflictError,
31    },
32    kv,
33    streamer::{GuardedStreamerClient, StreamerClient, StreamerGenerationId},
34};
35use crate::{backend::bgtasks::BgtaskTrigger, stream_id::StreamId};
36
37type StreamerInitFuture = Shared<BoxFuture<'static, Result<StreamerClient, StreamerError>>>;
38
39#[derive(Clone)]
40enum StreamerClientSlot {
41    Initializing {
42        generation_id: StreamerGenerationId,
43        future: StreamerInitFuture,
44    },
45    Ready {
46        client: StreamerClient,
47    },
48}
49
50#[derive(Clone)]
51pub struct Backend {
52    pub(super) db: slatedb::Db,
53    streamer_slots: Arc<DashMap<StreamId, StreamerClientSlot>>,
54    append_inflight_bytes_sema: Arc<Semaphore>,
55    durability_notifier: DurabilityNotifier,
56    bgtask_trigger_tx: broadcast::Sender<BgtaskTrigger>,
57}
58
59impl Backend {
60    pub fn new(db: slatedb::Db, append_inflight_bytes: ByteSize) -> Self {
61        let (bgtask_trigger_tx, _) = broadcast::channel(16);
62        let append_inflight_bytes = Arc::new(Semaphore::new(
63            (append_inflight_bytes.as_u64() as usize).clamp(
64                s2_common::caps::RECORD_BATCH_MAX.bytes,
65                Semaphore::MAX_PERMITS,
66            ),
67        ));
68        let durability_notifier = DurabilityNotifier::spawn(&db);
69        Self {
70            db,
71            streamer_slots: Arc::new(DashMap::new()),
72            append_inflight_bytes_sema: append_inflight_bytes,
73            durability_notifier,
74            bgtask_trigger_tx,
75        }
76    }
77
78    pub(super) fn bgtask_trigger(&self, trigger: BgtaskTrigger) {
79        let _ = self.bgtask_trigger_tx.send(trigger);
80    }
81
82    pub(super) fn bgtask_trigger_subscribe(&self) -> broadcast::Receiver<BgtaskTrigger> {
83        self.bgtask_trigger_tx.subscribe()
84    }
85
86    async fn start_streamer(
87        &self,
88        generation_id: StreamerGenerationId,
89        basin: BasinName,
90        stream: StreamName,
91    ) -> Result<StreamerClient, StreamerError> {
92        let stream_id = StreamId::new(&basin, &stream);
93
94        let (meta, tail_pos, fencing_token, trim_point) = tokio::try_join!(
95            self.db_get(
96                kv::stream_meta::ser_key(&basin, &stream),
97                kv::stream_meta::deser_value,
98            ),
99            self.db_get(
100                kv::stream_tail_position::ser_key(stream_id),
101                kv::stream_tail_position::deser_value,
102            ),
103            self.db_get(
104                kv::stream_fencing_token::ser_key(stream_id),
105                kv::stream_fencing_token::deser_value,
106            ),
107            self.db_get(
108                kv::stream_trim_point::ser_key(stream_id),
109                kv::stream_trim_point::deser_value,
110            )
111        )?;
112
113        let Some(meta) = meta else {
114            return Err(StreamNotFoundError { basin, stream }.into());
115        };
116
117        let tail_pos = tail_pos.map(|(pos, _)| pos).unwrap_or(StreamPosition::MIN);
118        self.assert_no_records_following_tail(stream_id, &basin, &stream, tail_pos)
119            .await?;
120
121        let fencing_token = fencing_token.unwrap_or_default();
122
123        if trim_point == Some(..NonZeroSeqNum::MAX) {
124            return Err(StreamDeletionPendingError { basin, stream }.into());
125        }
126
127        let streamer_slots = self.streamer_slots.clone();
128        Ok(super::streamer::Spawner {
129            generation_id,
130            db: self.db.clone(),
131            stream_id,
132            config: meta.config,
133            cipher: meta.cipher,
134            tail_pos,
135            fencing_token,
136            trim_point: ..trim_point.map_or(SeqNum::MIN, |tp| tp.end.get()),
137            append_inflight_bytes_sema: self.append_inflight_bytes_sema.clone(),
138            durability_notifier: self.durability_notifier.clone(),
139            bgtask_trigger_tx: self.bgtask_trigger_tx.clone(),
140        }
141        .spawn(move |client_id| {
142            streamer_slots.remove_if(&stream_id, |_, slot| {
143                matches!(slot, StreamerClientSlot::Ready { client } if client.generation_id() == client_id)
144            });
145        }))
146    }
147
148    async fn assert_no_records_following_tail(
149        &self,
150        stream_id: StreamId,
151        basin: &BasinName,
152        stream: &StreamName,
153        tail_pos: StreamPosition,
154    ) -> Result<(), StorageError> {
155        let start_key = kv::stream_record_data::ser_key(
156            stream_id,
157            StreamPosition {
158                seq_num: tail_pos.seq_num,
159                timestamp: 0,
160            },
161        );
162        static SCAN_OPTS: ScanOptions = ScanOptions {
163            durability_filter: DurabilityLevel::Remote,
164            dirty: false,
165            read_ahead_bytes: 1,
166            cache_blocks: false,
167            max_fetch_tasks: 1,
168            order: IterationOrder::Ascending,
169        };
170        let mut it = self.db.scan_with_options(start_key.., &SCAN_OPTS).await?;
171        let Some(kv) = it.next().await? else {
172            return Ok(());
173        };
174        if kv.key.first().copied() != Some(kv::KeyType::StreamRecordData as u8) {
175            return Ok(());
176        }
177        let (deser_stream_id, pos) = kv::stream_record_data::deser_key(kv.key)?;
178        assert!(
179            deser_stream_id != stream_id,
180            "invariant violation: stream `{basin}/{stream}` tail_pos {tail_pos:?} but found record at {pos:?}"
181        );
182        Ok(())
183    }
184
185    fn streamer_client_slot(&self, basin: &BasinName, stream: &StreamName) -> StreamerClientSlot {
186        match self.streamer_slots.entry(StreamId::new(basin, stream)) {
187            dashmap::Entry::Occupied(mut oe) => {
188                if matches!(oe.get(), StreamerClientSlot::Ready { client } if client.is_dead()) {
189                    let slot = self.clone().new_initializing_slot(basin, stream);
190                    oe.insert(slot.clone());
191                    slot
192                } else {
193                    oe.get().clone()
194                }
195            }
196            dashmap::Entry::Vacant(ve) => {
197                let slot = self.clone().new_initializing_slot(basin, stream);
198                ve.insert(slot.clone());
199                slot
200            }
201        }
202    }
203
204    fn new_initializing_slot(self, basin: &BasinName, stream: &StreamName) -> StreamerClientSlot {
205        let basin = basin.clone();
206        let stream = stream.clone();
207        let generation_id = StreamerGenerationId::next();
208        let future = async move { self.start_streamer(generation_id, basin, stream).await }
209            .boxed()
210            .shared();
211        StreamerClientSlot::Initializing {
212            generation_id,
213            future,
214        }
215    }
216
217    fn streamer_finish_initialization(
218        &self,
219        stream_id: StreamId,
220        generation_id: StreamerGenerationId,
221        result: &Result<StreamerClient, StreamerError>,
222    ) {
223        if let dashmap::Entry::Occupied(mut oe) = self.streamer_slots.entry(stream_id) {
224            let is_same_init = matches!(
225                oe.get(),
226                StreamerClientSlot::Initializing {
227                    generation_id: state_generation_id,
228                    ..
229                } if *state_generation_id == generation_id
230            );
231            if is_same_init {
232                match result {
233                    Ok(client) => {
234                        debug_assert_eq!(client.generation_id(), generation_id);
235                        if client.is_dead() {
236                            oe.remove();
237                        } else {
238                            oe.insert(StreamerClientSlot::Ready {
239                                client: client.clone(),
240                            });
241                        }
242                    }
243                    Err(_) => {
244                        oe.remove();
245                    }
246                }
247            }
248        }
249    }
250
251    pub(super) async fn streamer_client(
252        &self,
253        basin: &BasinName,
254        stream: &StreamName,
255    ) -> Result<StreamerClient, StreamerError> {
256        let stream_id = StreamId::new(basin, stream);
257        match self.streamer_client_slot(basin, stream) {
258            StreamerClientSlot::Initializing {
259                generation_id,
260                future,
261            } => {
262                let result = future.await;
263                self.streamer_finish_initialization(stream_id, generation_id, &result);
264                result
265            }
266            StreamerClientSlot::Ready { client } => Ok(client),
267        }
268    }
269
270    pub(super) fn streamer_client_if_active(
271        &self,
272        basin: &BasinName,
273        stream: &StreamName,
274    ) -> Option<StreamerClient> {
275        let stream_id = StreamId::new(basin, stream);
276        let slot = self.streamer_slots.get(&stream_id)?;
277        match slot.value() {
278            StreamerClientSlot::Ready { client } if !client.is_dead() => Some(client.clone()),
279            _ => None,
280        }
281    }
282
283    pub(super) async fn streamer_client_guarded(
284        &self,
285        basin: &BasinName,
286        stream: &StreamName,
287    ) -> Result<GuardedStreamerClient, StreamerError> {
288        loop {
289            let client = self.streamer_client(basin, stream).await?;
290            match client.guard() {
291                Ok(client) => return Ok(client),
292                Err(StreamerMissingInActionError) => continue,
293            }
294        }
295    }
296
297    pub(super) async fn stream_handle_with_auto_create<E>(
298        &self,
299        basin: &BasinName,
300        stream: &StreamName,
301        should_auto_create: impl FnOnce(&BasinConfig) -> bool,
302        resolve_encryption: impl FnOnce(Option<EncryptionAlgorithm>) -> Result<EncryptionSpec, E>,
303    ) -> Result<StreamHandle, E>
304    where
305        E: From<StreamerError>
306            + From<StorageError>
307            + From<BasinNotFoundError>
308            + From<TransactionConflictError>
309            + From<BasinDeletionPendingError>
310            + From<StreamDeletionPendingError>
311            + From<StreamNotFoundError>,
312    {
313        match self.streamer_client_guarded(basin, stream).await {
314            Ok(client) => Ok(StreamHandle {
315                db: self.db.clone(),
316                encryption: resolve_encryption(client.cipher())?,
317                client,
318            }),
319            Err(StreamerError::StreamNotFound(e)) => {
320                let config = match self.get_basin_config(basin.clone()).await {
321                    Ok(config) => config,
322                    Err(GetBasinConfigError::Storage(e)) => Err(e)?,
323                    Err(GetBasinConfigError::BasinNotFound(e)) => Err(e)?,
324                };
325                if should_auto_create(&config) {
326                    if let Err(e) = self
327                        .create_stream(
328                            basin.clone(),
329                            stream.clone(),
330                            CreateStreamIntent::CreateOnly {
331                                config: OptionalStreamConfig::default(),
332                                request_token: None,
333                            },
334                        )
335                        .await
336                    {
337                        match e {
338                            CreateStreamError::Storage(e) => Err(e)?,
339                            CreateStreamError::TransactionConflict(e) => Err(e)?,
340                            CreateStreamError::BasinDeletionPending(e) => Err(e)?,
341                            CreateStreamError::StreamDeletionPending(e) => Err(e)?,
342                            CreateStreamError::BasinNotFound(e) => Err(e)?,
343                            CreateStreamError::StreamAlreadyExists(_) => {}
344                            CreateStreamError::Validation(_) => {
345                                unreachable!("auto-create uses default config")
346                            }
347                        }
348                    }
349                    let client = self.streamer_client_guarded(basin, stream).await?;
350                    let encryption = resolve_encryption(client.cipher())?;
351                    Ok(StreamHandle {
352                        db: self.db.clone(),
353                        encryption,
354                        client,
355                    })
356                } else {
357                    Err(e.into())
358                }
359            }
360            Err(e) => Err(e.into()),
361        }
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use std::str::FromStr as _;
368
369    use bytes::Bytes;
370    use s2_common::{
371        record::{Metered, Record, StoredRecord, StreamPosition},
372        types::{
373            basin::CreateBasinIntent,
374            config::{BasinConfig, OptionalStreamConfig},
375        },
376    };
377    use slatedb::{WriteBatch, config::WriteOptions, object_store};
378    use time::OffsetDateTime;
379
380    use super::*;
381
382    async fn new_test_backend() -> Backend {
383        let object_store: Arc<dyn object_store::ObjectStore> =
384            Arc::new(object_store::memory::InMemory::new());
385        let db = slatedb::Db::builder("test", object_store)
386            .build()
387            .await
388            .unwrap();
389        Backend::new(db, ByteSize::b(1))
390    }
391
392    #[tokio::test]
393    #[should_panic(expected = "invariant violation: stream `testbasin1/stream1` tail_pos")]
394    async fn start_streamer_fails_if_records_exist_after_tail_pos() {
395        let backend = new_test_backend().await;
396
397        let basin = BasinName::from_str("testbasin1").unwrap();
398        let stream = StreamName::from_str("stream1").unwrap();
399        let stream_id = StreamId::new(&basin, &stream);
400
401        let meta = kv::stream_meta::StreamMeta {
402            config: OptionalStreamConfig::default(),
403            cipher: None,
404            created_at: OffsetDateTime::now_utc(),
405            deleted_at: None,
406            creation_idempotency_key: None,
407        };
408
409        let tail_pos = StreamPosition {
410            seq_num: 1,
411            timestamp: 123,
412        };
413        let record_pos = StreamPosition {
414            seq_num: tail_pos.seq_num,
415            timestamp: tail_pos.timestamp,
416        };
417
418        let record = Record::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap();
419        let metered_record: Metered<StoredRecord> = StoredRecord::from(record).into();
420
421        let mut wb = WriteBatch::new();
422        wb.put(
423            kv::stream_meta::ser_key(&basin, &stream),
424            kv::stream_meta::ser_value(&meta),
425        );
426        wb.put(
427            kv::stream_tail_position::ser_key(stream_id),
428            kv::stream_tail_position::ser_value(
429                tail_pos,
430                kv::timestamp::TimestampSecs::from_secs(1),
431            ),
432        );
433        wb.put(
434            kv::stream_record_data::ser_key(stream_id, record_pos),
435            kv::stream_record_data::ser_value(metered_record.as_ref()),
436        );
437        static WRITE_OPTS: WriteOptions = WriteOptions {
438            await_durable: true,
439        };
440        backend
441            .db
442            .write_with_options(wb, &WRITE_OPTS)
443            .await
444            .unwrap();
445
446        backend
447            .start_streamer(StreamerGenerationId::next(), basin.clone(), stream.clone())
448            .await
449            .unwrap();
450    }
451
452    #[tokio::test]
453    async fn streamer_client_slot_uses_single_initializer() {
454        let backend = new_test_backend().await;
455        let basin = BasinName::from_str("testbasin2").unwrap();
456        let stream = StreamName::from_str("stream2").unwrap();
457
458        let slot_1 = backend.streamer_client_slot(&basin, &stream);
459        let slot_2 = backend.streamer_client_slot(&basin, &stream);
460
461        let (generation_id_1, generation_id_2) = match (slot_1, slot_2) {
462            (
463                StreamerClientSlot::Initializing {
464                    generation_id: generation_id_1,
465                    ..
466                },
467                StreamerClientSlot::Initializing {
468                    generation_id: generation_id_2,
469                    ..
470                },
471            ) => (generation_id_1, generation_id_2),
472            _ => panic!("expected both slots to be Initializing"),
473        };
474        assert_eq!(generation_id_1, generation_id_2);
475        assert_eq!(backend.streamer_slots.len(), 1);
476    }
477
478    #[tokio::test]
479    async fn streamer_client_if_active_is_peek_only() {
480        let backend = new_test_backend().await;
481        let basin = BasinName::from_str("testbasin3").unwrap();
482        let stream = StreamName::from_str("stream3").unwrap();
483
484        backend
485            .create_basin(
486                basin.clone(),
487                CreateBasinIntent::CreateOnly {
488                    config: BasinConfig::default(),
489                    request_token: None,
490                },
491            )
492            .await
493            .unwrap();
494        backend
495            .create_stream(
496                basin.clone(),
497                stream.clone(),
498                CreateStreamIntent::CreateOnly {
499                    config: OptionalStreamConfig::default(),
500                    request_token: None,
501                },
502            )
503            .await
504            .unwrap();
505
506        assert!(backend.streamer_slots.is_empty());
507        assert!(backend.streamer_client_if_active(&basin, &stream).is_none());
508        assert!(backend.streamer_slots.is_empty());
509    }
510
511    #[tokio::test]
512    async fn streamer_client_failed_init_is_not_memoized() {
513        let backend = new_test_backend().await;
514        let basin = BasinName::from_str("testbasin4").unwrap();
515        let stream = StreamName::from_str("stream4").unwrap();
516        let stream_id = StreamId::new(&basin, &stream);
517
518        for _ in 0..2 {
519            let err = backend.streamer_client(&basin, &stream).await;
520            assert!(matches!(err, Err(StreamerError::StreamNotFound(_))));
521            assert!(
522                backend.streamer_slots.get(&stream_id).is_none(),
523                "failed init should not be cached"
524            );
525        }
526    }
527
528    #[tokio::test]
529    async fn streamer_finish_initialization_ignores_stale_generation_id() {
530        let backend = new_test_backend().await;
531        let basin = BasinName::from_str("testbasin5").unwrap();
532        let stream = StreamName::from_str("stream5").unwrap();
533        let stream_id = StreamId::new(&basin, &stream);
534
535        let stale_generation_id = StreamerGenerationId::next();
536        let current_generation_id = StreamerGenerationId::next();
537        let future = futures::future::pending::<Result<StreamerClient, StreamerError>>()
538            .boxed()
539            .shared();
540        backend.streamer_slots.insert(
541            stream_id,
542            StreamerClientSlot::Initializing {
543                generation_id: current_generation_id,
544                future: future.clone(),
545            },
546        );
547
548        let stale_result = Err(StreamNotFoundError { basin, stream }.into());
549        backend.streamer_finish_initialization(stream_id, stale_generation_id, &stale_result);
550
551        let Some(slot) = backend.streamer_slots.get(&stream_id) else {
552            panic!("stale init completion should not alter slot state");
553        };
554        match slot.value() {
555            StreamerClientSlot::Initializing { generation_id, .. } => {
556                assert_eq!(*generation_id, current_generation_id)
557            }
558            _ => panic!("expected initializing slot to remain unchanged"),
559        }
560    }
561}