1use std::{
2 collections::VecDeque,
3 future::Future,
4 num::NonZeroU32,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8 time::Duration,
9};
10
11use futures::StreamExt;
12use tokio::{
13 sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot},
14 time::Instant,
15};
16use tokio_muxt::{CoalesceMode, MuxTimer};
17use tokio_stream::wrappers::ReceiverStream;
18use tokio_util::task::AbortOnDropHandle;
19use tracing::debug;
20
21use crate::{
22 api::{ApiError, BasinClient, Streaming, retry_builder},
23 retry::RetryBackoffBuilder,
24 types::{
25 AppendAck, AppendInput, AppendRetryPolicy, MeteredBytes, ONE_MIB, S2Error, StreamName,
26 StreamPosition, ValidationError,
27 },
28};
29
30#[derive(Debug, thiserror::Error)]
31pub enum AppendSessionError {
32 #[error(transparent)]
33 Api(#[from] ApiError),
34 #[error("append acknowledgement timed out")]
35 AckTimeout,
36 #[error("server disconnected")]
37 ServerDisconnected,
38 #[error("response stream closed early while appends in flight")]
39 StreamClosedEarly,
40 #[error("session already closed")]
41 SessionClosed,
42 #[error("session is closing")]
43 SessionClosing,
44 #[error("session dropped without calling close")]
45 SessionDropped,
46}
47
48impl AppendSessionError {
49 pub fn is_retryable(&self) -> bool {
50 match self {
51 Self::Api(err) => err.is_retryable(),
52 Self::AckTimeout => true,
53 Self::ServerDisconnected => true,
54 _ => false,
55 }
56 }
57}
58
59impl From<AppendSessionError> for S2Error {
60 fn from(err: AppendSessionError) -> Self {
61 match err {
62 AppendSessionError::Api(api_err) => api_err.into(),
63 other => S2Error::Client(other.to_string()),
64 }
65 }
66}
67
68pub struct BatchSubmitTicket {
70 rx: oneshot::Receiver<Result<AppendAck, S2Error>>,
71}
72
73impl Future for BatchSubmitTicket {
74 type Output = Result<AppendAck, S2Error>;
75
76 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
77 match Pin::new(&mut self.rx).poll(cx) {
78 Poll::Ready(Ok(res)) => Poll::Ready(res),
79 Poll::Ready(Err(_)) => Poll::Ready(Err(AppendSessionError::SessionDropped.into())),
80 Poll::Pending => Poll::Pending,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
86pub struct AppendSessionConfig {
88 max_inflight_bytes: u32,
89 max_inflight_batches: Option<u32>,
90}
91
92impl Default for AppendSessionConfig {
93 fn default() -> Self {
94 Self {
95 max_inflight_bytes: 10 * ONE_MIB,
96 max_inflight_batches: None,
97 }
98 }
99}
100
101impl AppendSessionConfig {
102 pub fn new() -> Self {
104 Self::default()
105 }
106
107 pub fn with_max_inflight_bytes(self, max_inflight_bytes: u32) -> Result<Self, ValidationError> {
113 if max_inflight_bytes < ONE_MIB {
114 return Err(format!("max_inflight_bytes must be at least {ONE_MIB}").into());
115 }
116 Ok(Self {
117 max_inflight_bytes,
118 ..self
119 })
120 }
121
122 pub fn with_max_inflight_batches(
126 self,
127 max_inflight_batches: NonZeroU32,
128 ) -> Result<Self, ValidationError> {
129 Ok(Self {
130 max_inflight_batches: Some(max_inflight_batches.get()),
131 ..self
132 })
133 }
134}
135
136struct SessionState {
137 cmd_rx: mpsc::Receiver<Command>,
138 inflight_appends: VecDeque<InflightAppend>,
139 inflight_bytes: usize,
140 close_tx: Option<oneshot::Sender<Result<(), S2Error>>>,
141 closing: bool,
142 total_records: usize,
143 total_acked_records: usize,
144 prev_ack_end: Option<StreamPosition>,
145}
146
147pub struct AppendSession {
152 cmd_tx: mpsc::Sender<Command>,
153 permits: AppendPermits,
154 _handle: AbortOnDropHandle<()>,
155}
156
157impl AppendSession {
158 pub(crate) fn new(
159 client: BasinClient,
160 stream: StreamName,
161 config: AppendSessionConfig,
162 ) -> Self {
163 let buffer_size = config
164 .max_inflight_batches
165 .map(|mib| mib as usize)
166 .unwrap_or(DEFAULT_CHANNEL_BUFFER_SIZE);
167 let (cmd_tx, cmd_rx) = mpsc::channel(buffer_size);
168 let permits = AppendPermits::new(config.max_inflight_batches, config.max_inflight_bytes);
169 let retry_builder = retry_builder(&client.config.retry);
170 let handle = AbortOnDropHandle::new(tokio::spawn(run_session_with_retry(
171 client,
172 stream,
173 cmd_rx,
174 retry_builder,
175 buffer_size,
176 )));
177 Self {
178 cmd_tx,
179 permits,
180 _handle: handle,
181 }
182 }
183
184 pub async fn submit(&self, input: AppendInput) -> Result<BatchSubmitTicket, S2Error> {
194 let permit = self.reserve(input.records.metered_bytes() as u32).await?;
195 Ok(permit.submit(input))
196 }
197
198 pub async fn reserve(&self, bytes: u32) -> Result<BatchSubmitPermit, S2Error> {
215 let append_permit = self.permits.acquire(bytes).await;
216 let cmd_tx_permit = self
217 .cmd_tx
218 .clone()
219 .reserve_owned()
220 .await
221 .map_err(|_| AppendSessionError::SessionClosed)?;
222 Ok(BatchSubmitPermit {
223 append_permit,
224 cmd_tx_permit,
225 })
226 }
227
228 pub async fn close(self) -> Result<(), S2Error> {
230 let (done_tx, done_rx) = oneshot::channel();
231 self.cmd_tx
232 .send(Command::Close { done_tx })
233 .await
234 .map_err(|_| AppendSessionError::SessionClosed)?;
235 done_rx
236 .await
237 .map_err(|_| AppendSessionError::SessionClosed)??;
238 Ok(())
239 }
240}
241
242pub struct BatchSubmitPermit {
244 append_permit: AppendPermit,
245 cmd_tx_permit: mpsc::OwnedPermit<Command>,
246}
247
248impl BatchSubmitPermit {
249 pub fn submit(self, input: AppendInput) -> BatchSubmitTicket {
251 let (ack_tx, ack_rx) = oneshot::channel();
252 self.cmd_tx_permit.send(Command::Submit {
253 input,
254 ack_tx,
255 permit: Some(self.append_permit),
256 });
257 BatchSubmitTicket { rx: ack_rx }
258 }
259}
260
261pub(crate) struct AppendSessionInternal {
262 cmd_tx: mpsc::Sender<Command>,
263 _handle: AbortOnDropHandle<()>,
264}
265
266impl AppendSessionInternal {
267 pub(crate) fn new(client: BasinClient, stream: StreamName) -> Self {
268 let buffer_size = DEFAULT_CHANNEL_BUFFER_SIZE;
269 let (cmd_tx, cmd_rx) = mpsc::channel(buffer_size);
270 let retry_builder = retry_builder(&client.config.retry);
271 let handle = AbortOnDropHandle::new(tokio::spawn(run_session_with_retry(
272 client,
273 stream,
274 cmd_rx,
275 retry_builder,
276 buffer_size,
277 )));
278 Self {
279 cmd_tx,
280 _handle: handle,
281 }
282 }
283
284 pub(crate) fn submit(
285 &self,
286 input: AppendInput,
287 ) -> impl Future<Output = Result<BatchSubmitTicket, S2Error>> + Send + 'static {
288 let cmd_tx = self.cmd_tx.clone();
289 async move {
290 let (ack_tx, ack_rx) = oneshot::channel();
291 cmd_tx
292 .send(Command::Submit {
293 input,
294 ack_tx,
295 permit: None,
296 })
297 .await
298 .map_err(|_| AppendSessionError::SessionClosed)?;
299 Ok(BatchSubmitTicket { rx: ack_rx })
300 }
301 }
302
303 pub(crate) async fn close(self) -> Result<(), S2Error> {
304 let (done_tx, done_rx) = oneshot::channel();
305 self.cmd_tx
306 .send(Command::Close { done_tx })
307 .await
308 .map_err(|_| AppendSessionError::SessionClosed)?;
309 done_rx
310 .await
311 .map_err(|_| AppendSessionError::SessionClosed)??;
312 Ok(())
313 }
314}
315
316#[derive(Debug)]
317pub(crate) struct AppendPermit {
318 _count: Option<OwnedSemaphorePermit>,
319 _bytes: OwnedSemaphorePermit,
320}
321
322#[derive(Clone)]
323pub(crate) struct AppendPermits {
324 count: Option<Arc<Semaphore>>,
325 bytes: Arc<Semaphore>,
326}
327
328impl AppendPermits {
329 pub(crate) fn new(count_permits: Option<u32>, bytes_permits: u32) -> Self {
330 Self {
331 count: count_permits.map(|permits| Arc::new(Semaphore::new(permits as usize))),
332 bytes: Arc::new(Semaphore::new(bytes_permits as usize)),
333 }
334 }
335
336 pub(crate) async fn acquire(&self, bytes: u32) -> AppendPermit {
337 AppendPermit {
338 _count: if let Some(count) = self.count.as_ref() {
339 Some(
340 count
341 .clone()
342 .acquire_many_owned(1)
343 .await
344 .expect("semaphore should not be closed"),
345 )
346 } else {
347 None
348 },
349 _bytes: self
350 .bytes
351 .clone()
352 .acquire_many_owned(bytes)
353 .await
354 .expect("semaphore should not be closed"),
355 }
356 }
357}
358
359async fn run_session_with_retry(
360 client: BasinClient,
361 stream: StreamName,
362 cmd_rx: mpsc::Receiver<Command>,
363 retry_builder: RetryBackoffBuilder,
364 buffer_size: usize,
365) {
366 let mut state = SessionState {
367 cmd_rx,
368 inflight_appends: VecDeque::new(),
369 inflight_bytes: 0,
370 close_tx: None,
371 closing: false,
372 total_records: 0,
373 total_acked_records: 0,
374 prev_ack_end: None,
375 };
376 let mut prev_total_acked_records = 0;
377 let mut retry_backoffs = retry_builder.build();
378
379 loop {
380 let result = run_session(&client, &stream, &mut state, buffer_size).await;
381
382 match result {
383 Ok(()) => {
384 break;
385 }
386 Err(err) => {
387 if prev_total_acked_records < state.total_acked_records {
388 prev_total_acked_records = state.total_acked_records;
389 retry_backoffs.reset();
390 }
391
392 let retry_policy_compliant = retry_policy_compliant(
393 client.config.retry.append_retry_policy,
394 &state.inflight_appends,
395 );
396
397 if retry_policy_compliant
398 && err.is_retryable()
399 && let Some(backoff) = retry_backoffs.next()
400 {
401 debug!(
402 %err,
403 ?backoff,
404 num_retries_remaining = retry_backoffs.remaining(),
405 "retrying append session"
406 );
407 tokio::time::sleep(backoff).await;
408 } else {
409 debug!(
410 %err,
411 retry_policy_compliant,
412 retries_exhausted = retry_backoffs.is_exhausted(),
413 "not retrying append session"
414 );
415
416 let err: S2Error = err.into();
417 for inflight_append in state.inflight_appends.drain(..) {
418 let _ = inflight_append.ack_tx.send(Err(err.clone()));
419 }
420
421 if let Some(done_tx) = state.close_tx.take() {
422 let _ = done_tx.send(Err(err));
423 }
424 break;
425 }
426 }
427 }
428 }
429
430 if let Some(done_tx) = state.close_tx.take() {
431 let _ = done_tx.send(Ok(()));
432 }
433}
434
435async fn run_session(
436 client: &BasinClient,
437 stream: &StreamName,
438 state: &mut SessionState,
439 buffer_size: usize,
440) -> Result<(), AppendSessionError> {
441 let (input_tx, mut acks) = connect(client, stream, buffer_size).await?;
442 let ack_timeout = client.config.request_timeout;
443
444 if !state.inflight_appends.is_empty() {
445 resend(state, &input_tx, &mut acks, ack_timeout).await?;
446
447 assert!(state.inflight_appends.is_empty());
448 assert_eq!(state.inflight_bytes, 0);
449 }
450
451 let timer = MuxTimer::<N_TIMER_VARIANTS>::default();
452 tokio::pin!(timer);
453
454 let mut stashed_submission: Option<StashedSubmission> = None;
455
456 loop {
457 tokio::select! {
458 (event_ord, _deadline) = &mut timer, if timer.is_armed() => {
459 match TimerEvent::from(event_ord) {
460 TimerEvent::AckDeadline => {
461 return Err(AppendSessionError::AckTimeout);
462 }
463 }
464 }
465
466 input_tx_permit = input_tx.reserve(), if stashed_submission.is_some() => {
467 let input_tx_permit = input_tx_permit
468 .map_err(|_| AppendSessionError::ServerDisconnected)?;
469 let submission = stashed_submission
470 .take()
471 .expect("stashed_submission should not be None");
472
473 input_tx_permit.send(submission.input.clone());
474
475 state.total_records += submission.input.records.len();
476 state.inflight_bytes += submission.input_metered_bytes;
477
478 timer.as_mut().fire_at(
479 TimerEvent::AckDeadline,
480 submission.since + ack_timeout,
481 CoalesceMode::Earliest,
482 );
483 state.inflight_appends.push_back(submission.into());
484 }
485
486 cmd = state.cmd_rx.recv(), if stashed_submission.is_none() => {
487 match cmd {
488 Some(Command::Submit { input, ack_tx, permit }) => {
489 if state.closing {
490 let _ = ack_tx.send(
491 Err(AppendSessionError::SessionClosing.into())
492 );
493 } else {
494 let input_metered_bytes = input.records.metered_bytes();
495 stashed_submission = Some(StashedSubmission {
496 input,
497 input_metered_bytes,
498 ack_tx,
499 permit,
500 since: Instant::now(),
501 });
502 }
503 }
504 Some(Command::Close { done_tx }) => {
505 state.closing = true;
506 state.close_tx = Some(done_tx);
507 }
508 None => {
509 return Err(AppendSessionError::SessionDropped);
510 }
511 }
512 }
513
514 ack = acks.next() => {
515 match ack {
516 Some(Ok(ack)) => {
517 process_ack(
518 ack,
519 state,
520 timer.as_mut(),
521 ack_timeout,
522 );
523 }
524 Some(Err(err)) => {
525 return Err(err.into());
526 }
527 None => {
528 if !state.inflight_appends.is_empty() || stashed_submission.is_some() {
529 return Err(AppendSessionError::StreamClosedEarly);
530 }
531 break;
532 }
533 }
534 }
535 }
536
537 if state.closing && state.inflight_appends.is_empty() && stashed_submission.is_none() {
538 break;
539 }
540 }
541
542 assert!(state.inflight_appends.is_empty());
543 assert_eq!(state.inflight_bytes, 0);
544 assert!(stashed_submission.is_none());
545
546 Ok(())
547}
548
549async fn resend(
550 state: &mut SessionState,
551 input_tx: &mpsc::Sender<AppendInput>,
552 acks: &mut Streaming<AppendAck>,
553 ack_timeout: Duration,
554) -> Result<(), AppendSessionError> {
555 debug!(
556 inflight_appends_len = state.inflight_appends.len(),
557 inflight_bytes = state.inflight_bytes,
558 "resending inflight appends"
559 );
560
561 let mut resend_index = 0;
562 let mut resend_finished = false;
563
564 let timer = MuxTimer::<N_TIMER_VARIANTS>::default();
565 tokio::pin!(timer);
566
567 while !state.inflight_appends.is_empty() {
568 tokio::select! {
569 (event_ord, _deadline) = &mut timer, if timer.is_armed() => {
570 match TimerEvent::from(event_ord) {
571 TimerEvent::AckDeadline => {
572 return Err(AppendSessionError::AckTimeout);
573 }
574 }
575 }
576
577 input_tx_permit = input_tx.reserve(), if !resend_finished => {
578 let input_tx_permit = input_tx_permit
579 .map_err(|_| AppendSessionError::ServerDisconnected)?;
580
581 if let Some(inflight_append) = state.inflight_appends.get_mut(resend_index) {
582 inflight_append.since = Instant::now();
583 timer.as_mut().fire_at(
584 TimerEvent::AckDeadline,
585 inflight_append.since + ack_timeout,
586 CoalesceMode::Earliest,
587 );
588 input_tx_permit.send(inflight_append.input.clone());
589 resend_index += 1;
590 } else {
591 resend_finished = true;
592 }
593 }
594
595 ack = acks.next() => {
596 match ack {
597 Some(Ok(ack)) => {
598 process_ack(
599 ack,
600 state,
601 timer.as_mut(),
602 ack_timeout,
603 );
604 resend_index -= 1;
605 }
606 Some(Err(err)) => {
607 return Err(err.into());
608 }
609 None => {
610 return Err(AppendSessionError::StreamClosedEarly);
611 }
612 }
613 }
614 }
615 }
616
617 assert_eq!(
618 resend_index, 0,
619 "resend_index should be 0 after resend completes"
620 );
621 debug!("finished resending inflight appends");
622 Ok(())
623}
624
625async fn connect(
626 client: &BasinClient,
627 stream: &StreamName,
628 buffer_size: usize,
629) -> Result<(mpsc::Sender<AppendInput>, Streaming<AppendAck>), AppendSessionError> {
630 let (input_tx, input_rx) = mpsc::channel::<AppendInput>(buffer_size);
631 let ack_stream = Box::pin(
632 client
633 .append_session(stream, ReceiverStream::new(input_rx).map(|i| i.into()))
634 .await?
635 .map(|ack| match ack {
636 Ok(ack) => Ok(ack.into()),
637 Err(err) => Err(err),
638 }),
639 );
640 Ok((input_tx, ack_stream))
641}
642
643fn process_ack(
644 ack: AppendAck,
645 state: &mut SessionState,
646 timer: Pin<&mut MuxTimer<N_TIMER_VARIANTS>>,
647 ack_timeout: Duration,
648) {
649 let corresponding_append = state
650 .inflight_appends
651 .pop_front()
652 .expect("corresponding append should be present for an ack");
653
654 assert!(
655 ack.end.seq_num >= ack.start.seq_num,
656 "ack end seq_num should be greater than or equal to start seq_num"
657 );
658
659 if let Some(end) = state.prev_ack_end {
660 assert!(
661 ack.end.seq_num > end.seq_num,
662 "ack end seq_num should be greater than previous ack end"
663 );
664 }
665
666 let num_acked_records = (ack.end.seq_num - ack.start.seq_num) as usize;
667 assert_eq!(
668 num_acked_records,
669 corresponding_append.input.records.len(),
670 "ack record count should match submitted batch size"
671 );
672
673 state.total_acked_records += num_acked_records;
674 state.inflight_bytes -= corresponding_append.input_metered_bytes;
675 state.prev_ack_end = Some(ack.end);
676
677 let _ = corresponding_append.ack_tx.send(Ok(ack));
678
679 if let Some(oldest_append) = state.inflight_appends.front() {
680 timer.fire_at(
681 TimerEvent::AckDeadline,
682 oldest_append.since + ack_timeout,
683 CoalesceMode::Latest,
684 );
685 } else {
686 timer.cancel(TimerEvent::AckDeadline);
687 assert_eq!(
688 state.total_records, state.total_acked_records,
689 "all records should be acked when inflight is empty"
690 );
691 }
692}
693
694struct StashedSubmission {
695 input: AppendInput,
696 input_metered_bytes: usize,
697 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
698 permit: Option<AppendPermit>,
699 since: Instant,
700}
701
702struct InflightAppend {
703 input: AppendInput,
704 input_metered_bytes: usize,
705 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
706 since: Instant,
707 _permit: Option<AppendPermit>,
708}
709
710impl From<StashedSubmission> for InflightAppend {
711 fn from(value: StashedSubmission) -> Self {
712 Self {
713 input: value.input,
714 input_metered_bytes: value.input_metered_bytes,
715 ack_tx: value.ack_tx,
716 since: value.since,
717 _permit: value.permit,
718 }
719 }
720}
721
722fn retry_policy_compliant(
723 policy: AppendRetryPolicy,
724 inflight_appends: &VecDeque<InflightAppend>,
725) -> bool {
726 if policy == AppendRetryPolicy::All {
727 return true;
728 }
729 inflight_appends
730 .iter()
731 .all(|ia| policy.is_compliant(&ia.input))
732}
733
734enum Command {
735 Submit {
736 input: AppendInput,
737 ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
738 permit: Option<AppendPermit>,
739 },
740 Close {
741 done_tx: oneshot::Sender<Result<(), S2Error>>,
742 },
743}
744
745const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 100;
746
747#[derive(Debug, Clone, Copy, PartialEq, Eq)]
748enum TimerEvent {
749 AckDeadline,
750}
751
752const N_TIMER_VARIANTS: usize = 1;
753
754impl From<TimerEvent> for usize {
755 fn from(event: TimerEvent) -> Self {
756 match event {
757 TimerEvent::AckDeadline => 0,
758 }
759 }
760}
761
762impl From<usize> for TimerEvent {
763 fn from(value: usize) -> Self {
764 match value {
765 0 => TimerEvent::AckDeadline,
766 _ => panic!("invalid ordinal"),
767 }
768 }
769}