1use std::{
2 collections::VecDeque,
3 future::Future,
4 num::NonZeroU32,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8 time::Duration,
9};
10
11use futures::StreamExt;
12use tokio::{
13 sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot},
14 time::Instant,
15};
16use tokio_muxt::{CoalesceMode, MuxTimer};
17use tokio_stream::wrappers::ReceiverStream;
18use tokio_util::task::AbortOnDropHandle;
19use tracing::debug;
20
21use crate::{
22 api::{ApiError, BasinClient, Streaming, retry_builder},
23 retry::RetryBackoffBuilder,
24 types::{
25 AppendAck, AppendInput, AppendRetryPolicy, MeteredBytes, ONE_MIB, S2Error, StreamName,
26 StreamPosition, ValidationError,
27 },
28};
29
30#[derive(Debug, thiserror::Error)]
31pub enum AppendSessionError {
32 #[error(transparent)]
33 Api(#[from] ApiError),
34 #[error("append acknowledgement timed out")]
35 AckTimeout,
36 #[error("server disconnected")]
37 ServerDisconnected,
38 #[error("response stream closed early while appends in flight")]
39 StreamClosedEarly,
40 #[error("session already closed")]
41 SessionClosed,
42 #[error("session is closing")]
43 SessionClosing,
44 #[error("session dropped without calling close")]
45 SessionDropped,
46}
47
48impl AppendSessionError {
49 pub fn is_retryable(&self) -> bool {
50 match self {
51 Self::Api(err) => err.is_retryable(),
52 Self::AckTimeout => true,
53 Self::ServerDisconnected => true,
54 _ => false,
55 }
56 }
57}
58
59impl From<AppendSessionError> for S2Error {
60 fn from(err: AppendSessionError) -> Self {
61 match err {
62 AppendSessionError::Api(api_err) => api_err.into(),
63 other => S2Error::Client(other.to_string()),
64 }
65 }
66}
67
68pub struct BatchSubmitTicket {
70 rx: oneshot::Receiver<Result<AppendAck, S2Error>>,
71}
72
73impl Future for BatchSubmitTicket {
74 type Output = Result<AppendAck, S2Error>;
75
76 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
77 match Pin::new(&mut self.rx).poll(cx) {
78 Poll::Ready(Ok(res)) => Poll::Ready(res),
79 Poll::Ready(Err(_)) => Poll::Ready(Err(AppendSessionError::SessionDropped.into())),
80 Poll::Pending => Poll::Pending,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
86pub struct AppendSessionConfig {
88 max_inflight_bytes: u32,
89 max_inflight_batches: Option<u32>,
90}
91
92impl Default for AppendSessionConfig {
93 fn default() -> Self {
94 Self {
95 max_inflight_bytes: 10 * ONE_MIB,
96 max_inflight_batches: None,
97 }
98 }
99}
100
101impl AppendSessionConfig {
102 pub fn new() -> Self {
104 Self::default()
105 }
106
107 pub fn with_max_inflight_bytes(self, max_inflight_bytes: u32) -> Result<Self, ValidationError> {
113 if max_inflight_bytes < ONE_MIB {
114 return Err(format!("max_inflight_bytes must be at least {ONE_MIB}").into());
115 }
116 Ok(Self {
117 max_inflight_bytes,
118 ..self
119 })
120 }
121
122 pub fn with_max_inflight_batches(
126 self,
127 max_inflight_batches: NonZeroU32,
128 ) -> Result<Self, ValidationError> {
129 Ok(Self {
130 max_inflight_batches: Some(max_inflight_batches.get()),
131 ..self
132 })
133 }
134}
135
136struct SessionState {
137 cmd_rx: mpsc::Receiver<Command>,
138 inflight_appends: VecDeque<InflightAppend>,
139 inflight_bytes: usize,
140 close_tx: Option<oneshot::Sender<Result<(), S2Error>>>,
141 total_records: usize,
142 total_acked_records: usize,
143 prev_ack_end: Option<StreamPosition>,
144 stashed_submission: Option<StashedSubmission>,
145}
146
147pub struct AppendSession {
152 cmd_tx: mpsc::Sender<Command>,
153 permits: AppendPermits,
154 _handle: AbortOnDropHandle<()>,
155}
156
157impl AppendSession {
158 pub(crate) fn new(
159 client: BasinClient,
160 stream: StreamName,
161 config: AppendSessionConfig,
162 ) -> Self {
163 let buffer_size = config
164 .max_inflight_batches
165 .map(|mib| mib as usize)
166 .unwrap_or(DEFAULT_CHANNEL_BUFFER_SIZE);
167 let (cmd_tx, cmd_rx) = mpsc::channel(buffer_size);
168 let permits = AppendPermits::new(config.max_inflight_batches, config.max_inflight_bytes);
169 let retry_builder = retry_builder(&client.config.retry);
170 let handle = AbortOnDropHandle::new(tokio::spawn(run_session_with_retry(
171 client,
172 stream,
173 cmd_rx,
174 retry_builder,
175 buffer_size,
176 )));
177 Self {
178 cmd_tx,
179 permits,
180 _handle: handle,
181 }
182 }
183
184 pub async fn submit(&self, input: AppendInput) -> Result<BatchSubmitTicket, S2Error> {
194 let permit = self.reserve(input.records.metered_bytes() as u32).await?;
195 Ok(permit.submit(input))
196 }
197
198 pub async fn reserve(&self, bytes: u32) -> Result<BatchSubmitPermit, S2Error> {
215 let append_permit = self.permits.acquire(bytes).await;
216 let cmd_tx_permit = self
217 .cmd_tx
218 .clone()
219 .reserve_owned()
220 .await
221 .map_err(|_| AppendSessionError::SessionClosed)?;
222 Ok(BatchSubmitPermit {
223 append_permit,
224 cmd_tx_permit,
225 })
226 }
227
228 pub async fn close(self) -> Result<(), S2Error> {
230 let (done_tx, done_rx) = oneshot::channel();
231 self.cmd_tx
232 .send(Command::Close { done_tx })
233 .await
234 .map_err(|_| AppendSessionError::SessionClosed)?;
235 done_rx
236 .await
237 .map_err(|_| AppendSessionError::SessionClosed)??;
238 Ok(())
239 }
240}
241
242pub struct BatchSubmitPermit {
244 append_permit: AppendPermit,
245 cmd_tx_permit: mpsc::OwnedPermit<Command>,
246}
247
248impl BatchSubmitPermit {
249 pub fn submit(self, input: AppendInput) -> BatchSubmitTicket {
251 let (ack_tx, ack_rx) = oneshot::channel();
252 self.cmd_tx_permit.send(Command::Submit {
253 input,
254 ack_tx,
255 permit: Some(self.append_permit),
256 });
257 BatchSubmitTicket { rx: ack_rx }
258 }
259}
260
261pub(crate) struct AppendSessionInternal {
262 cmd_tx: mpsc::Sender<Command>,
263 _handle: AbortOnDropHandle<()>,
264}
265
266impl AppendSessionInternal {
267 pub(crate) fn new(client: BasinClient, stream: StreamName) -> Self {
268 let buffer_size = DEFAULT_CHANNEL_BUFFER_SIZE;
269 let (cmd_tx, cmd_rx) = mpsc::channel(buffer_size);
270 let retry_builder = retry_builder(&client.config.retry);
271 let handle = AbortOnDropHandle::new(tokio::spawn(run_session_with_retry(
272 client,
273 stream,
274 cmd_rx,
275 retry_builder,
276 buffer_size,
277 )));
278 Self {
279 cmd_tx,
280 _handle: handle,
281 }
282 }
283
284 pub(crate) fn submit(
285 &self,
286 input: AppendInput,
287 ) -> impl Future<Output = Result<BatchSubmitTicket, S2Error>> + Send + 'static {
288 let cmd_tx = self.cmd_tx.clone();
289 async move {
290 let (ack_tx, ack_rx) = oneshot::channel();
291 cmd_tx
292 .send(Command::Submit {
293 input,
294 ack_tx,
295 permit: None,
296 })
297 .await
298 .map_err(|_| AppendSessionError::SessionClosed)?;
299 Ok(BatchSubmitTicket { rx: ack_rx })
300 }
301 }
302
303 pub(crate) async fn close(self) -> Result<(), S2Error> {
304 let (done_tx, done_rx) = oneshot::channel();
305 self.cmd_tx
306 .send(Command::Close { done_tx })
307 .await
308 .map_err(|_| AppendSessionError::SessionClosed)?;
309 done_rx
310 .await
311 .map_err(|_| AppendSessionError::SessionClosed)??;
312 Ok(())
313 }
314}
315
316#[derive(Debug)]
317pub(crate) struct AppendPermit {
318 _count: Option<OwnedSemaphorePermit>,
319 _bytes: OwnedSemaphorePermit,
320}
321
322#[derive(Clone)]
323pub(crate) struct AppendPermits {
324 count: Option<Arc<Semaphore>>,
325 bytes: Arc<Semaphore>,
326}
327
328impl AppendPermits {
329 pub(crate) fn new(count_permits: Option<u32>, bytes_permits: u32) -> Self {
330 Self {
331 count: count_permits.map(|permits| Arc::new(Semaphore::new(permits as usize))),
332 bytes: Arc::new(Semaphore::new(bytes_permits as usize)),
333 }
334 }
335
336 pub(crate) async fn acquire(&self, bytes: u32) -> AppendPermit {
337 AppendPermit {
338 _count: if let Some(count) = self.count.as_ref() {
339 Some(
340 count
341 .clone()
342 .acquire_many_owned(1)
343 .await
344 .expect("semaphore should not be closed"),
345 )
346 } else {
347 None
348 },
349 _bytes: self
350 .bytes
351 .clone()
352 .acquire_many_owned(bytes)
353 .await
354 .expect("semaphore should not be closed"),
355 }
356 }
357}
358
359async fn run_session_with_retry(
360 client: BasinClient,
361 stream: StreamName,
362 cmd_rx: mpsc::Receiver<Command>,
363 retry_builder: RetryBackoffBuilder,
364 buffer_size: usize,
365) {
366 let mut state = SessionState {
367 cmd_rx,
368 inflight_appends: VecDeque::new(),
369 inflight_bytes: 0,
370 close_tx: None,
371 total_records: 0,
372 total_acked_records: 0,
373 prev_ack_end: None,
374 stashed_submission: None,
375 };
376 let mut prev_total_acked_records = 0;
377 let mut retry_backoffs = retry_builder.build();
378
379 loop {
380 let result = run_session(&client, &stream, &mut state, buffer_size).await;
381
382 match result {
383 Ok(()) => {
384 break;
385 }
386 Err(err) => {
387 if prev_total_acked_records < state.total_acked_records {
388 prev_total_acked_records = state.total_acked_records;
389 retry_backoffs.reset();
390 }
391
392 let retry_policy_compliant = retry_policy_compliant(
393 client.config.retry.append_retry_policy,
394 &state.inflight_appends,
395 );
396
397 if retry_policy_compliant
398 && err.is_retryable()
399 && let Some(backoff) = retry_backoffs.next()
400 {
401 debug!(
402 %err,
403 ?backoff,
404 num_retries_remaining = retry_backoffs.remaining(),
405 "retrying append session"
406 );
407 tokio::time::sleep(backoff).await;
408 } else {
409 debug!(
410 %err,
411 retry_policy_compliant,
412 retries_exhausted = retry_backoffs.is_exhausted(),
413 "not retrying append session"
414 );
415
416 let err: S2Error = err.into();
417
418 for inflight_append in state.inflight_appends.drain(..) {
419 let _ = inflight_append.ack_tx.send(Err(err.clone()));
420 }
421
422 if let Some(stashed) = state.stashed_submission.take() {
423 let _ = stashed.ack_tx.send(Err(err.clone()));
424 }
425
426 state.cmd_rx.close();
427 while let Some(cmd) = state.cmd_rx.recv().await {
428 match cmd {
429 Command::Submit { ack_tx, .. } => {
430 let _ = ack_tx.send(Err(err.clone()));
431 }
432 Command::Close { done_tx } => {
433 let _ = done_tx.send(Err(err.clone()));
434 }
435 }
436 }
437
438 if let Some(done_tx) = state.close_tx.take() {
439 let _ = done_tx.send(Err(err));
440 }
441 break;
442 }
443 }
444 }
445 }
446
447 if let Some(done_tx) = state.close_tx.take() {
448 let _ = done_tx.send(Ok(()));
449 }
450}
451
452async fn run_session(
453 client: &BasinClient,
454 stream: &StreamName,
455 state: &mut SessionState,
456 buffer_size: usize,
457) -> Result<(), AppendSessionError> {
458 let (input_tx, mut acks) = connect(client, stream, buffer_size).await?;
459 let ack_timeout = client.config.request_timeout;
460
461 if !state.inflight_appends.is_empty() {
462 resend(state, &input_tx, &mut acks, ack_timeout).await?;
463
464 assert!(state.inflight_appends.is_empty());
465 assert_eq!(state.inflight_bytes, 0);
466 }
467
468 let timer = MuxTimer::<N_TIMER_VARIANTS>::default();
469 tokio::pin!(timer);
470
471 loop {
472 tokio::select! {
473 (event_ord, _deadline) = &mut timer, if timer.is_armed() => {
474 match TimerEvent::from(event_ord) {
475 TimerEvent::AckDeadline => {
476 return Err(AppendSessionError::AckTimeout);
477 }
478 }
479 }
480
481 input_tx_permit = input_tx.reserve(), if state.stashed_submission.is_some() => {
482 let input_tx_permit = input_tx_permit
483 .map_err(|_| AppendSessionError::ServerDisconnected)?;
484 let submission = state.stashed_submission
485 .take()
486 .expect("stashed_submission should not be None");
487
488 input_tx_permit.send(submission.input.clone());
489
490 state.total_records += submission.input.records.len();
491 state.inflight_bytes += submission.input_metered_bytes;
492
493 timer.as_mut().fire_at(
494 TimerEvent::AckDeadline,
495 submission.since + ack_timeout,
496 CoalesceMode::Earliest,
497 );
498 state.inflight_appends.push_back(submission.into());
499 }
500
501 cmd = state.cmd_rx.recv(), if state.stashed_submission.is_none() => {
502 match cmd {
503 Some(Command::Submit { input, ack_tx, permit }) => {
504 if state.close_tx.is_some() {
505 let _ = ack_tx.send(
506 Err(AppendSessionError::SessionClosing.into())
507 );
508 } else {
509 let input_metered_bytes = input.records.metered_bytes();
510 state.stashed_submission = Some(StashedSubmission {
511 input,
512 input_metered_bytes,
513 ack_tx,
514 permit,
515 since: Instant::now(),
516 });
517 }
518 }
519 Some(Command::Close { done_tx }) => {
520 state.close_tx = Some(done_tx);
521 }
522 None => {
523 return Err(AppendSessionError::SessionDropped);
524 }
525 }
526 }
527
528 ack = acks.next() => {
529 match ack {
530 Some(Ok(ack)) => {
531 process_ack(
532 ack,
533 state,
534 timer.as_mut(),
535 ack_timeout,
536 );
537 }
538 Some(Err(err)) => {
539 return Err(err.into());
540 }
541 None => {
542 if !state.inflight_appends.is_empty() || state.stashed_submission.is_some() {
543 return Err(AppendSessionError::StreamClosedEarly);
544 }
545 break;
546 }
547 }
548 }
549 }
550
551 if state.close_tx.is_some()
552 && state.inflight_appends.is_empty()
553 && state.stashed_submission.is_none()
554 {
555 break;
556 }
557 }
558
559 assert!(state.inflight_appends.is_empty());
560 assert_eq!(state.inflight_bytes, 0);
561 assert!(state.stashed_submission.is_none());
562
563 Ok(())
564}
565
566async fn resend(
567 state: &mut SessionState,
568 input_tx: &mpsc::Sender<AppendInput>,
569 acks: &mut Streaming<AppendAck>,
570 ack_timeout: Duration,
571) -> Result<(), AppendSessionError> {
572 debug!(
573 inflight_appends_len = state.inflight_appends.len(),
574 inflight_bytes = state.inflight_bytes,
575 "resending inflight appends"
576 );
577
578 let mut resend_index = 0;
579 let mut resend_finished = false;
580
581 let timer = MuxTimer::<N_TIMER_VARIANTS>::default();
582 tokio::pin!(timer);
583
584 while !state.inflight_appends.is_empty() {
585 tokio::select! {
586 (event_ord, _deadline) = &mut timer, if timer.is_armed() => {
587 match TimerEvent::from(event_ord) {
588 TimerEvent::AckDeadline => {
589 return Err(AppendSessionError::AckTimeout);
590 }
591 }
592 }
593
594 input_tx_permit = input_tx.reserve(), if !resend_finished => {
595 let input_tx_permit = input_tx_permit
596 .map_err(|_| AppendSessionError::ServerDisconnected)?;
597
598 if let Some(inflight_append) = state.inflight_appends.get_mut(resend_index) {
599 inflight_append.since = Instant::now();
600 timer.as_mut().fire_at(
601 TimerEvent::AckDeadline,
602 inflight_append.since + ack_timeout,
603 CoalesceMode::Earliest,
604 );
605 input_tx_permit.send(inflight_append.input.clone());
606 resend_index += 1;
607 } else {
608 resend_finished = true;
609 }
610 }
611
612 ack = acks.next() => {
613 match ack {
614 Some(Ok(ack)) => {
615 process_ack(
616 ack,
617 state,
618 timer.as_mut(),
619 ack_timeout,
620 );
621 resend_index -= 1;
622 }
623 Some(Err(err)) => {
624 return Err(err.into());
625 }
626 None => {
627 return Err(AppendSessionError::StreamClosedEarly);
628 }
629 }
630 }
631 }
632 }
633
634 assert_eq!(
635 resend_index, 0,
636 "resend_index should be 0 after resend completes"
637 );
638 debug!("finished resending inflight appends");
639 Ok(())
640}
641
642async fn connect(
643 client: &BasinClient,
644 stream: &StreamName,
645 buffer_size: usize,
646) -> Result<(mpsc::Sender<AppendInput>, Streaming<AppendAck>), AppendSessionError> {
647 let (input_tx, input_rx) = mpsc::channel::<AppendInput>(buffer_size);
648 let ack_stream = Box::pin(
649 client
650 .append_session(stream, ReceiverStream::new(input_rx).map(|i| i.into()))
651 .await?
652 .map(|ack| match ack {
653 Ok(ack) => Ok(ack.into()),
654 Err(err) => Err(err),
655 }),
656 );
657 Ok((input_tx, ack_stream))
658}
659
660fn process_ack(
661 ack: AppendAck,
662 state: &mut SessionState,
663 timer: Pin<&mut MuxTimer<N_TIMER_VARIANTS>>,
664 ack_timeout: Duration,
665) {
666 let corresponding_append = state
667 .inflight_appends
668 .pop_front()
669 .expect("corresponding append should be present for an ack");
670
671 assert!(
672 ack.end.seq_num >= ack.start.seq_num,
673 "ack end seq_num should be greater than or equal to start seq_num"
674 );
675
676 if let Some(end) = state.prev_ack_end {
677 assert!(
678 ack.end.seq_num > end.seq_num,
679 "ack end seq_num should be greater than previous ack end"
680 );
681 }
682
683 let num_acked_records = (ack.end.seq_num - ack.start.seq_num) as usize;
684 assert_eq!(
685 num_acked_records,
686 corresponding_append.input.records.len(),
687 "ack record count should match submitted batch size"
688 );
689
690 state.total_acked_records += num_acked_records;
691 state.inflight_bytes -= corresponding_append.input_metered_bytes;
692 state.prev_ack_end = Some(ack.end);
693
694 let _ = corresponding_append.ack_tx.send(Ok(ack));
695
696 if let Some(oldest_append) = state.inflight_appends.front() {
697 timer.fire_at(
698 TimerEvent::AckDeadline,
699 oldest_append.since + ack_timeout,
700 CoalesceMode::Latest,
701 );
702 } else {
703 timer.cancel(TimerEvent::AckDeadline);
704 assert_eq!(
705 state.total_records, state.total_acked_records,
706 "all records should be acked when inflight is empty"
707 );
708 }
709}
710
711struct StashedSubmission {
712 input: AppendInput,
713 input_metered_bytes: usize,
714 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
715 permit: Option<AppendPermit>,
716 since: Instant,
717}
718
719struct InflightAppend {
720 input: AppendInput,
721 input_metered_bytes: usize,
722 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
723 since: Instant,
724 _permit: Option<AppendPermit>,
725}
726
727impl From<StashedSubmission> for InflightAppend {
728 fn from(value: StashedSubmission) -> Self {
729 Self {
730 input: value.input,
731 input_metered_bytes: value.input_metered_bytes,
732 ack_tx: value.ack_tx,
733 since: value.since,
734 _permit: value.permit,
735 }
736 }
737}
738
739fn retry_policy_compliant(
740 policy: AppendRetryPolicy,
741 inflight_appends: &VecDeque<InflightAppend>,
742) -> bool {
743 if policy == AppendRetryPolicy::All {
744 return true;
745 }
746 inflight_appends
747 .iter()
748 .all(|ia| policy.is_compliant(&ia.input))
749}
750
751enum Command {
752 Submit {
753 input: AppendInput,
754 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
755 permit: Option<AppendPermit>,
756 },
757 Close {
758 done_tx: oneshot::Sender<Result<(), S2Error>>,
759 },
760}
761
762const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 100;
763
764#[derive(Debug, Clone, Copy, PartialEq, Eq)]
765enum TimerEvent {
766 AckDeadline,
767}
768
769const N_TIMER_VARIANTS: usize = 1;
770
771impl From<TimerEvent> for usize {
772 fn from(event: TimerEvent) -> Self {
773 match event {
774 TimerEvent::AckDeadline => 0,
775 }
776 }
777}
778
779impl From<usize> for TimerEvent {
780 fn from(value: usize) -> Self {
781 match value {
782 0 => TimerEvent::AckDeadline,
783 _ => panic!("invalid ordinal"),
784 }
785 }
786}