Skip to main content

s2_lite/backend/
core.rs

1use std::sync::Arc;
2
3use bytesize::ByteSize;
4use dashmap::DashMap;
5use futures::{
6    FutureExt as _,
7    future::{BoxFuture, Shared},
8};
9use s2_common::{
10    encryption::{EncryptionAlgorithm, EncryptionSpec},
11    record::{NonZeroSeqNum, SeqNum, StreamPosition},
12    types::{
13        basin::BasinName,
14        config::{BasinConfig, OptionalStreamConfig},
15        resources::ProvisionMode,
16        stream::StreamName,
17    },
18};
19use slatedb::{
20    IterationOrder,
21    config::{DurabilityLevel, ScanOptions},
22};
23use tokio::sync::{Semaphore, broadcast};
24
25use super::{
26    StreamHandle,
27    durability_notifier::DurabilityNotifier,
28    error::{
29        BasinDeletionPendingError, BasinNotFoundError, GetBasinConfigError, ProvisionStreamError,
30        StorageError, StreamDeletionPendingError, StreamNotFoundError, StreamerError,
31        StreamerMissingInActionError, TransactionConflictError,
32    },
33    kv,
34    streamer::{GuardedStreamerClient, StreamerClient, StreamerGenerationId},
35};
36use crate::{backend::bgtasks::BgtaskTrigger, stream_id::StreamId};
37
38type StreamerInitFuture = Shared<BoxFuture<'static, Result<StreamerClient, StreamerError>>>;
39
40#[derive(Clone)]
41enum StreamerClientSlot {
42    Initializing {
43        generation_id: StreamerGenerationId,
44        future: StreamerInitFuture,
45    },
46    Ready {
47        client: StreamerClient,
48    },
49}
50
51#[derive(Clone)]
52pub struct Backend {
53    pub(super) db: slatedb::Db,
54    streamer_slots: Arc<DashMap<StreamId, StreamerClientSlot>>,
55    append_inflight_bytes_sema: Arc<Semaphore>,
56    durability_notifier: DurabilityNotifier,
57    bgtask_trigger_tx: broadcast::Sender<BgtaskTrigger>,
58}
59
60impl Backend {
61    pub fn new(db: slatedb::Db, append_inflight_bytes: ByteSize) -> Self {
62        let (bgtask_trigger_tx, _) = broadcast::channel(16);
63        let append_inflight_bytes = Arc::new(Semaphore::new(
64            (append_inflight_bytes.as_u64() as usize).clamp(
65                s2_common::caps::RECORD_BATCH_MAX.bytes,
66                Semaphore::MAX_PERMITS,
67            ),
68        ));
69        let durability_notifier = DurabilityNotifier::spawn(&db);
70        Self {
71            db,
72            streamer_slots: Arc::new(DashMap::new()),
73            append_inflight_bytes_sema: append_inflight_bytes,
74            durability_notifier,
75            bgtask_trigger_tx,
76        }
77    }
78
79    pub(super) fn bgtask_trigger(&self, trigger: BgtaskTrigger) {
80        let _ = self.bgtask_trigger_tx.send(trigger);
81    }
82
83    pub(super) fn bgtask_trigger_subscribe(&self) -> broadcast::Receiver<BgtaskTrigger> {
84        self.bgtask_trigger_tx.subscribe()
85    }
86
87    async fn start_streamer(
88        &self,
89        generation_id: StreamerGenerationId,
90        basin: BasinName,
91        stream: StreamName,
92    ) -> Result<StreamerClient, StreamerError> {
93        let stream_id = StreamId::new(&basin, &stream);
94
95        let (meta, tail_pos, fencing_token, trim_point) = tokio::try_join!(
96            self.db_get(
97                kv::stream_meta::ser_key(&basin, &stream),
98                kv::stream_meta::deser_value,
99            ),
100            self.db_get(
101                kv::stream_tail_position::ser_key(stream_id),
102                kv::stream_tail_position::deser_value,
103            ),
104            self.db_get(
105                kv::stream_fencing_token::ser_key(stream_id),
106                kv::stream_fencing_token::deser_value,
107            ),
108            self.db_get(
109                kv::stream_trim_point::ser_key(stream_id),
110                kv::stream_trim_point::deser_value,
111            )
112        )?;
113
114        let Some(meta) = meta else {
115            return Err(StreamNotFoundError { basin, stream }.into());
116        };
117
118        let tail_pos = tail_pos.map(|(pos, _)| pos).unwrap_or(StreamPosition::MIN);
119        self.assert_no_records_following_tail(stream_id, &basin, &stream, tail_pos)
120            .await?;
121
122        let fencing_token = fencing_token.unwrap_or_default();
123
124        if trim_point == Some(..NonZeroSeqNum::MAX) {
125            return Err(StreamDeletionPendingError { basin, stream }.into());
126        }
127
128        let streamer_slots = self.streamer_slots.clone();
129        Ok(super::streamer::Spawner {
130            generation_id,
131            db: self.db.clone(),
132            stream_id,
133            config: meta.config,
134            cipher: meta.cipher,
135            tail_pos,
136            fencing_token,
137            trim_point: ..trim_point.map_or(SeqNum::MIN, |tp| tp.end.get()),
138            append_inflight_bytes_sema: self.append_inflight_bytes_sema.clone(),
139            durability_notifier: self.durability_notifier.clone(),
140            bgtask_trigger_tx: self.bgtask_trigger_tx.clone(),
141        }
142        .spawn(move |client_id| {
143            streamer_slots.remove_if(&stream_id, |_, slot| {
144                matches!(slot, StreamerClientSlot::Ready { client } if client.generation_id() == client_id)
145            });
146        }))
147    }
148
149    async fn assert_no_records_following_tail(
150        &self,
151        stream_id: StreamId,
152        basin: &BasinName,
153        stream: &StreamName,
154        tail_pos: StreamPosition,
155    ) -> Result<(), StorageError> {
156        let start_key = kv::stream_record_data::ser_key(
157            stream_id,
158            StreamPosition {
159                seq_num: tail_pos.seq_num,
160                timestamp: 0,
161            },
162        );
163        static SCAN_OPTS: ScanOptions = ScanOptions {
164            durability_filter: DurabilityLevel::Remote,
165            dirty: false,
166            read_ahead_bytes: 1,
167            cache_blocks: false,
168            max_fetch_tasks: 1,
169            order: IterationOrder::Ascending,
170        };
171        let mut it = self.db.scan_with_options(start_key.., &SCAN_OPTS).await?;
172        let Some(kv) = it.next().await? else {
173            return Ok(());
174        };
175        if kv.key.first().copied() != Some(kv::KeyType::StreamRecordData as u8) {
176            return Ok(());
177        }
178        let (deser_stream_id, pos) = kv::stream_record_data::deser_key(kv.key)?;
179        assert!(
180            deser_stream_id != stream_id,
181            "invariant violation: stream `{basin}/{stream}` tail_pos {tail_pos:?} but found record at {pos:?}"
182        );
183        Ok(())
184    }
185
186    fn streamer_client_slot(&self, basin: &BasinName, stream: &StreamName) -> StreamerClientSlot {
187        match self.streamer_slots.entry(StreamId::new(basin, stream)) {
188            dashmap::Entry::Occupied(mut oe) => {
189                if matches!(oe.get(), StreamerClientSlot::Ready { client } if client.is_dead()) {
190                    let slot = self.clone().new_initializing_slot(basin, stream);
191                    oe.insert(slot.clone());
192                    slot
193                } else {
194                    oe.get().clone()
195                }
196            }
197            dashmap::Entry::Vacant(ve) => {
198                let slot = self.clone().new_initializing_slot(basin, stream);
199                ve.insert(slot.clone());
200                slot
201            }
202        }
203    }
204
205    fn new_initializing_slot(self, basin: &BasinName, stream: &StreamName) -> StreamerClientSlot {
206        let basin = basin.clone();
207        let stream = stream.clone();
208        let generation_id = StreamerGenerationId::next();
209        let future = async move { self.start_streamer(generation_id, basin, stream).await }
210            .boxed()
211            .shared();
212        StreamerClientSlot::Initializing {
213            generation_id,
214            future,
215        }
216    }
217
218    fn streamer_finish_initialization(
219        &self,
220        stream_id: StreamId,
221        generation_id: StreamerGenerationId,
222        result: &Result<StreamerClient, StreamerError>,
223    ) {
224        if let dashmap::Entry::Occupied(mut oe) = self.streamer_slots.entry(stream_id) {
225            let is_same_init = matches!(
226                oe.get(),
227                StreamerClientSlot::Initializing {
228                    generation_id: state_generation_id,
229                    ..
230                } if *state_generation_id == generation_id
231            );
232            if is_same_init {
233                match result {
234                    Ok(client) => {
235                        debug_assert_eq!(client.generation_id(), generation_id);
236                        if client.is_dead() {
237                            oe.remove();
238                        } else {
239                            oe.insert(StreamerClientSlot::Ready {
240                                client: client.clone(),
241                            });
242                        }
243                    }
244                    Err(_) => {
245                        oe.remove();
246                    }
247                }
248            }
249        }
250    }
251
252    pub(super) async fn streamer_client(
253        &self,
254        basin: &BasinName,
255        stream: &StreamName,
256    ) -> Result<StreamerClient, StreamerError> {
257        let stream_id = StreamId::new(basin, stream);
258        match self.streamer_client_slot(basin, stream) {
259            StreamerClientSlot::Initializing {
260                generation_id,
261                future,
262            } => {
263                let result = future.await;
264                self.streamer_finish_initialization(stream_id, generation_id, &result);
265                result
266            }
267            StreamerClientSlot::Ready { client } => Ok(client),
268        }
269    }
270
271    pub(super) fn streamer_client_if_active(
272        &self,
273        basin: &BasinName,
274        stream: &StreamName,
275    ) -> Option<StreamerClient> {
276        let stream_id = StreamId::new(basin, stream);
277        let slot = self.streamer_slots.get(&stream_id)?;
278        match slot.value() {
279            StreamerClientSlot::Ready { client } if !client.is_dead() => Some(client.clone()),
280            _ => None,
281        }
282    }
283
284    pub(super) async fn streamer_client_guarded(
285        &self,
286        basin: &BasinName,
287        stream: &StreamName,
288    ) -> Result<GuardedStreamerClient, StreamerError> {
289        loop {
290            let client = self.streamer_client(basin, stream).await?;
291            match client.guard() {
292                Ok(client) => return Ok(client),
293                Err(StreamerMissingInActionError) => continue,
294            }
295        }
296    }
297
298    pub(super) async fn stream_handle_with_auto_create<E>(
299        &self,
300        basin: &BasinName,
301        stream: &StreamName,
302        should_auto_create: impl FnOnce(&BasinConfig) -> bool,
303        resolve_encryption: impl FnOnce(Option<EncryptionAlgorithm>) -> Result<EncryptionSpec, E>,
304    ) -> Result<StreamHandle, E>
305    where
306        E: From<StreamerError>
307            + From<StorageError>
308            + From<BasinNotFoundError>
309            + From<TransactionConflictError>
310            + From<BasinDeletionPendingError>
311            + From<StreamDeletionPendingError>
312            + From<StreamNotFoundError>,
313    {
314        match self.streamer_client_guarded(basin, stream).await {
315            Ok(client) => Ok(StreamHandle {
316                db: self.db.clone(),
317                encryption: resolve_encryption(client.cipher())?,
318                client,
319            }),
320            Err(StreamerError::StreamNotFound(e)) => {
321                let config = match self.get_basin_config(basin.clone()).await {
322                    Ok(config) => config,
323                    Err(GetBasinConfigError::Storage(e)) => Err(e)?,
324                    Err(GetBasinConfigError::BasinNotFound(e)) => Err(e)?,
325                };
326                if should_auto_create(&config) {
327                    if let Err(e) = self
328                        .provision_stream(
329                            basin.clone(),
330                            stream.clone(),
331                            OptionalStreamConfig::default(),
332                            ProvisionMode::CreateOnly {
333                                request_token: None,
334                            },
335                        )
336                        .await
337                    {
338                        match e {
339                            ProvisionStreamError::Storage(e) => Err(e)?,
340                            ProvisionStreamError::TransactionConflict(e) => Err(e)?,
341                            ProvisionStreamError::BasinDeletionPending(e) => Err(e)?,
342                            ProvisionStreamError::StreamDeletionPending(e) => Err(e)?,
343                            ProvisionStreamError::BasinNotFound(e) => Err(e)?,
344                            ProvisionStreamError::StreamAlreadyExists(_) => {}
345                            ProvisionStreamError::Validation(_) => {
346                                unreachable!("auto-create uses default config")
347                            }
348                        }
349                    }
350                    let client = self.streamer_client_guarded(basin, stream).await?;
351                    let encryption = resolve_encryption(client.cipher())?;
352                    Ok(StreamHandle {
353                        db: self.db.clone(),
354                        encryption,
355                        client,
356                    })
357                } else {
358                    Err(e.into())
359                }
360            }
361            Err(e) => Err(e.into()),
362        }
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use std::str::FromStr as _;
369
370    use bytes::Bytes;
371    use s2_common::{
372        record::{Metered, Record, StoredRecord, StreamPosition},
373        types::{
374            config::{BasinConfig, OptionalStreamConfig},
375            resources::ProvisionMode,
376        },
377    };
378    use slatedb::{WriteBatch, config::WriteOptions, object_store};
379    use time::OffsetDateTime;
380
381    use super::*;
382
383    async fn new_test_backend() -> Backend {
384        let object_store: Arc<dyn object_store::ObjectStore> =
385            Arc::new(object_store::memory::InMemory::new());
386        let db = slatedb::Db::builder("test", object_store)
387            .build()
388            .await
389            .unwrap();
390        Backend::new(db, ByteSize::b(1))
391    }
392
393    #[tokio::test]
394    #[should_panic(expected = "invariant violation: stream `testbasin1/stream1` tail_pos")]
395    async fn start_streamer_fails_if_records_exist_after_tail_pos() {
396        let backend = new_test_backend().await;
397
398        let basin = BasinName::from_str("testbasin1").unwrap();
399        let stream = StreamName::from_str("stream1").unwrap();
400        let stream_id = StreamId::new(&basin, &stream);
401
402        let meta = kv::stream_meta::StreamMeta {
403            config: OptionalStreamConfig::default(),
404            cipher: None,
405            created_at: OffsetDateTime::now_utc(),
406            deleted_at: None,
407            creation_idempotency_key: None,
408        };
409
410        let tail_pos = StreamPosition {
411            seq_num: 1,
412            timestamp: 123,
413        };
414        let record_pos = StreamPosition {
415            seq_num: tail_pos.seq_num,
416            timestamp: tail_pos.timestamp,
417        };
418
419        let record = Record::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap();
420        let metered_record: Metered<StoredRecord> = StoredRecord::from(record).into();
421
422        let mut wb = WriteBatch::new();
423        wb.put(
424            kv::stream_meta::ser_key(&basin, &stream),
425            kv::stream_meta::ser_value(&meta),
426        );
427        wb.put(
428            kv::stream_tail_position::ser_key(stream_id),
429            kv::stream_tail_position::ser_value(
430                tail_pos,
431                kv::timestamp::TimestampSecs::from_secs(1),
432            ),
433        );
434        wb.put(
435            kv::stream_record_data::ser_key(stream_id, record_pos),
436            kv::stream_record_data::ser_value(metered_record.as_ref()),
437        );
438        static WRITE_OPTS: WriteOptions = WriteOptions {
439            await_durable: true,
440        };
441        backend
442            .db
443            .write_with_options(wb, &WRITE_OPTS)
444            .await
445            .unwrap();
446
447        backend
448            .start_streamer(StreamerGenerationId::next(), basin.clone(), stream.clone())
449            .await
450            .unwrap();
451    }
452
453    #[tokio::test]
454    async fn streamer_client_slot_uses_single_initializer() {
455        let backend = new_test_backend().await;
456        let basin = BasinName::from_str("testbasin2").unwrap();
457        let stream = StreamName::from_str("stream2").unwrap();
458
459        let slot_1 = backend.streamer_client_slot(&basin, &stream);
460        let slot_2 = backend.streamer_client_slot(&basin, &stream);
461
462        let (generation_id_1, generation_id_2) = match (slot_1, slot_2) {
463            (
464                StreamerClientSlot::Initializing {
465                    generation_id: generation_id_1,
466                    ..
467                },
468                StreamerClientSlot::Initializing {
469                    generation_id: generation_id_2,
470                    ..
471                },
472            ) => (generation_id_1, generation_id_2),
473            _ => panic!("expected both slots to be Initializing"),
474        };
475        assert_eq!(generation_id_1, generation_id_2);
476        assert_eq!(backend.streamer_slots.len(), 1);
477    }
478
479    #[tokio::test]
480    async fn streamer_client_if_active_is_peek_only() {
481        let backend = new_test_backend().await;
482        let basin = BasinName::from_str("testbasin3").unwrap();
483        let stream = StreamName::from_str("stream3").unwrap();
484
485        backend
486            .provision_basin(
487                basin.clone(),
488                BasinConfig::default(),
489                ProvisionMode::CreateOnly {
490                    request_token: None,
491                },
492            )
493            .await
494            .unwrap();
495        backend
496            .provision_stream(
497                basin.clone(),
498                stream.clone(),
499                OptionalStreamConfig::default(),
500                ProvisionMode::CreateOnly {
501                    request_token: None,
502                },
503            )
504            .await
505            .unwrap();
506
507        assert!(backend.streamer_slots.is_empty());
508        assert!(backend.streamer_client_if_active(&basin, &stream).is_none());
509        assert!(backend.streamer_slots.is_empty());
510    }
511
512    #[tokio::test]
513    async fn streamer_client_failed_init_is_not_memoized() {
514        let backend = new_test_backend().await;
515        let basin = BasinName::from_str("testbasin4").unwrap();
516        let stream = StreamName::from_str("stream4").unwrap();
517        let stream_id = StreamId::new(&basin, &stream);
518
519        for _ in 0..2 {
520            let err = backend.streamer_client(&basin, &stream).await;
521            assert!(matches!(err, Err(StreamerError::StreamNotFound(_))));
522            assert!(
523                backend.streamer_slots.get(&stream_id).is_none(),
524                "failed init should not be cached"
525            );
526        }
527    }
528
529    #[tokio::test]
530    async fn streamer_finish_initialization_ignores_stale_generation_id() {
531        let backend = new_test_backend().await;
532        let basin = BasinName::from_str("testbasin5").unwrap();
533        let stream = StreamName::from_str("stream5").unwrap();
534        let stream_id = StreamId::new(&basin, &stream);
535
536        let stale_generation_id = StreamerGenerationId::next();
537        let current_generation_id = StreamerGenerationId::next();
538        let future = futures::future::pending::<Result<StreamerClient, StreamerError>>()
539            .boxed()
540            .shared();
541        backend.streamer_slots.insert(
542            stream_id,
543            StreamerClientSlot::Initializing {
544                generation_id: current_generation_id,
545                future: future.clone(),
546            },
547        );
548
549        let stale_result = Err(StreamNotFoundError { basin, stream }.into());
550        backend.streamer_finish_initialization(stream_id, stale_generation_id, &stale_result);
551
552        let Some(slot) = backend.streamer_slots.get(&stream_id) else {
553            panic!("stale init completion should not alter slot state");
554        };
555        match slot.value() {
556            StreamerClientSlot::Initializing { generation_id, .. } => {
557                assert_eq!(*generation_id, current_generation_id)
558            }
559            _ => panic!("expected initializing slot to remain unchanged"),
560        }
561    }
562}