1use std::{
2 collections::VecDeque,
3 future::Future,
4 num::NonZeroU32,
5 pin::Pin,
6 sync::{Arc, OnceLock},
7 task::{Context, Poll},
8 time::Duration,
9};
10
11use futures::StreamExt;
12use tokio::{
13 sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot},
14 time::Instant,
15};
16use tokio_muxt::{CoalesceMode, MuxTimer};
17use tokio_stream::wrappers::ReceiverStream;
18use tokio_util::task::AbortOnDropHandle;
19use tracing::debug;
20
21use crate::{
22 api::{ApiError, BasinClient, Streaming, retry_builder},
23 retry::RetryBackoffBuilder,
24 types::{
25 AppendAck, AppendInput, AppendRetryPolicy, MeteredBytes, ONE_MIB, S2Error, StreamName,
26 StreamPosition, ValidationError,
27 },
28};
29
30#[derive(Debug, thiserror::Error)]
31pub enum AppendSessionError {
32 #[error(transparent)]
33 Api(#[from] ApiError),
34 #[error("append acknowledgement timed out")]
35 AckTimeout,
36 #[error("server disconnected")]
37 ServerDisconnected,
38 #[error("response stream closed early while appends in flight")]
39 StreamClosedEarly,
40 #[error("session already closed")]
41 SessionClosed,
42 #[error("session is closing")]
43 SessionClosing,
44 #[error("session dropped without calling close")]
45 SessionDropped,
46 #[error("unexpected append acknowledgement during resend")]
47 UnexpectedAck,
48}
49
50impl AppendSessionError {
51 pub fn is_retryable(&self) -> bool {
52 match self {
53 Self::Api(err) => err.is_retryable(),
54 Self::AckTimeout => true,
55 Self::ServerDisconnected => true,
56 _ => false,
57 }
58 }
59}
60
61impl From<AppendSessionError> for S2Error {
62 fn from(err: AppendSessionError) -> Self {
63 match err {
64 AppendSessionError::Api(api_err) => api_err.into(),
65 other => S2Error::Client(other.to_string()),
66 }
67 }
68}
69
70pub struct BatchSubmitTicket {
72 rx: oneshot::Receiver<Result<AppendAck, S2Error>>,
73 terminal_err: Arc<OnceLock<S2Error>>,
74}
75
76impl Future for BatchSubmitTicket {
77 type Output = Result<AppendAck, S2Error>;
78
79 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
80 match Pin::new(&mut self.rx).poll(cx) {
81 Poll::Ready(Ok(res)) => Poll::Ready(res),
82 Poll::Ready(Err(_)) => Poll::Ready(Err(self
83 .terminal_err
84 .get()
85 .cloned()
86 .unwrap_or_else(|| AppendSessionError::SessionDropped.into()))),
87 Poll::Pending => Poll::Pending,
88 }
89 }
90}
91
92#[derive(Debug, Clone)]
93pub struct AppendSessionConfig {
95 max_unacked_bytes: u32,
96 max_unacked_batches: Option<u32>,
97}
98
99impl Default for AppendSessionConfig {
100 fn default() -> Self {
101 Self {
102 max_unacked_bytes: 5 * ONE_MIB,
103 max_unacked_batches: None,
104 }
105 }
106}
107
108impl AppendSessionConfig {
109 pub fn new() -> Self {
111 Self::default()
112 }
113
114 pub fn with_max_unacked_bytes(self, max_unacked_bytes: u32) -> Result<Self, ValidationError> {
120 if max_unacked_bytes < ONE_MIB {
121 return Err(format!("max_unacked_bytes must be at least {ONE_MIB}").into());
122 }
123 Ok(Self {
124 max_unacked_bytes,
125 ..self
126 })
127 }
128
129 pub fn with_max_unacked_batches(self, max_unacked_batches: NonZeroU32) -> Self {
133 Self {
134 max_unacked_batches: Some(max_unacked_batches.get()),
135 ..self
136 }
137 }
138}
139
140struct SessionState {
141 cmd_rx: mpsc::Receiver<Command>,
142 inflight_appends: VecDeque<InflightAppend>,
143 inflight_bytes: usize,
144 close_tx: Option<oneshot::Sender<Result<(), S2Error>>>,
145 total_records: usize,
146 total_acked_records: usize,
147 prev_ack_end: Option<StreamPosition>,
148 stashed_submission: Option<StashedSubmission>,
149}
150
151pub struct AppendSession {
156 cmd_tx: mpsc::Sender<Command>,
157 permits: AppendPermits,
158 terminal_err: Arc<OnceLock<S2Error>>,
159 _handle: AbortOnDropHandle<()>,
160}
161
162impl AppendSession {
163 pub(crate) fn new(
164 client: BasinClient,
165 stream: StreamName,
166 config: AppendSessionConfig,
167 ) -> Self {
168 let buffer_size = config
169 .max_unacked_batches
170 .map(|mib| mib as usize)
171 .unwrap_or(DEFAULT_CHANNEL_BUFFER_SIZE);
172 let (cmd_tx, cmd_rx) = mpsc::channel(buffer_size);
173 let permits = AppendPermits::new(config.max_unacked_batches, config.max_unacked_bytes);
174 let retry_builder = retry_builder(&client.config.retry);
175 let terminal_err = Arc::new(OnceLock::new());
176 let handle = AbortOnDropHandle::new(tokio::spawn(run_session_with_retry(
177 client,
178 stream,
179 cmd_rx,
180 retry_builder,
181 buffer_size,
182 terminal_err.clone(),
183 )));
184 Self {
185 cmd_tx,
186 permits,
187 terminal_err,
188 _handle: handle,
189 }
190 }
191
192 pub async fn submit(&self, input: AppendInput) -> Result<BatchSubmitTicket, S2Error> {
202 let permit = self.reserve(input.records.metered_bytes() as u32).await?;
203 Ok(permit.submit(input))
204 }
205
206 pub async fn reserve(&self, bytes: u32) -> Result<BatchSubmitPermit, S2Error> {
223 let append_permit = self.permits.acquire(bytes).await;
224 let cmd_tx_permit = self
225 .cmd_tx
226 .clone()
227 .reserve_owned()
228 .await
229 .map_err(|_| self.terminal_err())?;
230 Ok(BatchSubmitPermit {
231 append_permit,
232 cmd_tx_permit,
233 terminal_err: self.terminal_err.clone(),
234 })
235 }
236
237 pub async fn close(self) -> Result<(), S2Error> {
239 let (done_tx, done_rx) = oneshot::channel();
240 self.cmd_tx
241 .send(Command::Close { done_tx })
242 .await
243 .map_err(|_| self.terminal_err())?;
244 done_rx.await.map_err(|_| self.terminal_err())??;
245 Ok(())
246 }
247
248 fn terminal_err(&self) -> S2Error {
249 self.terminal_err
250 .get()
251 .cloned()
252 .unwrap_or_else(|| AppendSessionError::SessionClosed.into())
253 }
254}
255
256pub struct BatchSubmitPermit {
258 append_permit: AppendPermit,
259 cmd_tx_permit: mpsc::OwnedPermit<Command>,
260 terminal_err: Arc<OnceLock<S2Error>>,
261}
262
263impl BatchSubmitPermit {
264 pub fn submit(self, input: AppendInput) -> BatchSubmitTicket {
266 let (ack_tx, ack_rx) = oneshot::channel();
267 self.cmd_tx_permit.send(Command::Submit {
268 input,
269 ack_tx,
270 permit: Some(self.append_permit),
271 });
272 BatchSubmitTicket {
273 rx: ack_rx,
274 terminal_err: self.terminal_err,
275 }
276 }
277}
278
279pub(crate) struct AppendSessionInternal {
280 cmd_tx: mpsc::Sender<Command>,
281 terminal_err: Arc<OnceLock<S2Error>>,
282 _handle: AbortOnDropHandle<()>,
283}
284
285impl AppendSessionInternal {
286 pub(crate) fn new(client: BasinClient, stream: StreamName) -> Self {
287 let buffer_size = DEFAULT_CHANNEL_BUFFER_SIZE;
288 let (cmd_tx, cmd_rx) = mpsc::channel(buffer_size);
289 let retry_builder = retry_builder(&client.config.retry);
290 let terminal_err = Arc::new(OnceLock::new());
291 let handle = AbortOnDropHandle::new(tokio::spawn(run_session_with_retry(
292 client,
293 stream,
294 cmd_rx,
295 retry_builder,
296 buffer_size,
297 terminal_err.clone(),
298 )));
299 Self {
300 cmd_tx,
301 terminal_err,
302 _handle: handle,
303 }
304 }
305
306 pub(crate) fn submit(
307 &self,
308 input: AppendInput,
309 ) -> impl Future<Output = Result<BatchSubmitTicket, S2Error>> + Send + 'static {
310 let cmd_tx = self.cmd_tx.clone();
311 let terminal_err = self.terminal_err.clone();
312 async move {
313 let (ack_tx, ack_rx) = oneshot::channel();
314 cmd_tx
315 .send(Command::Submit {
316 input,
317 ack_tx,
318 permit: None,
319 })
320 .await
321 .map_err(|_| {
322 terminal_err
323 .get()
324 .cloned()
325 .unwrap_or_else(|| AppendSessionError::SessionClosed.into())
326 })?;
327 Ok(BatchSubmitTicket {
328 rx: ack_rx,
329 terminal_err,
330 })
331 }
332 }
333
334 pub(crate) async fn close(self) -> Result<(), S2Error> {
335 let (done_tx, done_rx) = oneshot::channel();
336 self.cmd_tx
337 .send(Command::Close { done_tx })
338 .await
339 .map_err(|_| self.terminal_err())?;
340 done_rx.await.map_err(|_| self.terminal_err())??;
341 Ok(())
342 }
343
344 fn terminal_err(&self) -> S2Error {
345 self.terminal_err
346 .get()
347 .cloned()
348 .unwrap_or_else(|| AppendSessionError::SessionClosed.into())
349 }
350}
351
352#[derive(Debug)]
353pub(crate) struct AppendPermit {
354 _count: Option<OwnedSemaphorePermit>,
355 _bytes: OwnedSemaphorePermit,
356}
357
358#[derive(Clone)]
359pub(crate) struct AppendPermits {
360 count: Option<Arc<Semaphore>>,
361 bytes: Arc<Semaphore>,
362}
363
364impl AppendPermits {
365 pub(crate) fn new(count_permits: Option<u32>, bytes_permits: u32) -> Self {
366 Self {
367 count: count_permits.map(|permits| Arc::new(Semaphore::new(permits as usize))),
368 bytes: Arc::new(Semaphore::new(bytes_permits as usize)),
369 }
370 }
371
372 pub(crate) async fn acquire(&self, bytes: u32) -> AppendPermit {
373 AppendPermit {
374 _count: if let Some(count) = self.count.as_ref() {
375 Some(
376 count
377 .clone()
378 .acquire_many_owned(1)
379 .await
380 .expect("semaphore should not be closed"),
381 )
382 } else {
383 None
384 },
385 _bytes: self
386 .bytes
387 .clone()
388 .acquire_many_owned(bytes)
389 .await
390 .expect("semaphore should not be closed"),
391 }
392 }
393}
394
395async fn run_session_with_retry(
396 client: BasinClient,
397 stream: StreamName,
398 cmd_rx: mpsc::Receiver<Command>,
399 retry_builder: RetryBackoffBuilder,
400 buffer_size: usize,
401 terminal_err: Arc<OnceLock<S2Error>>,
402) {
403 let mut state = SessionState {
404 cmd_rx,
405 inflight_appends: VecDeque::new(),
406 inflight_bytes: 0,
407 close_tx: None,
408 total_records: 0,
409 total_acked_records: 0,
410 prev_ack_end: None,
411 stashed_submission: None,
412 };
413 let mut prev_total_acked_records = 0;
414 let mut retry_backoff = retry_builder.build();
415
416 loop {
417 let result = run_session(&client, &stream, &mut state, buffer_size).await;
418
419 match result {
420 Ok(()) => {
421 break;
422 }
423 Err(err) => {
424 if prev_total_acked_records < state.total_acked_records {
425 prev_total_acked_records = state.total_acked_records;
426 retry_backoff.reset();
427 }
428
429 let retry_policy_compliant = retry_policy_compliant(
430 client.config.retry.append_retry_policy,
431 &state.inflight_appends,
432 );
433
434 if retry_policy_compliant
435 && err.is_retryable()
436 && let Some(backoff) = retry_backoff.next()
437 {
438 debug!(
439 %err,
440 ?backoff,
441 num_retries_remaining = retry_backoff.remaining(),
442 "retrying append session"
443 );
444 tokio::time::sleep(backoff).await;
445 } else {
446 debug!(
447 %err,
448 retry_policy_compliant,
449 retries_exhausted = retry_backoff.is_exhausted(),
450 "not retrying append session"
451 );
452
453 let err: S2Error = err.into();
454
455 let _ = terminal_err.set(err.clone());
456
457 for inflight_append in state.inflight_appends.drain(..) {
458 let _ = inflight_append.ack_tx.send(Err(err.clone()));
459 }
460
461 if let Some(stashed) = state.stashed_submission.take() {
462 let _ = stashed.ack_tx.send(Err(err.clone()));
463 }
464
465 if let Some(done_tx) = state.close_tx.take() {
466 let _ = done_tx.send(Err(err.clone()));
467 }
468
469 state.cmd_rx.close();
470 while let Some(cmd) = state.cmd_rx.recv().await {
471 cmd.reject(err.clone());
472 }
473 break;
474 }
475 }
476 }
477 }
478
479 if let Some(done_tx) = state.close_tx.take() {
480 let _ = done_tx.send(Ok(()));
481 }
482}
483
484async fn run_session(
485 client: &BasinClient,
486 stream: &StreamName,
487 state: &mut SessionState,
488 buffer_size: usize,
489) -> Result<(), AppendSessionError> {
490 let (input_tx, mut acks) = connect(client, stream, buffer_size).await?;
491 let ack_timeout = client.config.request_timeout;
492
493 if !state.inflight_appends.is_empty() {
494 resend(state, &input_tx, &mut acks, ack_timeout).await?;
495
496 assert!(state.inflight_appends.is_empty());
497 assert_eq!(state.inflight_bytes, 0);
498 }
499
500 let timer = MuxTimer::<N_TIMER_VARIANTS>::default();
501 tokio::pin!(timer);
502
503 loop {
504 tokio::select! {
505 (event_ord, _deadline) = &mut timer, if timer.is_armed() => {
506 match TimerEvent::from(event_ord) {
507 TimerEvent::AckDeadline => {
508 return Err(AppendSessionError::AckTimeout);
509 }
510 }
511 }
512
513 input_tx_permit = input_tx.reserve(), if state.stashed_submission.is_some() => {
514 let input_tx_permit = input_tx_permit
515 .map_err(|_| AppendSessionError::ServerDisconnected)?;
516 let submission = state.stashed_submission
517 .take()
518 .expect("stashed_submission should not be None");
519
520 input_tx_permit.send(submission.input.clone());
521
522 state.total_records += submission.input.records.len();
523 state.inflight_bytes += submission.input_metered_bytes;
524
525 timer.as_mut().fire_at(
526 TimerEvent::AckDeadline,
527 submission.since + ack_timeout,
528 CoalesceMode::Earliest,
529 );
530 state.inflight_appends.push_back(submission.into());
531 }
532
533 cmd = state.cmd_rx.recv(), if state.stashed_submission.is_none() => {
534 match cmd {
535 Some(Command::Submit { input, ack_tx, permit }) => {
536 if state.close_tx.is_some() {
537 let _ = ack_tx.send(
538 Err(AppendSessionError::SessionClosing.into())
539 );
540 } else {
541 let input_metered_bytes = input.records.metered_bytes();
542 state.stashed_submission = Some(StashedSubmission {
543 input,
544 input_metered_bytes,
545 ack_tx,
546 permit,
547 since: Instant::now(),
548 });
549 }
550 }
551 Some(Command::Close { done_tx }) => {
552 state.close_tx = Some(done_tx);
553 }
554 None => {
555 return Err(AppendSessionError::SessionDropped);
556 }
557 }
558 }
559
560 ack = acks.next() => {
561 match ack {
562 Some(Ok(ack)) => {
563 process_ack(
564 ack,
565 state,
566 timer.as_mut(),
567 ack_timeout,
568 );
569 }
570 Some(Err(err)) => {
571 return Err(err.into());
572 }
573 None => {
574 if !state.inflight_appends.is_empty() || state.stashed_submission.is_some() {
575 return Err(AppendSessionError::StreamClosedEarly);
576 }
577 break;
578 }
579 }
580 }
581 }
582
583 if state.close_tx.is_some()
584 && state.inflight_appends.is_empty()
585 && state.stashed_submission.is_none()
586 {
587 break;
588 }
589 }
590
591 assert!(state.inflight_appends.is_empty());
592 assert_eq!(state.inflight_bytes, 0);
593 assert!(state.stashed_submission.is_none());
594
595 Ok(())
596}
597
598async fn resend(
599 state: &mut SessionState,
600 input_tx: &mpsc::Sender<AppendInput>,
601 acks: &mut Streaming<AppendAck>,
602 ack_timeout: Duration,
603) -> Result<(), AppendSessionError> {
604 debug!(
605 inflight_appends_len = state.inflight_appends.len(),
606 inflight_bytes = state.inflight_bytes,
607 "resending inflight appends"
608 );
609
610 let mut resend_index = 0;
611 let mut resend_finished = false;
612
613 let timer = MuxTimer::<N_TIMER_VARIANTS>::default();
614 tokio::pin!(timer);
615
616 while !state.inflight_appends.is_empty() {
617 tokio::select! {
618 (event_ord, _deadline) = &mut timer, if timer.is_armed() => {
619 match TimerEvent::from(event_ord) {
620 TimerEvent::AckDeadline => {
621 return Err(AppendSessionError::AckTimeout);
622 }
623 }
624 }
625
626 input_tx_permit = input_tx.reserve(), if !resend_finished => {
627 let input_tx_permit = input_tx_permit
628 .map_err(|_| AppendSessionError::ServerDisconnected)?;
629
630 if let Some(inflight_append) = state.inflight_appends.get_mut(resend_index) {
631 inflight_append.since = Instant::now();
632 timer.as_mut().fire_at(
633 TimerEvent::AckDeadline,
634 inflight_append.since + ack_timeout,
635 CoalesceMode::Earliest,
636 );
637 input_tx_permit.send(inflight_append.input.clone());
638 resend_index += 1;
639 } else {
640 resend_finished = true;
641 }
642 }
643
644 ack = acks.next() => {
645 match ack {
646 Some(Ok(ack)) => {
647 process_ack(
648 ack,
649 state,
650 timer.as_mut(),
651 ack_timeout,
652 );
653 resend_index = resend_index
654 .checked_sub(1)
655 .ok_or(AppendSessionError::UnexpectedAck)?;
656 }
657 Some(Err(err)) => {
658 return Err(err.into());
659 }
660 None => {
661 return Err(AppendSessionError::StreamClosedEarly);
662 }
663 }
664 }
665 }
666 }
667
668 assert_eq!(
669 resend_index, 0,
670 "resend_index should be 0 after resend completes"
671 );
672 debug!("finished resending inflight appends");
673 Ok(())
674}
675
676async fn connect(
677 client: &BasinClient,
678 stream: &StreamName,
679 buffer_size: usize,
680) -> Result<(mpsc::Sender<AppendInput>, Streaming<AppendAck>), AppendSessionError> {
681 let (input_tx, input_rx) = mpsc::channel::<AppendInput>(buffer_size);
682 let ack_stream = Box::pin(
683 client
684 .append_session(stream, ReceiverStream::new(input_rx).map(|i| i.into()))
685 .await?
686 .map(|ack| match ack {
687 Ok(ack) => Ok(ack.into()),
688 Err(err) => Err(err),
689 }),
690 );
691 Ok((input_tx, ack_stream))
692}
693
694fn process_ack(
695 ack: AppendAck,
696 state: &mut SessionState,
697 timer: Pin<&mut MuxTimer<N_TIMER_VARIANTS>>,
698 ack_timeout: Duration,
699) {
700 let corresponding_append = state
701 .inflight_appends
702 .pop_front()
703 .expect("corresponding append should be present for an ack");
704
705 assert!(
706 ack.end.seq_num >= ack.start.seq_num,
707 "ack end seq_num should be greater than or equal to start seq_num"
708 );
709
710 if let Some(end) = state.prev_ack_end {
711 assert!(
712 ack.end.seq_num > end.seq_num,
713 "ack end seq_num should be greater than previous ack end"
714 );
715 }
716
717 let num_acked_records = (ack.end.seq_num - ack.start.seq_num) as usize;
718 assert_eq!(
719 num_acked_records,
720 corresponding_append.input.records.len(),
721 "ack record count should match submitted batch size"
722 );
723
724 state.total_acked_records += num_acked_records;
725 state.inflight_bytes -= corresponding_append.input_metered_bytes;
726 state.prev_ack_end = Some(ack.end);
727
728 let _ = corresponding_append.ack_tx.send(Ok(ack));
729
730 if let Some(oldest_append) = state.inflight_appends.front() {
731 timer.fire_at(
732 TimerEvent::AckDeadline,
733 oldest_append.since + ack_timeout,
734 CoalesceMode::Latest,
735 );
736 } else {
737 timer.cancel(TimerEvent::AckDeadline);
738 assert_eq!(
739 state.total_records, state.total_acked_records,
740 "all records should be acked when inflight is empty"
741 );
742 }
743}
744
745struct StashedSubmission {
746 input: AppendInput,
747 input_metered_bytes: usize,
748 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
749 permit: Option<AppendPermit>,
750 since: Instant,
751}
752
753struct InflightAppend {
754 input: AppendInput,
755 input_metered_bytes: usize,
756 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
757 since: Instant,
758 _permit: Option<AppendPermit>,
759}
760
761impl From<StashedSubmission> for InflightAppend {
762 fn from(value: StashedSubmission) -> Self {
763 Self {
764 input: value.input,
765 input_metered_bytes: value.input_metered_bytes,
766 ack_tx: value.ack_tx,
767 since: value.since,
768 _permit: value.permit,
769 }
770 }
771}
772
773fn retry_policy_compliant(
774 policy: AppendRetryPolicy,
775 inflight_appends: &VecDeque<InflightAppend>,
776) -> bool {
777 if policy == AppendRetryPolicy::All {
778 return true;
779 }
780 inflight_appends
781 .iter()
782 .all(|ia| policy.is_compliant(&ia.input))
783}
784
785enum Command {
786 Submit {
787 input: AppendInput,
788 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
789 permit: Option<AppendPermit>,
790 },
791 Close {
792 done_tx: oneshot::Sender<Result<(), S2Error>>,
793 },
794}
795
796impl Command {
797 fn reject(self, err: S2Error) {
798 match self {
799 Command::Submit { ack_tx, .. } => {
800 let _ = ack_tx.send(Err(err));
801 }
802 Command::Close { done_tx } => {
803 let _ = done_tx.send(Err(err));
804 }
805 }
806 }
807}
808
809const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 100;
810
811#[derive(Debug, Clone, Copy, PartialEq, Eq)]
812enum TimerEvent {
813 AckDeadline,
814}
815
816const N_TIMER_VARIANTS: usize = 1;
817
818impl From<TimerEvent> for usize {
819 fn from(event: TimerEvent) -> Self {
820 match event {
821 TimerEvent::AckDeadline => 0,
822 }
823 }
824}
825
826impl From<usize> for TimerEvent {
827 fn from(value: usize) -> Self {
828 match value {
829 0 => TimerEvent::AckDeadline,
830 _ => panic!("invalid ordinal"),
831 }
832 }
833}