Skip to main content

s2_lite/backend/
core.rs

1use std::sync::Arc;
2
3use bytesize::ByteSize;
4use dashmap::DashMap;
5use futures::{
6    FutureExt as _,
7    future::{BoxFuture, Shared},
8};
9use s2_common::{
10    encryption::{EncryptionAlgorithm, EncryptionSpec},
11    record::{NonZeroSeqNum, SeqNum, StreamPosition},
12    types::{
13        basin::BasinName,
14        config::{BasinConfig, OptionalStreamConfig},
15        resources::ProvisionMode,
16        stream::StreamName,
17    },
18};
19use slatedb::config::{DurabilityLevel, ReadOptions, ScanOptions};
20use tokio::sync::{Semaphore, broadcast};
21
22use super::{
23    StreamHandle,
24    durability_notifier::DurabilityNotifier,
25    error::{
26        BasinDeletionPendingError, BasinNotFoundError, GetBasinConfigError, ProvisionStreamError,
27        StorageError, StreamDeletionPendingError, StreamNotFoundError, StreamerError,
28        StreamerMissingInActionError, TransactionConflictError,
29    },
30    kv,
31    streamer::{GuardedStreamerClient, StreamerClient, StreamerGenerationId},
32};
33use crate::{backend::bgtasks::BgtaskTrigger, stream_id::StreamId};
34
35type StreamerInitFuture = Shared<BoxFuture<'static, Result<StreamerClient, StreamerError>>>;
36
37#[derive(Clone)]
38enum StreamerClientSlot {
39    Initializing {
40        generation_id: StreamerGenerationId,
41        future: StreamerInitFuture,
42    },
43    Ready {
44        client: StreamerClient,
45    },
46}
47
48#[derive(Clone)]
49pub struct Backend {
50    pub(super) db: slatedb::Db,
51    streamer_slots: Arc<DashMap<StreamId, StreamerClientSlot>>,
52    append_inflight_bytes_sema: Arc<Semaphore>,
53    durability_notifier: DurabilityNotifier,
54    bgtask_trigger_tx: broadcast::Sender<BgtaskTrigger>,
55}
56
57impl Backend {
58    pub fn new(db: slatedb::Db, append_inflight_bytes: ByteSize) -> Self {
59        let (bgtask_trigger_tx, _) = broadcast::channel(16);
60        let append_inflight_bytes = Arc::new(Semaphore::new(
61            (append_inflight_bytes.as_u64() as usize).clamp(
62                s2_common::caps::RECORD_BATCH_MAX.bytes,
63                Semaphore::MAX_PERMITS,
64            ),
65        ));
66        let durability_notifier = DurabilityNotifier::spawn(&db);
67        Self {
68            db,
69            streamer_slots: Arc::new(DashMap::new()),
70            append_inflight_bytes_sema: append_inflight_bytes,
71            durability_notifier,
72            bgtask_trigger_tx,
73        }
74    }
75
76    pub(super) fn bgtask_trigger(&self, trigger: BgtaskTrigger) {
77        let _ = self.bgtask_trigger_tx.send(trigger);
78    }
79
80    pub(super) fn bgtask_trigger_subscribe(&self) -> broadcast::Receiver<BgtaskTrigger> {
81        self.bgtask_trigger_tx.subscribe()
82    }
83
84    async fn start_streamer(
85        &self,
86        generation_id: StreamerGenerationId,
87        basin: BasinName,
88        stream: StreamName,
89    ) -> Result<StreamerClient, StreamerError> {
90        let stream_id = StreamId::new(&basin, &stream);
91
92        let (meta, persisted_tail, fencing_token, trim_point) = tokio::try_join!(
93            self.db_get(
94                kv::stream_meta::ser_key(&basin, &stream),
95                kv::stream_meta::deser_value,
96            ),
97            self.load_persisted_stream_tail(stream_id),
98            self.db_get(
99                kv::stream_fencing_token::ser_key(stream_id),
100                kv::stream_fencing_token::deser_value,
101            ),
102            self.db_get(
103                kv::stream_trim_point::ser_key(stream_id),
104                kv::stream_trim_point::deser_value,
105            )
106        )?;
107
108        let Some(meta) = meta else {
109            return Err(StreamNotFoundError { basin, stream }.into());
110        };
111
112        let (tail_pos, last_tail_write_timestamp) =
113            persisted_tail.unwrap_or((StreamPosition::MIN, kv::timestamp::TimestampSecs::ZERO));
114
115        self.assert_no_records_following_tail(stream_id, &basin, &stream, tail_pos)
116            .await?;
117
118        let fencing_token = fencing_token.unwrap_or_default();
119
120        if trim_point == Some(..NonZeroSeqNum::MAX) {
121            return Err(StreamDeletionPendingError.into());
122        }
123
124        let streamer_slots = self.streamer_slots.clone();
125        Ok(super::streamer::Spawner {
126            generation_id,
127            db: self.db.clone(),
128            stream_id,
129            config: meta.config,
130            cipher: meta.cipher,
131            tail_pos,
132            last_tail_write_timestamp,
133            fencing_token,
134            trim_point: ..trim_point.map_or(SeqNum::MIN, |tp| tp.end.get()),
135            append_inflight_bytes_sema: self.append_inflight_bytes_sema.clone(),
136            durability_notifier: self.durability_notifier.clone(),
137            bgtask_trigger_tx: self.bgtask_trigger_tx.clone(),
138        }
139        .spawn(move |client_id| {
140            streamer_slots.remove_if(&stream_id, |_, slot| {
141                matches!(slot, StreamerClientSlot::Ready { client } if client.generation_id() == client_id)
142            });
143        }))
144    }
145
146    async fn load_persisted_stream_tail(
147        &self,
148        stream_id: StreamId,
149    ) -> Result<Option<(StreamPosition, kv::timestamp::TimestampSecs)>, StorageError> {
150        let read_opts = ReadOptions {
151            durability_filter: DurabilityLevel::Remote,
152            ..Default::default()
153        };
154        let Some(entry) = self
155            .db
156            .get_key_value_with_options(kv::stream_tail_position::ser_key(stream_id), &read_opts)
157            .await?
158        else {
159            return Ok(None);
160        };
161        Ok(Some((
162            kv::stream_tail_position::deser_value(entry.value)?,
163            kv::timestamp::TimestampSecs::from_millis(entry.create_ts),
164        )))
165    }
166
167    async fn assert_no_records_following_tail(
168        &self,
169        stream_id: StreamId,
170        basin: &BasinName,
171        stream: &StreamName,
172        tail_pos: StreamPosition,
173    ) -> Result<(), StorageError> {
174        let start_key = kv::stream_record_data::ser_key(
175            stream_id,
176            StreamPosition {
177                seq_num: tail_pos.seq_num,
178                timestamp: 0,
179            },
180        );
181        let scan_opts = ScanOptions {
182            durability_filter: DurabilityLevel::Remote,
183            ..Default::default()
184        };
185        let mut it = self.db.scan_with_options(start_key.., &scan_opts).await?;
186        let Some(kv) = it.next().await? else {
187            return Ok(());
188        };
189        if kv.key.first().copied() != Some(kv::KeyType::StreamRecordData as u8) {
190            return Ok(());
191        }
192        let (deser_stream_id, pos) = kv::stream_record_data::deser_key(kv.key)?;
193        assert!(
194            deser_stream_id != stream_id,
195            "invariant violation: stream `{basin}/{stream}` tail_pos {tail_pos:?} but found record at {pos:?}"
196        );
197        Ok(())
198    }
199
200    fn streamer_client_slot(&self, basin: &BasinName, stream: &StreamName) -> StreamerClientSlot {
201        match self.streamer_slots.entry(StreamId::new(basin, stream)) {
202            dashmap::Entry::Occupied(mut oe) => {
203                if matches!(oe.get(), StreamerClientSlot::Ready { client } if client.is_dead()) {
204                    let slot = self.clone().new_initializing_slot(basin, stream);
205                    oe.insert(slot.clone());
206                    slot
207                } else {
208                    oe.get().clone()
209                }
210            }
211            dashmap::Entry::Vacant(ve) => {
212                let slot = self.clone().new_initializing_slot(basin, stream);
213                ve.insert(slot.clone());
214                slot
215            }
216        }
217    }
218
219    fn new_initializing_slot(self, basin: &BasinName, stream: &StreamName) -> StreamerClientSlot {
220        let basin = basin.clone();
221        let stream = stream.clone();
222        let generation_id = StreamerGenerationId::next();
223        let future = async move { self.start_streamer(generation_id, basin, stream).await }
224            .boxed()
225            .shared();
226        StreamerClientSlot::Initializing {
227            generation_id,
228            future,
229        }
230    }
231
232    fn streamer_finish_initialization(
233        &self,
234        stream_id: StreamId,
235        generation_id: StreamerGenerationId,
236        result: &Result<StreamerClient, StreamerError>,
237    ) {
238        if let dashmap::Entry::Occupied(mut oe) = self.streamer_slots.entry(stream_id) {
239            let is_same_init = matches!(
240                oe.get(),
241                StreamerClientSlot::Initializing {
242                    generation_id: state_generation_id,
243                    ..
244                } if *state_generation_id == generation_id
245            );
246            if is_same_init {
247                match result {
248                    Ok(client) => {
249                        debug_assert_eq!(client.generation_id(), generation_id);
250                        if client.is_dead() {
251                            oe.remove();
252                        } else {
253                            oe.insert(StreamerClientSlot::Ready {
254                                client: client.clone(),
255                            });
256                        }
257                    }
258                    Err(_) => {
259                        oe.remove();
260                    }
261                }
262            }
263        }
264    }
265
266    pub(super) async fn streamer_client(
267        &self,
268        basin: &BasinName,
269        stream: &StreamName,
270    ) -> Result<StreamerClient, StreamerError> {
271        let stream_id = StreamId::new(basin, stream);
272        match self.streamer_client_slot(basin, stream) {
273            StreamerClientSlot::Initializing {
274                generation_id,
275                future,
276            } => {
277                let result = future.await;
278                self.streamer_finish_initialization(stream_id, generation_id, &result);
279                result
280            }
281            StreamerClientSlot::Ready { client } => Ok(client),
282        }
283    }
284
285    pub(super) fn streamer_client_if_active(
286        &self,
287        basin: &BasinName,
288        stream: &StreamName,
289    ) -> Option<StreamerClient> {
290        let stream_id = StreamId::new(basin, stream);
291        let slot = self.streamer_slots.get(&stream_id)?;
292        match slot.value() {
293            StreamerClientSlot::Ready { client } if !client.is_dead() => Some(client.clone()),
294            _ => None,
295        }
296    }
297
298    pub(super) async fn streamer_client_guarded(
299        &self,
300        basin: &BasinName,
301        stream: &StreamName,
302    ) -> Result<GuardedStreamerClient, StreamerError> {
303        loop {
304            let client = self.streamer_client(basin, stream).await?;
305            match client.guard() {
306                Ok(client) => return Ok(client),
307                Err(StreamerMissingInActionError) => continue,
308            }
309        }
310    }
311
312    pub(super) async fn stream_handle_with_auto_create<E>(
313        &self,
314        basin: &BasinName,
315        stream: &StreamName,
316        should_auto_create: impl FnOnce(&BasinConfig) -> bool,
317        resolve_encryption: impl FnOnce(Option<EncryptionAlgorithm>) -> Result<EncryptionSpec, E>,
318    ) -> Result<StreamHandle, E>
319    where
320        E: From<StreamerError>
321            + From<StorageError>
322            + From<BasinNotFoundError>
323            + From<TransactionConflictError>
324            + From<BasinDeletionPendingError>
325            + From<StreamDeletionPendingError>
326            + From<StreamNotFoundError>,
327    {
328        match self.streamer_client_guarded(basin, stream).await {
329            Ok(client) => Ok(StreamHandle {
330                db: self.db.clone(),
331                encryption: resolve_encryption(client.cipher())?,
332                client,
333            }),
334            Err(StreamerError::StreamNotFound(e)) => {
335                let config = match self.get_basin_config(basin.clone()).await {
336                    Ok(config) => config,
337                    Err(GetBasinConfigError::Storage(e)) => Err(e)?,
338                    Err(GetBasinConfigError::BasinNotFound(e)) => Err(e)?,
339                };
340                if should_auto_create(&config) {
341                    if let Err(e) = self
342                        .provision_stream(
343                            basin.clone(),
344                            stream.clone(),
345                            OptionalStreamConfig::default(),
346                            ProvisionMode::CreateOnly {
347                                request_token: None,
348                            },
349                        )
350                        .await
351                    {
352                        match e {
353                            ProvisionStreamError::Storage(e) => Err(e)?,
354                            ProvisionStreamError::TransactionConflict(e) => Err(e)?,
355                            ProvisionStreamError::BasinDeletionPending(e) => Err(e)?,
356                            ProvisionStreamError::StreamDeletionPending(e) => Err(e)?,
357                            ProvisionStreamError::BasinNotFound(e) => Err(e)?,
358                            ProvisionStreamError::StreamAlreadyExists(_) => {}
359                            ProvisionStreamError::Validation(_) => {
360                                unreachable!("auto-create uses default config")
361                            }
362                        }
363                    }
364                    let client = self.streamer_client_guarded(basin, stream).await?;
365                    let encryption = resolve_encryption(client.cipher())?;
366                    Ok(StreamHandle {
367                        db: self.db.clone(),
368                        encryption,
369                        client,
370                    })
371                } else {
372                    Err(e.into())
373                }
374            }
375            Err(e) => Err(e.into()),
376        }
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use std::str::FromStr as _;
383
384    use bytes::Bytes;
385    use s2_common::{
386        record::{Metered, Record, StoredRecord, StreamPosition},
387        types::{
388            config::{BasinConfig, OptionalStreamConfig, StreamConfig},
389            resources::ProvisionMode,
390        },
391    };
392    use slatedb::{WriteBatch, object_store};
393    use time::OffsetDateTime;
394
395    use super::*;
396
397    async fn new_test_backend() -> Backend {
398        let object_store: Arc<dyn object_store::ObjectStore> =
399            Arc::new(object_store::memory::InMemory::new());
400        let db = slatedb::Db::builder("test", object_store)
401            .build()
402            .await
403            .unwrap();
404        Backend::new(db, ByteSize::b(1))
405    }
406
407    #[tokio::test]
408    #[should_panic(expected = "invariant violation: stream `testbasin1/stream1` tail_pos")]
409    async fn start_streamer_fails_if_records_exist_after_tail_pos() {
410        let backend = new_test_backend().await;
411
412        let basin = BasinName::from_str("testbasin1").unwrap();
413        let stream = StreamName::from_str("stream1").unwrap();
414        let stream_id = StreamId::new(&basin, &stream);
415
416        let meta = kv::stream_meta::StreamMeta {
417            config: StreamConfig::default(),
418            cipher: None,
419            created_at: OffsetDateTime::now_utc(),
420            deleted_at: None,
421            creation_idempotency_key: None,
422        };
423
424        let tail_pos = StreamPosition {
425            seq_num: 1,
426            timestamp: 123,
427        };
428        let record_pos = StreamPosition {
429            seq_num: tail_pos.seq_num,
430            timestamp: tail_pos.timestamp,
431        };
432
433        let record = Record::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap();
434        let metered_record: Metered<StoredRecord> = StoredRecord::from(record).into();
435
436        let mut wb = WriteBatch::new();
437        wb.put(
438            kv::stream_meta::ser_key(&basin, &stream),
439            kv::stream_meta::ser_value(&meta),
440        );
441        wb.put(
442            kv::stream_tail_position::ser_key(stream_id),
443            kv::stream_tail_position::ser_value(tail_pos),
444        );
445        wb.put(
446            kv::stream_record_data::ser_key(stream_id, record_pos),
447            kv::stream_record_data::ser_value(metered_record.as_ref()),
448        );
449        backend.db.write(wb).await.unwrap();
450
451        backend
452            .start_streamer(StreamerGenerationId::next(), basin.clone(), stream.clone())
453            .await
454            .unwrap();
455    }
456
457    #[tokio::test]
458    async fn streamer_client_slot_uses_single_initializer() {
459        let backend = new_test_backend().await;
460        let basin = BasinName::from_str("testbasin2").unwrap();
461        let stream = StreamName::from_str("stream2").unwrap();
462
463        let slot_1 = backend.streamer_client_slot(&basin, &stream);
464        let slot_2 = backend.streamer_client_slot(&basin, &stream);
465
466        let (generation_id_1, generation_id_2) = match (slot_1, slot_2) {
467            (
468                StreamerClientSlot::Initializing {
469                    generation_id: generation_id_1,
470                    ..
471                },
472                StreamerClientSlot::Initializing {
473                    generation_id: generation_id_2,
474                    ..
475                },
476            ) => (generation_id_1, generation_id_2),
477            _ => panic!("expected both slots to be Initializing"),
478        };
479        assert_eq!(generation_id_1, generation_id_2);
480        assert_eq!(backend.streamer_slots.len(), 1);
481    }
482
483    #[tokio::test]
484    async fn streamer_client_if_active_is_peek_only() {
485        let backend = new_test_backend().await;
486        let basin = BasinName::from_str("testbasin3").unwrap();
487        let stream = StreamName::from_str("stream3").unwrap();
488
489        backend
490            .provision_basin(
491                basin.clone(),
492                BasinConfig::default(),
493                ProvisionMode::CreateOnly {
494                    request_token: None,
495                },
496            )
497            .await
498            .unwrap();
499        backend
500            .provision_stream(
501                basin.clone(),
502                stream.clone(),
503                OptionalStreamConfig::default(),
504                ProvisionMode::CreateOnly {
505                    request_token: None,
506                },
507            )
508            .await
509            .unwrap();
510
511        assert!(backend.streamer_slots.is_empty());
512        assert!(backend.streamer_client_if_active(&basin, &stream).is_none());
513        assert!(backend.streamer_slots.is_empty());
514    }
515
516    #[tokio::test]
517    async fn streamer_client_failed_init_is_not_memoized() {
518        let backend = new_test_backend().await;
519        let basin = BasinName::from_str("testbasin4").unwrap();
520        let stream = StreamName::from_str("stream4").unwrap();
521        let stream_id = StreamId::new(&basin, &stream);
522
523        for _ in 0..2 {
524            let err = backend.streamer_client(&basin, &stream).await;
525            assert!(matches!(err, Err(StreamerError::StreamNotFound(_))));
526            assert!(
527                backend.streamer_slots.get(&stream_id).is_none(),
528                "failed init should not be cached"
529            );
530        }
531    }
532
533    #[tokio::test]
534    async fn streamer_finish_initialization_ignores_stale_generation_id() {
535        let backend = new_test_backend().await;
536        let basin = BasinName::from_str("testbasin5").unwrap();
537        let stream = StreamName::from_str("stream5").unwrap();
538        let stream_id = StreamId::new(&basin, &stream);
539
540        let stale_generation_id = StreamerGenerationId::next();
541        let current_generation_id = StreamerGenerationId::next();
542        let future = futures::future::pending::<Result<StreamerClient, StreamerError>>()
543            .boxed()
544            .shared();
545        backend.streamer_slots.insert(
546            stream_id,
547            StreamerClientSlot::Initializing {
548                generation_id: current_generation_id,
549                future: future.clone(),
550            },
551        );
552
553        let stale_result = Err(StreamNotFoundError { basin, stream }.into());
554        backend.streamer_finish_initialization(stream_id, stale_generation_id, &stale_result);
555
556        let Some(slot) = backend.streamer_slots.get(&stream_id) else {
557            panic!("stale init completion should not alter slot state");
558        };
559        match slot.value() {
560            StreamerClientSlot::Initializing { generation_id, .. } => {
561                assert_eq!(*generation_id, current_generation_id)
562            }
563            _ => panic!("expected initializing slot to remain unchanged"),
564        }
565    }
566}