1use std::sync::Arc;
2
3use bytesize::ByteSize;
4use dashmap::DashMap;
5use futures::{
6 FutureExt as _,
7 future::{BoxFuture, Shared},
8};
9use s2_common::{
10 encryption::{EncryptionAlgorithm, EncryptionSpec},
11 record::{NonZeroSeqNum, SeqNum, StreamPosition},
12 types::{
13 basin::BasinName,
14 config::{BasinConfig, OptionalStreamConfig},
15 stream::{CreateStreamIntent, StreamName},
16 },
17};
18use slatedb::{
19 IterationOrder,
20 config::{DurabilityLevel, ScanOptions},
21};
22use tokio::sync::{Semaphore, broadcast};
23
24use super::{
25 StreamHandle,
26 durability_notifier::DurabilityNotifier,
27 error::{
28 BasinDeletionPendingError, BasinNotFoundError, CreateStreamError, GetBasinConfigError,
29 StorageError, StreamDeletionPendingError, StreamNotFoundError, StreamerError,
30 StreamerMissingInActionError, TransactionConflictError,
31 },
32 kv,
33 streamer::{GuardedStreamerClient, StreamerClient, StreamerGenerationId},
34};
35use crate::{backend::bgtasks::BgtaskTrigger, stream_id::StreamId};
36
37type StreamerInitFuture = Shared<BoxFuture<'static, Result<StreamerClient, StreamerError>>>;
38
39#[derive(Clone)]
40enum StreamerClientSlot {
41 Initializing {
42 generation_id: StreamerGenerationId,
43 future: StreamerInitFuture,
44 },
45 Ready {
46 client: StreamerClient,
47 },
48}
49
50#[derive(Clone)]
51pub struct Backend {
52 pub(super) db: slatedb::Db,
53 streamer_slots: Arc<DashMap<StreamId, StreamerClientSlot>>,
54 append_inflight_bytes_sema: Arc<Semaphore>,
55 durability_notifier: DurabilityNotifier,
56 bgtask_trigger_tx: broadcast::Sender<BgtaskTrigger>,
57}
58
59impl Backend {
60 pub fn new(db: slatedb::Db, append_inflight_bytes: ByteSize) -> Self {
61 let (bgtask_trigger_tx, _) = broadcast::channel(16);
62 let append_inflight_bytes = Arc::new(Semaphore::new(
63 (append_inflight_bytes.as_u64() as usize).clamp(
64 s2_common::caps::RECORD_BATCH_MAX.bytes,
65 Semaphore::MAX_PERMITS,
66 ),
67 ));
68 let durability_notifier = DurabilityNotifier::spawn(&db);
69 Self {
70 db,
71 streamer_slots: Arc::new(DashMap::new()),
72 append_inflight_bytes_sema: append_inflight_bytes,
73 durability_notifier,
74 bgtask_trigger_tx,
75 }
76 }
77
78 pub(super) fn bgtask_trigger(&self, trigger: BgtaskTrigger) {
79 let _ = self.bgtask_trigger_tx.send(trigger);
80 }
81
82 pub(super) fn bgtask_trigger_subscribe(&self) -> broadcast::Receiver<BgtaskTrigger> {
83 self.bgtask_trigger_tx.subscribe()
84 }
85
86 async fn start_streamer(
87 &self,
88 generation_id: StreamerGenerationId,
89 basin: BasinName,
90 stream: StreamName,
91 ) -> Result<StreamerClient, StreamerError> {
92 let stream_id = StreamId::new(&basin, &stream);
93
94 let (meta, tail_pos, fencing_token, trim_point) = tokio::try_join!(
95 self.db_get(
96 kv::stream_meta::ser_key(&basin, &stream),
97 kv::stream_meta::deser_value,
98 ),
99 self.db_get(
100 kv::stream_tail_position::ser_key(stream_id),
101 kv::stream_tail_position::deser_value,
102 ),
103 self.db_get(
104 kv::stream_fencing_token::ser_key(stream_id),
105 kv::stream_fencing_token::deser_value,
106 ),
107 self.db_get(
108 kv::stream_trim_point::ser_key(stream_id),
109 kv::stream_trim_point::deser_value,
110 )
111 )?;
112
113 let Some(meta) = meta else {
114 return Err(StreamNotFoundError { basin, stream }.into());
115 };
116
117 let tail_pos = tail_pos.map(|(pos, _)| pos).unwrap_or(StreamPosition::MIN);
118 self.assert_no_records_following_tail(stream_id, &basin, &stream, tail_pos)
119 .await?;
120
121 let fencing_token = fencing_token.unwrap_or_default();
122
123 if trim_point == Some(..NonZeroSeqNum::MAX) {
124 return Err(StreamDeletionPendingError { basin, stream }.into());
125 }
126
127 let streamer_slots = self.streamer_slots.clone();
128 Ok(super::streamer::Spawner {
129 generation_id,
130 db: self.db.clone(),
131 stream_id,
132 config: meta.config,
133 cipher: meta.cipher,
134 tail_pos,
135 fencing_token,
136 trim_point: ..trim_point.map_or(SeqNum::MIN, |tp| tp.end.get()),
137 append_inflight_bytes_sema: self.append_inflight_bytes_sema.clone(),
138 durability_notifier: self.durability_notifier.clone(),
139 bgtask_trigger_tx: self.bgtask_trigger_tx.clone(),
140 }
141 .spawn(move |client_id| {
142 streamer_slots.remove_if(&stream_id, |_, slot| {
143 matches!(slot, StreamerClientSlot::Ready { client } if client.generation_id() == client_id)
144 });
145 }))
146 }
147
148 async fn assert_no_records_following_tail(
149 &self,
150 stream_id: StreamId,
151 basin: &BasinName,
152 stream: &StreamName,
153 tail_pos: StreamPosition,
154 ) -> Result<(), StorageError> {
155 let start_key = kv::stream_record_data::ser_key(
156 stream_id,
157 StreamPosition {
158 seq_num: tail_pos.seq_num,
159 timestamp: 0,
160 },
161 );
162 static SCAN_OPTS: ScanOptions = ScanOptions {
163 durability_filter: DurabilityLevel::Remote,
164 dirty: false,
165 read_ahead_bytes: 1,
166 cache_blocks: false,
167 max_fetch_tasks: 1,
168 order: IterationOrder::Ascending,
169 };
170 let mut it = self.db.scan_with_options(start_key.., &SCAN_OPTS).await?;
171 let Some(kv) = it.next().await? else {
172 return Ok(());
173 };
174 if kv.key.first().copied() != Some(kv::KeyType::StreamRecordData as u8) {
175 return Ok(());
176 }
177 let (deser_stream_id, pos) = kv::stream_record_data::deser_key(kv.key)?;
178 assert!(
179 deser_stream_id != stream_id,
180 "invariant violation: stream `{basin}/{stream}` tail_pos {tail_pos:?} but found record at {pos:?}"
181 );
182 Ok(())
183 }
184
185 fn streamer_client_slot(&self, basin: &BasinName, stream: &StreamName) -> StreamerClientSlot {
186 match self.streamer_slots.entry(StreamId::new(basin, stream)) {
187 dashmap::Entry::Occupied(mut oe) => {
188 if matches!(oe.get(), StreamerClientSlot::Ready { client } if client.is_dead()) {
189 let slot = self.clone().new_initializing_slot(basin, stream);
190 oe.insert(slot.clone());
191 slot
192 } else {
193 oe.get().clone()
194 }
195 }
196 dashmap::Entry::Vacant(ve) => {
197 let slot = self.clone().new_initializing_slot(basin, stream);
198 ve.insert(slot.clone());
199 slot
200 }
201 }
202 }
203
204 fn new_initializing_slot(self, basin: &BasinName, stream: &StreamName) -> StreamerClientSlot {
205 let basin = basin.clone();
206 let stream = stream.clone();
207 let generation_id = StreamerGenerationId::next();
208 let future = async move { self.start_streamer(generation_id, basin, stream).await }
209 .boxed()
210 .shared();
211 StreamerClientSlot::Initializing {
212 generation_id,
213 future,
214 }
215 }
216
217 fn streamer_finish_initialization(
218 &self,
219 stream_id: StreamId,
220 generation_id: StreamerGenerationId,
221 result: &Result<StreamerClient, StreamerError>,
222 ) {
223 if let dashmap::Entry::Occupied(mut oe) = self.streamer_slots.entry(stream_id) {
224 let is_same_init = matches!(
225 oe.get(),
226 StreamerClientSlot::Initializing {
227 generation_id: state_generation_id,
228 ..
229 } if *state_generation_id == generation_id
230 );
231 if is_same_init {
232 match result {
233 Ok(client) => {
234 debug_assert_eq!(client.generation_id(), generation_id);
235 if client.is_dead() {
236 oe.remove();
237 } else {
238 oe.insert(StreamerClientSlot::Ready {
239 client: client.clone(),
240 });
241 }
242 }
243 Err(_) => {
244 oe.remove();
245 }
246 }
247 }
248 }
249 }
250
251 pub(super) async fn streamer_client(
252 &self,
253 basin: &BasinName,
254 stream: &StreamName,
255 ) -> Result<StreamerClient, StreamerError> {
256 let stream_id = StreamId::new(basin, stream);
257 match self.streamer_client_slot(basin, stream) {
258 StreamerClientSlot::Initializing {
259 generation_id,
260 future,
261 } => {
262 let result = future.await;
263 self.streamer_finish_initialization(stream_id, generation_id, &result);
264 result
265 }
266 StreamerClientSlot::Ready { client } => Ok(client),
267 }
268 }
269
270 pub(super) fn streamer_client_if_active(
271 &self,
272 basin: &BasinName,
273 stream: &StreamName,
274 ) -> Option<StreamerClient> {
275 let stream_id = StreamId::new(basin, stream);
276 let slot = self.streamer_slots.get(&stream_id)?;
277 match slot.value() {
278 StreamerClientSlot::Ready { client } if !client.is_dead() => Some(client.clone()),
279 _ => None,
280 }
281 }
282
283 pub(super) async fn streamer_client_guarded(
284 &self,
285 basin: &BasinName,
286 stream: &StreamName,
287 ) -> Result<GuardedStreamerClient, StreamerError> {
288 loop {
289 let client = self.streamer_client(basin, stream).await?;
290 match client.guard() {
291 Ok(client) => return Ok(client),
292 Err(StreamerMissingInActionError) => continue,
293 }
294 }
295 }
296
297 pub(super) async fn stream_handle_with_auto_create<E>(
298 &self,
299 basin: &BasinName,
300 stream: &StreamName,
301 should_auto_create: impl FnOnce(&BasinConfig) -> bool,
302 resolve_encryption: impl FnOnce(Option<EncryptionAlgorithm>) -> Result<EncryptionSpec, E>,
303 ) -> Result<StreamHandle, E>
304 where
305 E: From<StreamerError>
306 + From<StorageError>
307 + From<BasinNotFoundError>
308 + From<TransactionConflictError>
309 + From<BasinDeletionPendingError>
310 + From<StreamDeletionPendingError>
311 + From<StreamNotFoundError>,
312 {
313 match self.streamer_client_guarded(basin, stream).await {
314 Ok(client) => Ok(StreamHandle {
315 db: self.db.clone(),
316 encryption: resolve_encryption(client.cipher())?,
317 client,
318 }),
319 Err(StreamerError::StreamNotFound(e)) => {
320 let config = match self.get_basin_config(basin.clone()).await {
321 Ok(config) => config,
322 Err(GetBasinConfigError::Storage(e)) => Err(e)?,
323 Err(GetBasinConfigError::BasinNotFound(e)) => Err(e)?,
324 };
325 if should_auto_create(&config) {
326 if let Err(e) = self
327 .create_stream(
328 basin.clone(),
329 stream.clone(),
330 CreateStreamIntent::CreateOnly {
331 config: OptionalStreamConfig::default(),
332 request_token: None,
333 },
334 )
335 .await
336 {
337 match e {
338 CreateStreamError::Storage(e) => Err(e)?,
339 CreateStreamError::TransactionConflict(e) => Err(e)?,
340 CreateStreamError::BasinDeletionPending(e) => Err(e)?,
341 CreateStreamError::StreamDeletionPending(e) => Err(e)?,
342 CreateStreamError::BasinNotFound(e) => Err(e)?,
343 CreateStreamError::StreamAlreadyExists(_) => {}
344 CreateStreamError::Validation(_) => {
345 unreachable!("auto-create uses default config")
346 }
347 }
348 }
349 let client = self.streamer_client_guarded(basin, stream).await?;
350 let encryption = resolve_encryption(client.cipher())?;
351 Ok(StreamHandle {
352 db: self.db.clone(),
353 encryption,
354 client,
355 })
356 } else {
357 Err(e.into())
358 }
359 }
360 Err(e) => Err(e.into()),
361 }
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use std::str::FromStr as _;
368
369 use bytes::Bytes;
370 use s2_common::{
371 record::{Metered, Record, StoredRecord, StreamPosition},
372 types::{
373 basin::CreateBasinIntent,
374 config::{BasinConfig, OptionalStreamConfig},
375 },
376 };
377 use slatedb::{WriteBatch, config::WriteOptions, object_store};
378 use time::OffsetDateTime;
379
380 use super::*;
381
382 async fn new_test_backend() -> Backend {
383 let object_store: Arc<dyn object_store::ObjectStore> =
384 Arc::new(object_store::memory::InMemory::new());
385 let db = slatedb::Db::builder("test", object_store)
386 .build()
387 .await
388 .unwrap();
389 Backend::new(db, ByteSize::b(1))
390 }
391
392 #[tokio::test]
393 #[should_panic(expected = "invariant violation: stream `testbasin1/stream1` tail_pos")]
394 async fn start_streamer_fails_if_records_exist_after_tail_pos() {
395 let backend = new_test_backend().await;
396
397 let basin = BasinName::from_str("testbasin1").unwrap();
398 let stream = StreamName::from_str("stream1").unwrap();
399 let stream_id = StreamId::new(&basin, &stream);
400
401 let meta = kv::stream_meta::StreamMeta {
402 config: OptionalStreamConfig::default(),
403 cipher: None,
404 created_at: OffsetDateTime::now_utc(),
405 deleted_at: None,
406 creation_idempotency_key: None,
407 };
408
409 let tail_pos = StreamPosition {
410 seq_num: 1,
411 timestamp: 123,
412 };
413 let record_pos = StreamPosition {
414 seq_num: tail_pos.seq_num,
415 timestamp: tail_pos.timestamp,
416 };
417
418 let record = Record::try_from_parts(vec![], Bytes::from_static(b"hello")).unwrap();
419 let metered_record: Metered<StoredRecord> = StoredRecord::from(record).into();
420
421 let mut wb = WriteBatch::new();
422 wb.put(
423 kv::stream_meta::ser_key(&basin, &stream),
424 kv::stream_meta::ser_value(&meta),
425 );
426 wb.put(
427 kv::stream_tail_position::ser_key(stream_id),
428 kv::stream_tail_position::ser_value(
429 tail_pos,
430 kv::timestamp::TimestampSecs::from_secs(1),
431 ),
432 );
433 wb.put(
434 kv::stream_record_data::ser_key(stream_id, record_pos),
435 kv::stream_record_data::ser_value(metered_record.as_ref()),
436 );
437 static WRITE_OPTS: WriteOptions = WriteOptions {
438 await_durable: true,
439 };
440 backend
441 .db
442 .write_with_options(wb, &WRITE_OPTS)
443 .await
444 .unwrap();
445
446 backend
447 .start_streamer(StreamerGenerationId::next(), basin.clone(), stream.clone())
448 .await
449 .unwrap();
450 }
451
452 #[tokio::test]
453 async fn streamer_client_slot_uses_single_initializer() {
454 let backend = new_test_backend().await;
455 let basin = BasinName::from_str("testbasin2").unwrap();
456 let stream = StreamName::from_str("stream2").unwrap();
457
458 let slot_1 = backend.streamer_client_slot(&basin, &stream);
459 let slot_2 = backend.streamer_client_slot(&basin, &stream);
460
461 let (generation_id_1, generation_id_2) = match (slot_1, slot_2) {
462 (
463 StreamerClientSlot::Initializing {
464 generation_id: generation_id_1,
465 ..
466 },
467 StreamerClientSlot::Initializing {
468 generation_id: generation_id_2,
469 ..
470 },
471 ) => (generation_id_1, generation_id_2),
472 _ => panic!("expected both slots to be Initializing"),
473 };
474 assert_eq!(generation_id_1, generation_id_2);
475 assert_eq!(backend.streamer_slots.len(), 1);
476 }
477
478 #[tokio::test]
479 async fn streamer_client_if_active_is_peek_only() {
480 let backend = new_test_backend().await;
481 let basin = BasinName::from_str("testbasin3").unwrap();
482 let stream = StreamName::from_str("stream3").unwrap();
483
484 backend
485 .create_basin(
486 basin.clone(),
487 CreateBasinIntent::CreateOnly {
488 config: BasinConfig::default(),
489 request_token: None,
490 },
491 )
492 .await
493 .unwrap();
494 backend
495 .create_stream(
496 basin.clone(),
497 stream.clone(),
498 CreateStreamIntent::CreateOnly {
499 config: OptionalStreamConfig::default(),
500 request_token: None,
501 },
502 )
503 .await
504 .unwrap();
505
506 assert!(backend.streamer_slots.is_empty());
507 assert!(backend.streamer_client_if_active(&basin, &stream).is_none());
508 assert!(backend.streamer_slots.is_empty());
509 }
510
511 #[tokio::test]
512 async fn streamer_client_failed_init_is_not_memoized() {
513 let backend = new_test_backend().await;
514 let basin = BasinName::from_str("testbasin4").unwrap();
515 let stream = StreamName::from_str("stream4").unwrap();
516 let stream_id = StreamId::new(&basin, &stream);
517
518 for _ in 0..2 {
519 let err = backend.streamer_client(&basin, &stream).await;
520 assert!(matches!(err, Err(StreamerError::StreamNotFound(_))));
521 assert!(
522 backend.streamer_slots.get(&stream_id).is_none(),
523 "failed init should not be cached"
524 );
525 }
526 }
527
528 #[tokio::test]
529 async fn streamer_finish_initialization_ignores_stale_generation_id() {
530 let backend = new_test_backend().await;
531 let basin = BasinName::from_str("testbasin5").unwrap();
532 let stream = StreamName::from_str("stream5").unwrap();
533 let stream_id = StreamId::new(&basin, &stream);
534
535 let stale_generation_id = StreamerGenerationId::next();
536 let current_generation_id = StreamerGenerationId::next();
537 let future = futures::future::pending::<Result<StreamerClient, StreamerError>>()
538 .boxed()
539 .shared();
540 backend.streamer_slots.insert(
541 stream_id,
542 StreamerClientSlot::Initializing {
543 generation_id: current_generation_id,
544 future: future.clone(),
545 },
546 );
547
548 let stale_result = Err(StreamNotFoundError { basin, stream }.into());
549 backend.streamer_finish_initialization(stream_id, stale_generation_id, &stale_result);
550
551 let Some(slot) = backend.streamer_slots.get(&stream_id) else {
552 panic!("stale init completion should not alter slot state");
553 };
554 match slot.value() {
555 StreamerClientSlot::Initializing { generation_id, .. } => {
556 assert_eq!(*generation_id, current_generation_id)
557 }
558 _ => panic!("expected initializing slot to remain unchanged"),
559 }
560 }
561}