1use std::{
2 collections::VecDeque,
3 future::Future,
4 num::NonZeroU32,
5 pin::Pin,
6 sync::{Arc, OnceLock},
7 task::{Context, Poll},
8 time::Duration,
9};
10
11use futures::StreamExt;
12use tokio::{
13 sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot},
14 time::Instant,
15};
16use tokio_muxt::{CoalesceMode, MuxTimer};
17use tokio_stream::wrappers::ReceiverStream;
18use tokio_util::task::AbortOnDropHandle;
19use tracing::debug;
20
21use crate::{
22 api::{ApiError, BasinClient, Streaming, retry_builder},
23 retry::RetryBackoffBuilder,
24 types::{
25 AppendAck, AppendInput, AppendRetryPolicy, MeteredBytes, ONE_MIB, S2Error, StreamName,
26 StreamPosition, ValidationError,
27 },
28};
29
30#[derive(Debug, thiserror::Error)]
31pub enum AppendSessionError {
32 #[error(transparent)]
33 Api(#[from] ApiError),
34 #[error("append acknowledgement timed out")]
35 AckTimeout,
36 #[error("server disconnected")]
37 ServerDisconnected,
38 #[error("response stream closed early while appends in flight")]
39 StreamClosedEarly,
40 #[error("session already closed")]
41 SessionClosed,
42 #[error("session is closing")]
43 SessionClosing,
44 #[error("session dropped without calling close")]
45 SessionDropped,
46}
47
48impl AppendSessionError {
49 pub fn is_retryable(&self) -> bool {
50 match self {
51 Self::Api(err) => err.is_retryable(),
52 Self::AckTimeout => true,
53 Self::ServerDisconnected => true,
54 _ => false,
55 }
56 }
57}
58
59impl From<AppendSessionError> for S2Error {
60 fn from(err: AppendSessionError) -> Self {
61 match err {
62 AppendSessionError::Api(api_err) => api_err.into(),
63 other => S2Error::Client(other.to_string()),
64 }
65 }
66}
67
68pub struct BatchSubmitTicket {
70 rx: oneshot::Receiver<Result<AppendAck, S2Error>>,
71 terminal_err: Arc<OnceLock<S2Error>>,
72}
73
74impl Future for BatchSubmitTicket {
75 type Output = Result<AppendAck, S2Error>;
76
77 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
78 match Pin::new(&mut self.rx).poll(cx) {
79 Poll::Ready(Ok(res)) => Poll::Ready(res),
80 Poll::Ready(Err(_)) => Poll::Ready(Err(self
81 .terminal_err
82 .get()
83 .cloned()
84 .unwrap_or_else(|| AppendSessionError::SessionDropped.into()))),
85 Poll::Pending => Poll::Pending,
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
91pub struct AppendSessionConfig {
93 max_unacked_bytes: u32,
94 max_unacked_batches: Option<u32>,
95}
96
97impl Default for AppendSessionConfig {
98 fn default() -> Self {
99 Self {
100 max_unacked_bytes: 5 * ONE_MIB,
101 max_unacked_batches: None,
102 }
103 }
104}
105
106impl AppendSessionConfig {
107 pub fn new() -> Self {
109 Self::default()
110 }
111
112 pub fn with_max_unacked_bytes(self, max_unacked_bytes: u32) -> Result<Self, ValidationError> {
118 if max_unacked_bytes < ONE_MIB {
119 return Err(format!("max_unacked_bytes must be at least {ONE_MIB}").into());
120 }
121 Ok(Self {
122 max_unacked_bytes,
123 ..self
124 })
125 }
126
127 pub fn with_max_unacked_batches(self, max_unacked_batches: NonZeroU32) -> Self {
131 Self {
132 max_unacked_batches: Some(max_unacked_batches.get()),
133 ..self
134 }
135 }
136}
137
138struct SessionState {
139 cmd_rx: mpsc::Receiver<Command>,
140 inflight_appends: VecDeque<InflightAppend>,
141 inflight_bytes: usize,
142 close_tx: Option<oneshot::Sender<Result<(), S2Error>>>,
143 total_records: usize,
144 total_acked_records: usize,
145 prev_ack_end: Option<StreamPosition>,
146 stashed_submission: Option<StashedSubmission>,
147}
148
149pub struct AppendSession {
154 cmd_tx: mpsc::Sender<Command>,
155 permits: AppendPermits,
156 terminal_err: Arc<OnceLock<S2Error>>,
157 _handle: AbortOnDropHandle<()>,
158}
159
160impl AppendSession {
161 pub(crate) fn new(
162 client: BasinClient,
163 stream: StreamName,
164 config: AppendSessionConfig,
165 ) -> Self {
166 let buffer_size = config
167 .max_unacked_batches
168 .map(|mib| mib as usize)
169 .unwrap_or(DEFAULT_CHANNEL_BUFFER_SIZE);
170 let (cmd_tx, cmd_rx) = mpsc::channel(buffer_size);
171 let permits = AppendPermits::new(config.max_unacked_batches, config.max_unacked_bytes);
172 let retry_builder = retry_builder(&client.config.retry);
173 let terminal_err = Arc::new(OnceLock::new());
174 let handle = AbortOnDropHandle::new(tokio::spawn(run_session_with_retry(
175 client,
176 stream,
177 cmd_rx,
178 retry_builder,
179 buffer_size,
180 terminal_err.clone(),
181 )));
182 Self {
183 cmd_tx,
184 permits,
185 terminal_err,
186 _handle: handle,
187 }
188 }
189
190 pub async fn submit(&self, input: AppendInput) -> Result<BatchSubmitTicket, S2Error> {
200 let permit = self.reserve(input.records.metered_bytes() as u32).await?;
201 Ok(permit.submit(input))
202 }
203
204 pub async fn reserve(&self, bytes: u32) -> Result<BatchSubmitPermit, S2Error> {
221 let append_permit = self.permits.acquire(bytes).await;
222 let cmd_tx_permit = self
223 .cmd_tx
224 .clone()
225 .reserve_owned()
226 .await
227 .map_err(|_| self.terminal_err())?;
228 Ok(BatchSubmitPermit {
229 append_permit,
230 cmd_tx_permit,
231 terminal_err: self.terminal_err.clone(),
232 })
233 }
234
235 pub async fn close(self) -> Result<(), S2Error> {
237 let (done_tx, done_rx) = oneshot::channel();
238 self.cmd_tx
239 .send(Command::Close { done_tx })
240 .await
241 .map_err(|_| self.terminal_err())?;
242 done_rx.await.map_err(|_| self.terminal_err())??;
243 Ok(())
244 }
245
246 fn terminal_err(&self) -> S2Error {
247 self.terminal_err
248 .get()
249 .cloned()
250 .unwrap_or_else(|| AppendSessionError::SessionClosed.into())
251 }
252}
253
254pub struct BatchSubmitPermit {
256 append_permit: AppendPermit,
257 cmd_tx_permit: mpsc::OwnedPermit<Command>,
258 terminal_err: Arc<OnceLock<S2Error>>,
259}
260
261impl BatchSubmitPermit {
262 pub fn submit(self, input: AppendInput) -> BatchSubmitTicket {
264 let (ack_tx, ack_rx) = oneshot::channel();
265 self.cmd_tx_permit.send(Command::Submit {
266 input,
267 ack_tx,
268 permit: Some(self.append_permit),
269 });
270 BatchSubmitTicket {
271 rx: ack_rx,
272 terminal_err: self.terminal_err,
273 }
274 }
275}
276
277pub(crate) struct AppendSessionInternal {
278 cmd_tx: mpsc::Sender<Command>,
279 terminal_err: Arc<OnceLock<S2Error>>,
280 _handle: AbortOnDropHandle<()>,
281}
282
283impl AppendSessionInternal {
284 pub(crate) fn new(client: BasinClient, stream: StreamName) -> Self {
285 let buffer_size = DEFAULT_CHANNEL_BUFFER_SIZE;
286 let (cmd_tx, cmd_rx) = mpsc::channel(buffer_size);
287 let retry_builder = retry_builder(&client.config.retry);
288 let terminal_err = Arc::new(OnceLock::new());
289 let handle = AbortOnDropHandle::new(tokio::spawn(run_session_with_retry(
290 client,
291 stream,
292 cmd_rx,
293 retry_builder,
294 buffer_size,
295 terminal_err.clone(),
296 )));
297 Self {
298 cmd_tx,
299 terminal_err,
300 _handle: handle,
301 }
302 }
303
304 pub(crate) fn submit(
305 &self,
306 input: AppendInput,
307 ) -> impl Future<Output = Result<BatchSubmitTicket, S2Error>> + Send + 'static {
308 let cmd_tx = self.cmd_tx.clone();
309 let terminal_err = self.terminal_err.clone();
310 async move {
311 let (ack_tx, ack_rx) = oneshot::channel();
312 cmd_tx
313 .send(Command::Submit {
314 input,
315 ack_tx,
316 permit: None,
317 })
318 .await
319 .map_err(|_| {
320 terminal_err
321 .get()
322 .cloned()
323 .unwrap_or_else(|| AppendSessionError::SessionClosed.into())
324 })?;
325 Ok(BatchSubmitTicket {
326 rx: ack_rx,
327 terminal_err,
328 })
329 }
330 }
331
332 pub(crate) async fn close(self) -> Result<(), S2Error> {
333 let (done_tx, done_rx) = oneshot::channel();
334 self.cmd_tx
335 .send(Command::Close { done_tx })
336 .await
337 .map_err(|_| self.terminal_err())?;
338 done_rx.await.map_err(|_| self.terminal_err())??;
339 Ok(())
340 }
341
342 fn terminal_err(&self) -> S2Error {
343 self.terminal_err
344 .get()
345 .cloned()
346 .unwrap_or_else(|| AppendSessionError::SessionClosed.into())
347 }
348}
349
350#[derive(Debug)]
351pub(crate) struct AppendPermit {
352 _count: Option<OwnedSemaphorePermit>,
353 _bytes: OwnedSemaphorePermit,
354}
355
356#[derive(Clone)]
357pub(crate) struct AppendPermits {
358 count: Option<Arc<Semaphore>>,
359 bytes: Arc<Semaphore>,
360}
361
362impl AppendPermits {
363 pub(crate) fn new(count_permits: Option<u32>, bytes_permits: u32) -> Self {
364 Self {
365 count: count_permits.map(|permits| Arc::new(Semaphore::new(permits as usize))),
366 bytes: Arc::new(Semaphore::new(bytes_permits as usize)),
367 }
368 }
369
370 pub(crate) async fn acquire(&self, bytes: u32) -> AppendPermit {
371 AppendPermit {
372 _count: if let Some(count) = self.count.as_ref() {
373 Some(
374 count
375 .clone()
376 .acquire_many_owned(1)
377 .await
378 .expect("semaphore should not be closed"),
379 )
380 } else {
381 None
382 },
383 _bytes: self
384 .bytes
385 .clone()
386 .acquire_many_owned(bytes)
387 .await
388 .expect("semaphore should not be closed"),
389 }
390 }
391}
392
393async fn run_session_with_retry(
394 client: BasinClient,
395 stream: StreamName,
396 cmd_rx: mpsc::Receiver<Command>,
397 retry_builder: RetryBackoffBuilder,
398 buffer_size: usize,
399 terminal_err: Arc<OnceLock<S2Error>>,
400) {
401 let mut state = SessionState {
402 cmd_rx,
403 inflight_appends: VecDeque::new(),
404 inflight_bytes: 0,
405 close_tx: None,
406 total_records: 0,
407 total_acked_records: 0,
408 prev_ack_end: None,
409 stashed_submission: None,
410 };
411 let mut prev_total_acked_records = 0;
412 let mut retry_backoff = retry_builder.build();
413
414 loop {
415 let result = run_session(&client, &stream, &mut state, buffer_size).await;
416
417 match result {
418 Ok(()) => {
419 break;
420 }
421 Err(err) => {
422 if prev_total_acked_records < state.total_acked_records {
423 prev_total_acked_records = state.total_acked_records;
424 retry_backoff.reset();
425 }
426
427 let retry_policy_compliant = retry_policy_compliant(
428 client.config.retry.append_retry_policy,
429 &state.inflight_appends,
430 );
431
432 if retry_policy_compliant
433 && err.is_retryable()
434 && let Some(backoff) = retry_backoff.next()
435 {
436 debug!(
437 %err,
438 ?backoff,
439 num_retries_remaining = retry_backoff.remaining(),
440 "retrying append session"
441 );
442 tokio::time::sleep(backoff).await;
443 } else {
444 debug!(
445 %err,
446 retry_policy_compliant,
447 retries_exhausted = retry_backoff.is_exhausted(),
448 "not retrying append session"
449 );
450
451 let err: S2Error = err.into();
452
453 let _ = terminal_err.set(err.clone());
454
455 for inflight_append in state.inflight_appends.drain(..) {
456 let _ = inflight_append.ack_tx.send(Err(err.clone()));
457 }
458
459 if let Some(stashed) = state.stashed_submission.take() {
460 let _ = stashed.ack_tx.send(Err(err.clone()));
461 }
462
463 if let Some(done_tx) = state.close_tx.take() {
464 let _ = done_tx.send(Err(err.clone()));
465 }
466
467 state.cmd_rx.close();
468 while let Some(cmd) = state.cmd_rx.recv().await {
469 cmd.reject(err.clone());
470 }
471 break;
472 }
473 }
474 }
475 }
476
477 if let Some(done_tx) = state.close_tx.take() {
478 let _ = done_tx.send(Ok(()));
479 }
480}
481
482async fn run_session(
483 client: &BasinClient,
484 stream: &StreamName,
485 state: &mut SessionState,
486 buffer_size: usize,
487) -> Result<(), AppendSessionError> {
488 let (input_tx, mut acks) = connect(client, stream, buffer_size).await?;
489 let ack_timeout = client.config.request_timeout;
490
491 if !state.inflight_appends.is_empty() {
492 resend(state, &input_tx, &mut acks, ack_timeout).await?;
493
494 assert!(state.inflight_appends.is_empty());
495 assert_eq!(state.inflight_bytes, 0);
496 }
497
498 let timer = MuxTimer::<N_TIMER_VARIANTS>::default();
499 tokio::pin!(timer);
500
501 loop {
502 tokio::select! {
503 (event_ord, _deadline) = &mut timer, if timer.is_armed() => {
504 match TimerEvent::from(event_ord) {
505 TimerEvent::AckDeadline => {
506 return Err(AppendSessionError::AckTimeout);
507 }
508 }
509 }
510
511 input_tx_permit = input_tx.reserve(), if state.stashed_submission.is_some() => {
512 let input_tx_permit = input_tx_permit
513 .map_err(|_| AppendSessionError::ServerDisconnected)?;
514 let submission = state.stashed_submission
515 .take()
516 .expect("stashed_submission should not be None");
517
518 input_tx_permit.send(submission.input.clone());
519
520 state.total_records += submission.input.records.len();
521 state.inflight_bytes += submission.input_metered_bytes;
522
523 timer.as_mut().fire_at(
524 TimerEvent::AckDeadline,
525 submission.since + ack_timeout,
526 CoalesceMode::Earliest,
527 );
528 state.inflight_appends.push_back(submission.into());
529 }
530
531 cmd = state.cmd_rx.recv(), if state.stashed_submission.is_none() => {
532 match cmd {
533 Some(Command::Submit { input, ack_tx, permit }) => {
534 if state.close_tx.is_some() {
535 let _ = ack_tx.send(
536 Err(AppendSessionError::SessionClosing.into())
537 );
538 } else {
539 let input_metered_bytes = input.records.metered_bytes();
540 state.stashed_submission = Some(StashedSubmission {
541 input,
542 input_metered_bytes,
543 ack_tx,
544 permit,
545 since: Instant::now(),
546 });
547 }
548 }
549 Some(Command::Close { done_tx }) => {
550 state.close_tx = Some(done_tx);
551 }
552 None => {
553 return Err(AppendSessionError::SessionDropped);
554 }
555 }
556 }
557
558 ack = acks.next() => {
559 match ack {
560 Some(Ok(ack)) => {
561 process_ack(
562 ack,
563 state,
564 timer.as_mut(),
565 ack_timeout,
566 );
567 }
568 Some(Err(err)) => {
569 return Err(err.into());
570 }
571 None => {
572 if !state.inflight_appends.is_empty() || state.stashed_submission.is_some() {
573 return Err(AppendSessionError::StreamClosedEarly);
574 }
575 break;
576 }
577 }
578 }
579 }
580
581 if state.close_tx.is_some()
582 && state.inflight_appends.is_empty()
583 && state.stashed_submission.is_none()
584 {
585 break;
586 }
587 }
588
589 assert!(state.inflight_appends.is_empty());
590 assert_eq!(state.inflight_bytes, 0);
591 assert!(state.stashed_submission.is_none());
592
593 Ok(())
594}
595
596async fn resend(
597 state: &mut SessionState,
598 input_tx: &mpsc::Sender<AppendInput>,
599 acks: &mut Streaming<AppendAck>,
600 ack_timeout: Duration,
601) -> Result<(), AppendSessionError> {
602 debug!(
603 inflight_appends_len = state.inflight_appends.len(),
604 inflight_bytes = state.inflight_bytes,
605 "resending inflight appends"
606 );
607
608 let mut resend_index = 0;
609 let mut resend_finished = false;
610
611 let timer = MuxTimer::<N_TIMER_VARIANTS>::default();
612 tokio::pin!(timer);
613
614 while !state.inflight_appends.is_empty() {
615 tokio::select! {
616 (event_ord, _deadline) = &mut timer, if timer.is_armed() => {
617 match TimerEvent::from(event_ord) {
618 TimerEvent::AckDeadline => {
619 return Err(AppendSessionError::AckTimeout);
620 }
621 }
622 }
623
624 input_tx_permit = input_tx.reserve(), if !resend_finished => {
625 let input_tx_permit = input_tx_permit
626 .map_err(|_| AppendSessionError::ServerDisconnected)?;
627
628 if let Some(inflight_append) = state.inflight_appends.get_mut(resend_index) {
629 inflight_append.since = Instant::now();
630 timer.as_mut().fire_at(
631 TimerEvent::AckDeadline,
632 inflight_append.since + ack_timeout,
633 CoalesceMode::Earliest,
634 );
635 input_tx_permit.send(inflight_append.input.clone());
636 resend_index += 1;
637 } else {
638 resend_finished = true;
639 }
640 }
641
642 ack = acks.next() => {
643 match ack {
644 Some(Ok(ack)) => {
645 process_ack(
646 ack,
647 state,
648 timer.as_mut(),
649 ack_timeout,
650 );
651 resend_index -= 1;
652 }
653 Some(Err(err)) => {
654 return Err(err.into());
655 }
656 None => {
657 return Err(AppendSessionError::StreamClosedEarly);
658 }
659 }
660 }
661 }
662 }
663
664 assert_eq!(
665 resend_index, 0,
666 "resend_index should be 0 after resend completes"
667 );
668 debug!("finished resending inflight appends");
669 Ok(())
670}
671
672async fn connect(
673 client: &BasinClient,
674 stream: &StreamName,
675 buffer_size: usize,
676) -> Result<(mpsc::Sender<AppendInput>, Streaming<AppendAck>), AppendSessionError> {
677 let (input_tx, input_rx) = mpsc::channel::<AppendInput>(buffer_size);
678 let ack_stream = Box::pin(
679 client
680 .append_session(stream, ReceiverStream::new(input_rx).map(|i| i.into()))
681 .await?
682 .map(|ack| match ack {
683 Ok(ack) => Ok(ack.into()),
684 Err(err) => Err(err),
685 }),
686 );
687 Ok((input_tx, ack_stream))
688}
689
690fn process_ack(
691 ack: AppendAck,
692 state: &mut SessionState,
693 timer: Pin<&mut MuxTimer<N_TIMER_VARIANTS>>,
694 ack_timeout: Duration,
695) {
696 let corresponding_append = state
697 .inflight_appends
698 .pop_front()
699 .expect("corresponding append should be present for an ack");
700
701 assert!(
702 ack.end.seq_num >= ack.start.seq_num,
703 "ack end seq_num should be greater than or equal to start seq_num"
704 );
705
706 if let Some(end) = state.prev_ack_end {
707 assert!(
708 ack.end.seq_num > end.seq_num,
709 "ack end seq_num should be greater than previous ack end"
710 );
711 }
712
713 let num_acked_records = (ack.end.seq_num - ack.start.seq_num) as usize;
714 assert_eq!(
715 num_acked_records,
716 corresponding_append.input.records.len(),
717 "ack record count should match submitted batch size"
718 );
719
720 state.total_acked_records += num_acked_records;
721 state.inflight_bytes -= corresponding_append.input_metered_bytes;
722 state.prev_ack_end = Some(ack.end);
723
724 let _ = corresponding_append.ack_tx.send(Ok(ack));
725
726 if let Some(oldest_append) = state.inflight_appends.front() {
727 timer.fire_at(
728 TimerEvent::AckDeadline,
729 oldest_append.since + ack_timeout,
730 CoalesceMode::Latest,
731 );
732 } else {
733 timer.cancel(TimerEvent::AckDeadline);
734 assert_eq!(
735 state.total_records, state.total_acked_records,
736 "all records should be acked when inflight is empty"
737 );
738 }
739}
740
741struct StashedSubmission {
742 input: AppendInput,
743 input_metered_bytes: usize,
744 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
745 permit: Option<AppendPermit>,
746 since: Instant,
747}
748
749struct InflightAppend {
750 input: AppendInput,
751 input_metered_bytes: usize,
752 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
753 since: Instant,
754 _permit: Option<AppendPermit>,
755}
756
757impl From<StashedSubmission> for InflightAppend {
758 fn from(value: StashedSubmission) -> Self {
759 Self {
760 input: value.input,
761 input_metered_bytes: value.input_metered_bytes,
762 ack_tx: value.ack_tx,
763 since: value.since,
764 _permit: value.permit,
765 }
766 }
767}
768
769fn retry_policy_compliant(
770 policy: AppendRetryPolicy,
771 inflight_appends: &VecDeque<InflightAppend>,
772) -> bool {
773 if policy == AppendRetryPolicy::All {
774 return true;
775 }
776 inflight_appends
777 .iter()
778 .all(|ia| policy.is_compliant(&ia.input))
779}
780
781enum Command {
782 Submit {
783 input: AppendInput,
784 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
785 permit: Option<AppendPermit>,
786 },
787 Close {
788 done_tx: oneshot::Sender<Result<(), S2Error>>,
789 },
790}
791
792impl Command {
793 fn reject(self, err: S2Error) {
794 match self {
795 Command::Submit { ack_tx, .. } => {
796 let _ = ack_tx.send(Err(err));
797 }
798 Command::Close { done_tx } => {
799 let _ = done_tx.send(Err(err));
800 }
801 }
802 }
803}
804
805const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 100;
806
807#[derive(Debug, Clone, Copy, PartialEq, Eq)]
808enum TimerEvent {
809 AckDeadline,
810}
811
812const N_TIMER_VARIANTS: usize = 1;
813
814impl From<TimerEvent> for usize {
815 fn from(event: TimerEvent) -> Self {
816 match event {
817 TimerEvent::AckDeadline => 0,
818 }
819 }
820}
821
822impl From<usize> for TimerEvent {
823 fn from(value: usize) -> Self {
824 match value {
825 0 => TimerEvent::AckDeadline,
826 _ => panic!("invalid ordinal"),
827 }
828 }
829}