1use std::{
2 collections::VecDeque,
3 future::Future,
4 num::NonZeroU32,
5 pin::Pin,
6 sync::{Arc, OnceLock},
7 task::{Context, Poll},
8 time::Duration,
9};
10
11use futures::StreamExt;
12use tokio::{
13 sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot},
14 time::Instant,
15};
16use tokio_muxt::{CoalesceMode, MuxTimer};
17use tokio_stream::wrappers::ReceiverStream;
18use tokio_util::task::AbortOnDropHandle;
19use tracing::debug;
20
21use crate::{
22 api::{ApiError, BasinClient, Streaming, retry_builder},
23 retry::RetryBackoffBuilder,
24 types::{
25 AppendAck, AppendInput, AppendRetryPolicy, MeteredBytes, ONE_MIB, S2Error, StreamName,
26 StreamPosition, ValidationError,
27 },
28};
29
30#[derive(Debug, thiserror::Error)]
31pub enum AppendSessionError {
32 #[error(transparent)]
33 Api(#[from] ApiError),
34 #[error("append acknowledgement timed out")]
35 AckTimeout,
36 #[error("server disconnected")]
37 ServerDisconnected,
38 #[error("response stream closed early while appends in flight")]
39 StreamClosedEarly,
40 #[error("session already closed")]
41 SessionClosed,
42 #[error("session is closing")]
43 SessionClosing,
44 #[error("session dropped without calling close")]
45 SessionDropped,
46}
47
48impl AppendSessionError {
49 pub fn is_retryable(&self) -> bool {
50 match self {
51 Self::Api(err) => err.is_retryable(),
52 Self::AckTimeout => true,
53 Self::ServerDisconnected => true,
54 _ => false,
55 }
56 }
57}
58
59impl From<AppendSessionError> for S2Error {
60 fn from(err: AppendSessionError) -> Self {
61 match err {
62 AppendSessionError::Api(api_err) => api_err.into(),
63 other => S2Error::Client(other.to_string()),
64 }
65 }
66}
67
68pub struct BatchSubmitTicket {
70 rx: oneshot::Receiver<Result<AppendAck, S2Error>>,
71 terminal_err: Arc<OnceLock<S2Error>>,
72}
73
74impl Future for BatchSubmitTicket {
75 type Output = Result<AppendAck, S2Error>;
76
77 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
78 match Pin::new(&mut self.rx).poll(cx) {
79 Poll::Ready(Ok(res)) => Poll::Ready(res),
80 Poll::Ready(Err(_)) => Poll::Ready(Err(self
81 .terminal_err
82 .get()
83 .cloned()
84 .unwrap_or_else(|| AppendSessionError::SessionDropped.into()))),
85 Poll::Pending => Poll::Pending,
86 }
87 }
88}
89
90#[derive(Debug, Clone)]
91pub struct AppendSessionConfig {
93 max_unacked_bytes: u32,
94 max_unacked_batches: Option<u32>,
95}
96
97impl Default for AppendSessionConfig {
98 fn default() -> Self {
99 Self {
100 max_unacked_bytes: 10 * ONE_MIB,
101 max_unacked_batches: None,
102 }
103 }
104}
105
106impl AppendSessionConfig {
107 pub fn new() -> Self {
109 Self::default()
110 }
111
112 pub fn with_max_unacked_bytes(self, max_unacked_bytes: u32) -> Result<Self, ValidationError> {
118 if max_unacked_bytes < ONE_MIB {
119 return Err(format!("max_unacked_bytes must be at least {ONE_MIB}").into());
120 }
121 Ok(Self {
122 max_unacked_bytes,
123 ..self
124 })
125 }
126
127 pub fn with_max_unacked_batches(
131 self,
132 max_unacked_batches: NonZeroU32,
133 ) -> Result<Self, ValidationError> {
134 Ok(Self {
135 max_unacked_batches: Some(max_unacked_batches.get()),
136 ..self
137 })
138 }
139}
140
141struct SessionState {
142 cmd_rx: mpsc::Receiver<Command>,
143 inflight_appends: VecDeque<InflightAppend>,
144 inflight_bytes: usize,
145 close_tx: Option<oneshot::Sender<Result<(), S2Error>>>,
146 total_records: usize,
147 total_acked_records: usize,
148 prev_ack_end: Option<StreamPosition>,
149 stashed_submission: Option<StashedSubmission>,
150}
151
152pub struct AppendSession {
157 cmd_tx: mpsc::Sender<Command>,
158 permits: AppendPermits,
159 terminal_err: Arc<OnceLock<S2Error>>,
160 _handle: AbortOnDropHandle<()>,
161}
162
163impl AppendSession {
164 pub(crate) fn new(
165 client: BasinClient,
166 stream: StreamName,
167 config: AppendSessionConfig,
168 ) -> Self {
169 let buffer_size = config
170 .max_unacked_batches
171 .map(|mib| mib as usize)
172 .unwrap_or(DEFAULT_CHANNEL_BUFFER_SIZE);
173 let (cmd_tx, cmd_rx) = mpsc::channel(buffer_size);
174 let permits = AppendPermits::new(config.max_unacked_batches, config.max_unacked_bytes);
175 let retry_builder = retry_builder(&client.config.retry);
176 let terminal_err = Arc::new(OnceLock::new());
177 let handle = AbortOnDropHandle::new(tokio::spawn(run_session_with_retry(
178 client,
179 stream,
180 cmd_rx,
181 retry_builder,
182 buffer_size,
183 terminal_err.clone(),
184 )));
185 Self {
186 cmd_tx,
187 permits,
188 terminal_err,
189 _handle: handle,
190 }
191 }
192
193 pub async fn submit(&self, input: AppendInput) -> Result<BatchSubmitTicket, S2Error> {
203 let permit = self.reserve(input.records.metered_bytes() as u32).await?;
204 Ok(permit.submit(input))
205 }
206
207 pub async fn reserve(&self, bytes: u32) -> Result<BatchSubmitPermit, S2Error> {
224 let append_permit = self.permits.acquire(bytes).await;
225 let cmd_tx_permit = self
226 .cmd_tx
227 .clone()
228 .reserve_owned()
229 .await
230 .map_err(|_| self.terminal_err())?;
231 Ok(BatchSubmitPermit {
232 append_permit,
233 cmd_tx_permit,
234 terminal_err: self.terminal_err.clone(),
235 })
236 }
237
238 pub async fn close(self) -> Result<(), S2Error> {
240 let (done_tx, done_rx) = oneshot::channel();
241 self.cmd_tx
242 .send(Command::Close { done_tx })
243 .await
244 .map_err(|_| self.terminal_err())?;
245 done_rx.await.map_err(|_| self.terminal_err())??;
246 Ok(())
247 }
248
249 fn terminal_err(&self) -> S2Error {
250 self.terminal_err
251 .get()
252 .cloned()
253 .unwrap_or_else(|| AppendSessionError::SessionClosed.into())
254 }
255}
256
257pub struct BatchSubmitPermit {
259 append_permit: AppendPermit,
260 cmd_tx_permit: mpsc::OwnedPermit<Command>,
261 terminal_err: Arc<OnceLock<S2Error>>,
262}
263
264impl BatchSubmitPermit {
265 pub fn submit(self, input: AppendInput) -> BatchSubmitTicket {
267 let (ack_tx, ack_rx) = oneshot::channel();
268 self.cmd_tx_permit.send(Command::Submit {
269 input,
270 ack_tx,
271 permit: Some(self.append_permit),
272 });
273 BatchSubmitTicket {
274 rx: ack_rx,
275 terminal_err: self.terminal_err,
276 }
277 }
278}
279
280pub(crate) struct AppendSessionInternal {
281 cmd_tx: mpsc::Sender<Command>,
282 terminal_err: Arc<OnceLock<S2Error>>,
283 _handle: AbortOnDropHandle<()>,
284}
285
286impl AppendSessionInternal {
287 pub(crate) fn new(client: BasinClient, stream: StreamName) -> Self {
288 let buffer_size = DEFAULT_CHANNEL_BUFFER_SIZE;
289 let (cmd_tx, cmd_rx) = mpsc::channel(buffer_size);
290 let retry_builder = retry_builder(&client.config.retry);
291 let terminal_err = Arc::new(OnceLock::new());
292 let handle = AbortOnDropHandle::new(tokio::spawn(run_session_with_retry(
293 client,
294 stream,
295 cmd_rx,
296 retry_builder,
297 buffer_size,
298 terminal_err.clone(),
299 )));
300 Self {
301 cmd_tx,
302 terminal_err,
303 _handle: handle,
304 }
305 }
306
307 pub(crate) fn submit(
308 &self,
309 input: AppendInput,
310 ) -> impl Future<Output = Result<BatchSubmitTicket, S2Error>> + Send + 'static {
311 let cmd_tx = self.cmd_tx.clone();
312 let terminal_err = self.terminal_err.clone();
313 async move {
314 let (ack_tx, ack_rx) = oneshot::channel();
315 cmd_tx
316 .send(Command::Submit {
317 input,
318 ack_tx,
319 permit: None,
320 })
321 .await
322 .map_err(|_| {
323 terminal_err
324 .get()
325 .cloned()
326 .unwrap_or_else(|| AppendSessionError::SessionClosed.into())
327 })?;
328 Ok(BatchSubmitTicket {
329 rx: ack_rx,
330 terminal_err,
331 })
332 }
333 }
334
335 pub(crate) async fn close(self) -> Result<(), S2Error> {
336 let (done_tx, done_rx) = oneshot::channel();
337 self.cmd_tx
338 .send(Command::Close { done_tx })
339 .await
340 .map_err(|_| self.terminal_err())?;
341 done_rx.await.map_err(|_| self.terminal_err())??;
342 Ok(())
343 }
344
345 fn terminal_err(&self) -> S2Error {
346 self.terminal_err
347 .get()
348 .cloned()
349 .unwrap_or_else(|| AppendSessionError::SessionClosed.into())
350 }
351}
352
353#[derive(Debug)]
354pub(crate) struct AppendPermit {
355 _count: Option<OwnedSemaphorePermit>,
356 _bytes: OwnedSemaphorePermit,
357}
358
359#[derive(Clone)]
360pub(crate) struct AppendPermits {
361 count: Option<Arc<Semaphore>>,
362 bytes: Arc<Semaphore>,
363}
364
365impl AppendPermits {
366 pub(crate) fn new(count_permits: Option<u32>, bytes_permits: u32) -> Self {
367 Self {
368 count: count_permits.map(|permits| Arc::new(Semaphore::new(permits as usize))),
369 bytes: Arc::new(Semaphore::new(bytes_permits as usize)),
370 }
371 }
372
373 pub(crate) async fn acquire(&self, bytes: u32) -> AppendPermit {
374 AppendPermit {
375 _count: if let Some(count) = self.count.as_ref() {
376 Some(
377 count
378 .clone()
379 .acquire_many_owned(1)
380 .await
381 .expect("semaphore should not be closed"),
382 )
383 } else {
384 None
385 },
386 _bytes: self
387 .bytes
388 .clone()
389 .acquire_many_owned(bytes)
390 .await
391 .expect("semaphore should not be closed"),
392 }
393 }
394}
395
396async fn run_session_with_retry(
397 client: BasinClient,
398 stream: StreamName,
399 cmd_rx: mpsc::Receiver<Command>,
400 retry_builder: RetryBackoffBuilder,
401 buffer_size: usize,
402 terminal_err: Arc<OnceLock<S2Error>>,
403) {
404 let mut state = SessionState {
405 cmd_rx,
406 inflight_appends: VecDeque::new(),
407 inflight_bytes: 0,
408 close_tx: None,
409 total_records: 0,
410 total_acked_records: 0,
411 prev_ack_end: None,
412 stashed_submission: None,
413 };
414 let mut prev_total_acked_records = 0;
415 let mut retry_backoffs = retry_builder.build();
416
417 loop {
418 let result = run_session(&client, &stream, &mut state, buffer_size).await;
419
420 match result {
421 Ok(()) => {
422 break;
423 }
424 Err(err) => {
425 if prev_total_acked_records < state.total_acked_records {
426 prev_total_acked_records = state.total_acked_records;
427 retry_backoffs.reset();
428 }
429
430 let retry_policy_compliant = retry_policy_compliant(
431 client.config.retry.append_retry_policy,
432 &state.inflight_appends,
433 );
434
435 if retry_policy_compliant
436 && err.is_retryable()
437 && let Some(backoff) = retry_backoffs.next()
438 {
439 debug!(
440 %err,
441 ?backoff,
442 num_retries_remaining = retry_backoffs.remaining(),
443 "retrying append session"
444 );
445 tokio::time::sleep(backoff).await;
446 } else {
447 debug!(
448 %err,
449 retry_policy_compliant,
450 retries_exhausted = retry_backoffs.is_exhausted(),
451 "not retrying append session"
452 );
453
454 let err: S2Error = err.into();
455
456 let _ = terminal_err.set(err.clone());
457
458 for inflight_append in state.inflight_appends.drain(..) {
459 let _ = inflight_append.ack_tx.send(Err(err.clone()));
460 }
461
462 if let Some(stashed) = state.stashed_submission.take() {
463 let _ = stashed.ack_tx.send(Err(err.clone()));
464 }
465
466 if let Some(done_tx) = state.close_tx.take() {
467 let _ = done_tx.send(Err(err.clone()));
468 }
469
470 state.cmd_rx.close();
471 while let Some(cmd) = state.cmd_rx.recv().await {
472 cmd.reject(err.clone());
473 }
474 break;
475 }
476 }
477 }
478 }
479
480 if let Some(done_tx) = state.close_tx.take() {
481 let _ = done_tx.send(Ok(()));
482 }
483}
484
485async fn run_session(
486 client: &BasinClient,
487 stream: &StreamName,
488 state: &mut SessionState,
489 buffer_size: usize,
490) -> Result<(), AppendSessionError> {
491 let (input_tx, mut acks) = connect(client, stream, buffer_size).await?;
492 let ack_timeout = client.config.request_timeout;
493
494 if !state.inflight_appends.is_empty() {
495 resend(state, &input_tx, &mut acks, ack_timeout).await?;
496
497 assert!(state.inflight_appends.is_empty());
498 assert_eq!(state.inflight_bytes, 0);
499 }
500
501 let timer = MuxTimer::<N_TIMER_VARIANTS>::default();
502 tokio::pin!(timer);
503
504 loop {
505 tokio::select! {
506 (event_ord, _deadline) = &mut timer, if timer.is_armed() => {
507 match TimerEvent::from(event_ord) {
508 TimerEvent::AckDeadline => {
509 return Err(AppendSessionError::AckTimeout);
510 }
511 }
512 }
513
514 input_tx_permit = input_tx.reserve(), if state.stashed_submission.is_some() => {
515 let input_tx_permit = input_tx_permit
516 .map_err(|_| AppendSessionError::ServerDisconnected)?;
517 let submission = state.stashed_submission
518 .take()
519 .expect("stashed_submission should not be None");
520
521 input_tx_permit.send(submission.input.clone());
522
523 state.total_records += submission.input.records.len();
524 state.inflight_bytes += submission.input_metered_bytes;
525
526 timer.as_mut().fire_at(
527 TimerEvent::AckDeadline,
528 submission.since + ack_timeout,
529 CoalesceMode::Earliest,
530 );
531 state.inflight_appends.push_back(submission.into());
532 }
533
534 cmd = state.cmd_rx.recv(), if state.stashed_submission.is_none() => {
535 match cmd {
536 Some(Command::Submit { input, ack_tx, permit }) => {
537 if state.close_tx.is_some() {
538 let _ = ack_tx.send(
539 Err(AppendSessionError::SessionClosing.into())
540 );
541 } else {
542 let input_metered_bytes = input.records.metered_bytes();
543 state.stashed_submission = Some(StashedSubmission {
544 input,
545 input_metered_bytes,
546 ack_tx,
547 permit,
548 since: Instant::now(),
549 });
550 }
551 }
552 Some(Command::Close { done_tx }) => {
553 state.close_tx = Some(done_tx);
554 }
555 None => {
556 return Err(AppendSessionError::SessionDropped);
557 }
558 }
559 }
560
561 ack = acks.next() => {
562 match ack {
563 Some(Ok(ack)) => {
564 process_ack(
565 ack,
566 state,
567 timer.as_mut(),
568 ack_timeout,
569 );
570 }
571 Some(Err(err)) => {
572 return Err(err.into());
573 }
574 None => {
575 if !state.inflight_appends.is_empty() || state.stashed_submission.is_some() {
576 return Err(AppendSessionError::StreamClosedEarly);
577 }
578 break;
579 }
580 }
581 }
582 }
583
584 if state.close_tx.is_some()
585 && state.inflight_appends.is_empty()
586 && state.stashed_submission.is_none()
587 {
588 break;
589 }
590 }
591
592 assert!(state.inflight_appends.is_empty());
593 assert_eq!(state.inflight_bytes, 0);
594 assert!(state.stashed_submission.is_none());
595
596 Ok(())
597}
598
599async fn resend(
600 state: &mut SessionState,
601 input_tx: &mpsc::Sender<AppendInput>,
602 acks: &mut Streaming<AppendAck>,
603 ack_timeout: Duration,
604) -> Result<(), AppendSessionError> {
605 debug!(
606 inflight_appends_len = state.inflight_appends.len(),
607 inflight_bytes = state.inflight_bytes,
608 "resending inflight appends"
609 );
610
611 let mut resend_index = 0;
612 let mut resend_finished = false;
613
614 let timer = MuxTimer::<N_TIMER_VARIANTS>::default();
615 tokio::pin!(timer);
616
617 while !state.inflight_appends.is_empty() {
618 tokio::select! {
619 (event_ord, _deadline) = &mut timer, if timer.is_armed() => {
620 match TimerEvent::from(event_ord) {
621 TimerEvent::AckDeadline => {
622 return Err(AppendSessionError::AckTimeout);
623 }
624 }
625 }
626
627 input_tx_permit = input_tx.reserve(), if !resend_finished => {
628 let input_tx_permit = input_tx_permit
629 .map_err(|_| AppendSessionError::ServerDisconnected)?;
630
631 if let Some(inflight_append) = state.inflight_appends.get_mut(resend_index) {
632 inflight_append.since = Instant::now();
633 timer.as_mut().fire_at(
634 TimerEvent::AckDeadline,
635 inflight_append.since + ack_timeout,
636 CoalesceMode::Earliest,
637 );
638 input_tx_permit.send(inflight_append.input.clone());
639 resend_index += 1;
640 } else {
641 resend_finished = true;
642 }
643 }
644
645 ack = acks.next() => {
646 match ack {
647 Some(Ok(ack)) => {
648 process_ack(
649 ack,
650 state,
651 timer.as_mut(),
652 ack_timeout,
653 );
654 resend_index -= 1;
655 }
656 Some(Err(err)) => {
657 return Err(err.into());
658 }
659 None => {
660 return Err(AppendSessionError::StreamClosedEarly);
661 }
662 }
663 }
664 }
665 }
666
667 assert_eq!(
668 resend_index, 0,
669 "resend_index should be 0 after resend completes"
670 );
671 debug!("finished resending inflight appends");
672 Ok(())
673}
674
675async fn connect(
676 client: &BasinClient,
677 stream: &StreamName,
678 buffer_size: usize,
679) -> Result<(mpsc::Sender<AppendInput>, Streaming<AppendAck>), AppendSessionError> {
680 let (input_tx, input_rx) = mpsc::channel::<AppendInput>(buffer_size);
681 let ack_stream = Box::pin(
682 client
683 .append_session(stream, ReceiverStream::new(input_rx).map(|i| i.into()))
684 .await?
685 .map(|ack| match ack {
686 Ok(ack) => Ok(ack.into()),
687 Err(err) => Err(err),
688 }),
689 );
690 Ok((input_tx, ack_stream))
691}
692
693fn process_ack(
694 ack: AppendAck,
695 state: &mut SessionState,
696 timer: Pin<&mut MuxTimer<N_TIMER_VARIANTS>>,
697 ack_timeout: Duration,
698) {
699 let corresponding_append = state
700 .inflight_appends
701 .pop_front()
702 .expect("corresponding append should be present for an ack");
703
704 assert!(
705 ack.end.seq_num >= ack.start.seq_num,
706 "ack end seq_num should be greater than or equal to start seq_num"
707 );
708
709 if let Some(end) = state.prev_ack_end {
710 assert!(
711 ack.end.seq_num > end.seq_num,
712 "ack end seq_num should be greater than previous ack end"
713 );
714 }
715
716 let num_acked_records = (ack.end.seq_num - ack.start.seq_num) as usize;
717 assert_eq!(
718 num_acked_records,
719 corresponding_append.input.records.len(),
720 "ack record count should match submitted batch size"
721 );
722
723 state.total_acked_records += num_acked_records;
724 state.inflight_bytes -= corresponding_append.input_metered_bytes;
725 state.prev_ack_end = Some(ack.end);
726
727 let _ = corresponding_append.ack_tx.send(Ok(ack));
728
729 if let Some(oldest_append) = state.inflight_appends.front() {
730 timer.fire_at(
731 TimerEvent::AckDeadline,
732 oldest_append.since + ack_timeout,
733 CoalesceMode::Latest,
734 );
735 } else {
736 timer.cancel(TimerEvent::AckDeadline);
737 assert_eq!(
738 state.total_records, state.total_acked_records,
739 "all records should be acked when inflight is empty"
740 );
741 }
742}
743
744struct StashedSubmission {
745 input: AppendInput,
746 input_metered_bytes: usize,
747 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
748 permit: Option<AppendPermit>,
749 since: Instant,
750}
751
752struct InflightAppend {
753 input: AppendInput,
754 input_metered_bytes: usize,
755 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
756 since: Instant,
757 _permit: Option<AppendPermit>,
758}
759
760impl From<StashedSubmission> for InflightAppend {
761 fn from(value: StashedSubmission) -> Self {
762 Self {
763 input: value.input,
764 input_metered_bytes: value.input_metered_bytes,
765 ack_tx: value.ack_tx,
766 since: value.since,
767 _permit: value.permit,
768 }
769 }
770}
771
772fn retry_policy_compliant(
773 policy: AppendRetryPolicy,
774 inflight_appends: &VecDeque<InflightAppend>,
775) -> bool {
776 if policy == AppendRetryPolicy::All {
777 return true;
778 }
779 inflight_appends
780 .iter()
781 .all(|ia| policy.is_compliant(&ia.input))
782}
783
784enum Command {
785 Submit {
786 input: AppendInput,
787 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
788 permit: Option<AppendPermit>,
789 },
790 Close {
791 done_tx: oneshot::Sender<Result<(), S2Error>>,
792 },
793}
794
795impl Command {
796 fn reject(self, err: S2Error) {
797 match self {
798 Command::Submit { ack_tx, .. } => {
799 let _ = ack_tx.send(Err(err));
800 }
801 Command::Close { done_tx } => {
802 let _ = done_tx.send(Err(err));
803 }
804 }
805 }
806}
807
808const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 100;
809
810#[derive(Debug, Clone, Copy, PartialEq, Eq)]
811enum TimerEvent {
812 AckDeadline,
813}
814
815const N_TIMER_VARIANTS: usize = 1;
816
817impl From<TimerEvent> for usize {
818 fn from(event: TimerEvent) -> Self {
819 match event {
820 TimerEvent::AckDeadline => 0,
821 }
822 }
823}
824
825impl From<usize> for TimerEvent {
826 fn from(value: usize) -> Self {
827 match value {
828 0 => TimerEvent::AckDeadline,
829 _ => panic!("invalid ordinal"),
830 }
831 }
832}