s2_sdk/session/
append.rs

1use std::{
2    collections::VecDeque,
3    future::Future,
4    num::NonZeroU32,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8    time::Duration,
9};
10
11use futures::StreamExt;
12use tokio::{
13    sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot},
14    time::Instant,
15};
16use tokio_muxt::{CoalesceMode, MuxTimer};
17use tokio_stream::wrappers::ReceiverStream;
18use tokio_util::task::AbortOnDropHandle;
19use tracing::debug;
20
21use crate::{
22    api::{ApiError, BasinClient, Streaming, retry_builder},
23    retry::RetryBackoffBuilder,
24    types::{
25        AppendAck, AppendInput, AppendRetryPolicy, MeteredBytes, ONE_MIB, S2Error, StreamName,
26        StreamPosition, ValidationError,
27    },
28};
29
30#[derive(Debug, thiserror::Error)]
31pub enum AppendSessionError {
32    #[error(transparent)]
33    Api(#[from] ApiError),
34    #[error("append acknowledgement timed out")]
35    AckTimeout,
36    #[error("server disconnected")]
37    ServerDisconnected,
38    #[error("response stream closed early while appends in flight")]
39    StreamClosedEarly,
40    #[error("session already closed")]
41    SessionClosed,
42    #[error("session is closing")]
43    SessionClosing,
44    #[error("session dropped without calling close")]
45    SessionDropped,
46}
47
48impl AppendSessionError {
49    pub fn is_retryable(&self) -> bool {
50        match self {
51            Self::Api(err) => err.is_retryable(),
52            Self::AckTimeout => true,
53            Self::ServerDisconnected => true,
54            _ => false,
55        }
56    }
57}
58
59impl From<AppendSessionError> for S2Error {
60    fn from(err: AppendSessionError) -> Self {
61        match err {
62            AppendSessionError::Api(api_err) => api_err.into(),
63            other => S2Error::Client(other.to_string()),
64        }
65    }
66}
67
68/// A [`Future`] that resolves to an acknowledgement once the batch of records is appended.
69pub struct BatchSubmitTicket {
70    rx: oneshot::Receiver<Result<AppendAck, S2Error>>,
71}
72
73impl Future for BatchSubmitTicket {
74    type Output = Result<AppendAck, S2Error>;
75
76    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
77        match Pin::new(&mut self.rx).poll(cx) {
78            Poll::Ready(Ok(res)) => Poll::Ready(res),
79            Poll::Ready(Err(_)) => Poll::Ready(Err(AppendSessionError::SessionDropped.into())),
80            Poll::Pending => Poll::Pending,
81        }
82    }
83}
84
85#[derive(Debug, Clone)]
86/// Configuration for an [`AppendSession`].
87pub struct AppendSessionConfig {
88    max_inflight_bytes: u32,
89    max_inflight_batches: Option<u32>,
90}
91
92impl Default for AppendSessionConfig {
93    fn default() -> Self {
94        Self {
95            max_inflight_bytes: 10 * ONE_MIB,
96            max_inflight_batches: None,
97        }
98    }
99}
100
101impl AppendSessionConfig {
102    /// Create a new [`AppendSessionConfig`] with default settings.
103    pub fn new() -> Self {
104        Self::default()
105    }
106
107    /// Set the limit on total metered bytes of unacknowledged [`AppendInput`]s held in memory.
108    ///
109    /// **Note:** It must be at least `1MiB`.
110    ///
111    /// Defaults to `10MiB`.
112    pub fn with_max_inflight_bytes(self, max_inflight_bytes: u32) -> Result<Self, ValidationError> {
113        if max_inflight_bytes < ONE_MIB {
114            return Err(format!("max_inflight_bytes must be at least {ONE_MIB}").into());
115        }
116        Ok(Self {
117            max_inflight_bytes,
118            ..self
119        })
120    }
121
122    /// Set the limit on number of unacknowledged [`AppendInput`]s held in memory.
123    ///
124    /// Defaults to no limit.
125    pub fn with_max_inflight_batches(
126        self,
127        max_inflight_batches: NonZeroU32,
128    ) -> Result<Self, ValidationError> {
129        Ok(Self {
130            max_inflight_batches: Some(max_inflight_batches.get()),
131            ..self
132        })
133    }
134}
135
136struct SessionState {
137    cmd_rx: mpsc::Receiver<Command>,
138    inflight_appends: VecDeque<InflightAppend>,
139    inflight_bytes: usize,
140    close_tx: Option<oneshot::Sender<Result<(), S2Error>>>,
141    closing: bool,
142    total_records: usize,
143    total_acked_records: usize,
144    prev_ack_end: Option<StreamPosition>,
145}
146
147/// A session for high-throughput appending with backpressure control. It can be created from
148/// [`append_session`](crate::S2Stream::append_session).
149///
150/// Supports pipelining multiple [`AppendInput`]s while preserving submission order.
151pub struct AppendSession {
152    cmd_tx: mpsc::Sender<Command>,
153    permits: AppendPermits,
154    _handle: AbortOnDropHandle<()>,
155}
156
157impl AppendSession {
158    pub(crate) fn new(
159        client: BasinClient,
160        stream: StreamName,
161        config: AppendSessionConfig,
162    ) -> Self {
163        let buffer_size = config
164            .max_inflight_batches
165            .map(|mib| mib as usize)
166            .unwrap_or(DEFAULT_CHANNEL_BUFFER_SIZE);
167        let (cmd_tx, cmd_rx) = mpsc::channel(buffer_size);
168        let permits = AppendPermits::new(config.max_inflight_batches, config.max_inflight_bytes);
169        let retry_builder = retry_builder(&client.config.retry);
170        let handle = AbortOnDropHandle::new(tokio::spawn(run_session_with_retry(
171            client,
172            stream,
173            cmd_rx,
174            retry_builder,
175            buffer_size,
176        )));
177        Self {
178            cmd_tx,
179            permits,
180            _handle: handle,
181        }
182    }
183
184    /// Submit a batch of records for appending.
185    ///
186    /// Internally, it waits on [`reserve`](Self::reserve), then submits using the permit.
187    /// This provides backpressure when inflight limits are reached.
188    /// For explicit control, use [`reserve`](Self::reserve) followed by
189    /// [`BatchSubmitPermit::submit`].
190    ///
191    /// **Note**: After all submits, you must call [`close`](Self::close) to ensure all batches are
192    /// appended.
193    pub async fn submit(&self, input: AppendInput) -> Result<BatchSubmitTicket, S2Error> {
194        let permit = self.reserve(input.records.metered_bytes() as u32).await?;
195        Ok(permit.submit(input))
196    }
197
198    /// Reserve capacity for a batch to be submitted. Useful in [`select!`](tokio::select) loops
199    /// where you want to interleave submission with other async work. See [`submit`](Self::submit)
200    /// for a simpler API.
201    ///
202    /// Waits when inflight limits are reached, providing explicit backpressure control.
203    /// The returned permit must be used to submit the batch.
204    ///
205    /// **Note**: After all submits, you must call [`close`](Self::close) to ensure all batches are
206    /// appended.
207    ///
208    /// # Cancel safety
209    ///
210    /// This method is cancel safe. Internally, it only awaits
211    /// [`Semaphore::acquire_many_owned`](tokio::sync::Semaphore::acquire_many_owned) and
212    /// [`Sender::reserve_owned`](tokio::sync::mpsc::Sender::reserve), both of which are cancel
213    /// safe.
214    pub async fn reserve(&self, bytes: u32) -> Result<BatchSubmitPermit, S2Error> {
215        let append_permit = self.permits.acquire(bytes).await;
216        let cmd_tx_permit = self
217            .cmd_tx
218            .clone()
219            .reserve_owned()
220            .await
221            .map_err(|_| AppendSessionError::SessionClosed)?;
222        Ok(BatchSubmitPermit {
223            append_permit,
224            cmd_tx_permit,
225        })
226    }
227
228    /// Close the session and wait for all submitted batch of records to be appended.
229    pub async fn close(self) -> Result<(), S2Error> {
230        let (done_tx, done_rx) = oneshot::channel();
231        self.cmd_tx
232            .send(Command::Close { done_tx })
233            .await
234            .map_err(|_| AppendSessionError::SessionClosed)?;
235        done_rx
236            .await
237            .map_err(|_| AppendSessionError::SessionClosed)??;
238        Ok(())
239    }
240}
241
242/// A permit to submit a batch after reserving capacity.
243pub struct BatchSubmitPermit {
244    append_permit: AppendPermit,
245    cmd_tx_permit: mpsc::OwnedPermit<Command>,
246}
247
248impl BatchSubmitPermit {
249    /// Submit the batch using this permit.
250    pub fn submit(self, input: AppendInput) -> BatchSubmitTicket {
251        let (ack_tx, ack_rx) = oneshot::channel();
252        self.cmd_tx_permit.send(Command::Submit {
253            input,
254            ack_tx,
255            permit: Some(self.append_permit),
256        });
257        BatchSubmitTicket { rx: ack_rx }
258    }
259}
260
261pub(crate) struct AppendSessionInternal {
262    cmd_tx: mpsc::Sender<Command>,
263    _handle: AbortOnDropHandle<()>,
264}
265
266impl AppendSessionInternal {
267    pub(crate) fn new(client: BasinClient, stream: StreamName) -> Self {
268        let buffer_size = DEFAULT_CHANNEL_BUFFER_SIZE;
269        let (cmd_tx, cmd_rx) = mpsc::channel(buffer_size);
270        let retry_builder = retry_builder(&client.config.retry);
271        let handle = AbortOnDropHandle::new(tokio::spawn(run_session_with_retry(
272            client,
273            stream,
274            cmd_rx,
275            retry_builder,
276            buffer_size,
277        )));
278        Self {
279            cmd_tx,
280            _handle: handle,
281        }
282    }
283
284    pub(crate) fn submit(
285        &self,
286        input: AppendInput,
287    ) -> impl Future<Output = Result<BatchSubmitTicket, S2Error>> + Send + 'static {
288        let cmd_tx = self.cmd_tx.clone();
289        async move {
290            let (ack_tx, ack_rx) = oneshot::channel();
291            cmd_tx
292                .send(Command::Submit {
293                    input,
294                    ack_tx,
295                    permit: None,
296                })
297                .await
298                .map_err(|_| AppendSessionError::SessionClosed)?;
299            Ok(BatchSubmitTicket { rx: ack_rx })
300        }
301    }
302
303    pub(crate) async fn close(self) -> Result<(), S2Error> {
304        let (done_tx, done_rx) = oneshot::channel();
305        self.cmd_tx
306            .send(Command::Close { done_tx })
307            .await
308            .map_err(|_| AppendSessionError::SessionClosed)?;
309        done_rx
310            .await
311            .map_err(|_| AppendSessionError::SessionClosed)??;
312        Ok(())
313    }
314}
315
316#[derive(Debug)]
317pub(crate) struct AppendPermit {
318    _count: Option<OwnedSemaphorePermit>,
319    _bytes: OwnedSemaphorePermit,
320}
321
322#[derive(Clone)]
323pub(crate) struct AppendPermits {
324    count: Option<Arc<Semaphore>>,
325    bytes: Arc<Semaphore>,
326}
327
328impl AppendPermits {
329    pub(crate) fn new(count_permits: Option<u32>, bytes_permits: u32) -> Self {
330        Self {
331            count: count_permits.map(|permits| Arc::new(Semaphore::new(permits as usize))),
332            bytes: Arc::new(Semaphore::new(bytes_permits as usize)),
333        }
334    }
335
336    pub(crate) async fn acquire(&self, bytes: u32) -> AppendPermit {
337        AppendPermit {
338            _count: if let Some(count) = self.count.as_ref() {
339                Some(
340                    count
341                        .clone()
342                        .acquire_many_owned(1)
343                        .await
344                        .expect("semaphore should not be closed"),
345                )
346            } else {
347                None
348            },
349            _bytes: self
350                .bytes
351                .clone()
352                .acquire_many_owned(bytes)
353                .await
354                .expect("semaphore should not be closed"),
355        }
356    }
357}
358
359async fn run_session_with_retry(
360    client: BasinClient,
361    stream: StreamName,
362    cmd_rx: mpsc::Receiver<Command>,
363    retry_builder: RetryBackoffBuilder,
364    buffer_size: usize,
365) {
366    let mut state = SessionState {
367        cmd_rx,
368        inflight_appends: VecDeque::new(),
369        inflight_bytes: 0,
370        close_tx: None,
371        closing: false,
372        total_records: 0,
373        total_acked_records: 0,
374        prev_ack_end: None,
375    };
376    let mut prev_total_acked_records = 0;
377    let mut retry_backoffs = retry_builder.build();
378
379    loop {
380        let result = run_session(&client, &stream, &mut state, buffer_size).await;
381
382        match result {
383            Ok(()) => {
384                break;
385            }
386            Err(err) => {
387                if prev_total_acked_records < state.total_acked_records {
388                    prev_total_acked_records = state.total_acked_records;
389                    retry_backoffs.reset();
390                }
391
392                let retry_policy_compliant = retry_policy_compliant(
393                    client.config.retry.append_retry_policy,
394                    &state.inflight_appends,
395                );
396
397                if retry_policy_compliant
398                    && err.is_retryable()
399                    && let Some(backoff) = retry_backoffs.next()
400                {
401                    debug!(
402                        %err,
403                        ?backoff,
404                        num_retries_remaining = retry_backoffs.remaining(),
405                        "retrying append session"
406                    );
407                    tokio::time::sleep(backoff).await;
408                } else {
409                    debug!(
410                        %err,
411                        retry_policy_compliant,
412                        retries_exhausted = retry_backoffs.is_exhausted(),
413                        "not retrying append session"
414                    );
415
416                    let err: S2Error = err.into();
417                    for inflight_append in state.inflight_appends.drain(..) {
418                        let _ = inflight_append.ack_tx.send(Err(err.clone()));
419                    }
420
421                    if let Some(done_tx) = state.close_tx.take() {
422                        let _ = done_tx.send(Err(err));
423                    }
424                    break;
425                }
426            }
427        }
428    }
429
430    if let Some(done_tx) = state.close_tx.take() {
431        let _ = done_tx.send(Ok(()));
432    }
433}
434
435async fn run_session(
436    client: &BasinClient,
437    stream: &StreamName,
438    state: &mut SessionState,
439    buffer_size: usize,
440) -> Result<(), AppendSessionError> {
441    let (input_tx, mut acks) = connect(client, stream, buffer_size).await?;
442    let ack_timeout = client.config.request_timeout;
443
444    if !state.inflight_appends.is_empty() {
445        resend(state, &input_tx, &mut acks, ack_timeout).await?;
446
447        assert!(state.inflight_appends.is_empty());
448        assert_eq!(state.inflight_bytes, 0);
449    }
450
451    let timer = MuxTimer::<N_TIMER_VARIANTS>::default();
452    tokio::pin!(timer);
453
454    let mut stashed_submission: Option<StashedSubmission> = None;
455
456    loop {
457        tokio::select! {
458            (event_ord, _deadline) = &mut timer, if timer.is_armed() => {
459                match TimerEvent::from(event_ord) {
460                    TimerEvent::AckDeadline => {
461                        return Err(AppendSessionError::AckTimeout);
462                    }
463                }
464            }
465
466            input_tx_permit = input_tx.reserve(), if stashed_submission.is_some() => {
467                let input_tx_permit = input_tx_permit
468                    .map_err(|_| AppendSessionError::ServerDisconnected)?;
469                let submission = stashed_submission
470                    .take()
471                    .expect("stashed_submission should not be None");
472
473                input_tx_permit.send(submission.input.clone());
474
475                state.total_records += submission.input.records.len();
476                state.inflight_bytes += submission.input_metered_bytes;
477
478                timer.as_mut().fire_at(
479                    TimerEvent::AckDeadline,
480                    submission.since + ack_timeout,
481                    CoalesceMode::Earliest,
482                );
483                state.inflight_appends.push_back(submission.into());
484            }
485
486            cmd = state.cmd_rx.recv(), if stashed_submission.is_none() => {
487                match cmd {
488                    Some(Command::Submit { input, ack_tx, permit }) => {
489                        if state.closing {
490                            let _ = ack_tx.send(
491                                Err(AppendSessionError::SessionClosing.into())
492                            );
493                        } else {
494                            let input_metered_bytes = input.records.metered_bytes();
495                            stashed_submission = Some(StashedSubmission {
496                                input,
497                                input_metered_bytes,
498                                ack_tx,
499                                permit,
500                                since: Instant::now(),
501                            });
502                        }
503                    }
504                    Some(Command::Close { done_tx }) => {
505                        state.closing = true;
506                        state.close_tx = Some(done_tx);
507                    }
508                    None => {
509                        return Err(AppendSessionError::SessionDropped);
510                    }
511                }
512            }
513
514            ack = acks.next() => {
515                match ack {
516                    Some(Ok(ack)) => {
517                        process_ack(
518                            ack,
519                            state,
520                            timer.as_mut(),
521                            ack_timeout,
522                        );
523                    }
524                    Some(Err(err)) => {
525                        return Err(err.into());
526                    }
527                    None => {
528                        if !state.inflight_appends.is_empty() || stashed_submission.is_some() {
529                            return Err(AppendSessionError::StreamClosedEarly);
530                        }
531                        break;
532                    }
533                }
534            }
535        }
536
537        if state.closing && state.inflight_appends.is_empty() && stashed_submission.is_none() {
538            break;
539        }
540    }
541
542    assert!(state.inflight_appends.is_empty());
543    assert_eq!(state.inflight_bytes, 0);
544    assert!(stashed_submission.is_none());
545
546    Ok(())
547}
548
549async fn resend(
550    state: &mut SessionState,
551    input_tx: &mpsc::Sender<AppendInput>,
552    acks: &mut Streaming<AppendAck>,
553    ack_timeout: Duration,
554) -> Result<(), AppendSessionError> {
555    debug!(
556        inflight_appends_len = state.inflight_appends.len(),
557        inflight_bytes = state.inflight_bytes,
558        "resending inflight appends"
559    );
560
561    let mut resend_index = 0;
562    let mut resend_finished = false;
563
564    let timer = MuxTimer::<N_TIMER_VARIANTS>::default();
565    tokio::pin!(timer);
566
567    while !state.inflight_appends.is_empty() {
568        tokio::select! {
569            (event_ord, _deadline) = &mut timer, if timer.is_armed() => {
570                match TimerEvent::from(event_ord) {
571                    TimerEvent::AckDeadline => {
572                        return Err(AppendSessionError::AckTimeout);
573                    }
574                }
575            }
576
577            input_tx_permit = input_tx.reserve(), if !resend_finished => {
578                let input_tx_permit = input_tx_permit
579                    .map_err(|_| AppendSessionError::ServerDisconnected)?;
580
581                if let Some(inflight_append) = state.inflight_appends.get_mut(resend_index) {
582                    inflight_append.since = Instant::now();
583                    timer.as_mut().fire_at(
584                        TimerEvent::AckDeadline,
585                        inflight_append.since + ack_timeout,
586                        CoalesceMode::Earliest,
587                    );
588                    input_tx_permit.send(inflight_append.input.clone());
589                    resend_index += 1;
590                } else {
591                    resend_finished = true;
592                }
593            }
594
595            ack = acks.next() => {
596                match ack {
597                    Some(Ok(ack)) => {
598                        process_ack(
599                            ack,
600                            state,
601                            timer.as_mut(),
602                            ack_timeout,
603                        );
604                        resend_index -= 1;
605                    }
606                    Some(Err(err)) => {
607                        return Err(err.into());
608                    }
609                    None => {
610                        return Err(AppendSessionError::StreamClosedEarly);
611                    }
612                }
613            }
614        }
615    }
616
617    assert_eq!(
618        resend_index, 0,
619        "resend_index should be 0 after resend completes"
620    );
621    debug!("finished resending inflight appends");
622    Ok(())
623}
624
625async fn connect(
626    client: &BasinClient,
627    stream: &StreamName,
628    buffer_size: usize,
629) -> Result<(mpsc::Sender<AppendInput>, Streaming<AppendAck>), AppendSessionError> {
630    let (input_tx, input_rx) = mpsc::channel::<AppendInput>(buffer_size);
631    let ack_stream = Box::pin(
632        client
633            .append_session(stream, ReceiverStream::new(input_rx).map(|i| i.into()))
634            .await?
635            .map(|ack| match ack {
636                Ok(ack) => Ok(ack.into()),
637                Err(err) => Err(err),
638            }),
639    );
640    Ok((input_tx, ack_stream))
641}
642
643fn process_ack(
644    ack: AppendAck,
645    state: &mut SessionState,
646    timer: Pin<&mut MuxTimer<N_TIMER_VARIANTS>>,
647    ack_timeout: Duration,
648) {
649    let corresponding_append = state
650        .inflight_appends
651        .pop_front()
652        .expect("corresponding append should be present for an ack");
653
654    assert!(
655        ack.end.seq_num >= ack.start.seq_num,
656        "ack end seq_num should be greater than or equal to start seq_num"
657    );
658
659    if let Some(end) = state.prev_ack_end {
660        assert!(
661            ack.end.seq_num > end.seq_num,
662            "ack end seq_num should be greater than previous ack end"
663        );
664    }
665
666    let num_acked_records = (ack.end.seq_num - ack.start.seq_num) as usize;
667    assert_eq!(
668        num_acked_records,
669        corresponding_append.input.records.len(),
670        "ack record count should match submitted batch size"
671    );
672
673    state.total_acked_records += num_acked_records;
674    state.inflight_bytes -= corresponding_append.input_metered_bytes;
675    state.prev_ack_end = Some(ack.end);
676
677    let _ = corresponding_append.ack_tx.send(Ok(ack));
678
679    if let Some(oldest_append) = state.inflight_appends.front() {
680        timer.fire_at(
681            TimerEvent::AckDeadline,
682            oldest_append.since + ack_timeout,
683            CoalesceMode::Latest,
684        );
685    } else {
686        timer.cancel(TimerEvent::AckDeadline);
687        assert_eq!(
688            state.total_records, state.total_acked_records,
689            "all records should be acked when inflight is empty"
690        );
691    }
692}
693
694struct StashedSubmission {
695    input: AppendInput,
696    input_metered_bytes: usize,
697    ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
698    permit: Option<AppendPermit>,
699    since: Instant,
700}
701
702struct InflightAppend {
703    input: AppendInput,
704    input_metered_bytes: usize,
705    ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
706    since: Instant,
707    _permit: Option<AppendPermit>,
708}
709
710impl From<StashedSubmission> for InflightAppend {
711    fn from(value: StashedSubmission) -> Self {
712        Self {
713            input: value.input,
714            input_metered_bytes: value.input_metered_bytes,
715            ack_tx: value.ack_tx,
716            since: value.since,
717            _permit: value.permit,
718        }
719    }
720}
721
722fn retry_policy_compliant(
723    policy: AppendRetryPolicy,
724    inflight_appends: &VecDeque<InflightAppend>,
725) -> bool {
726    if policy == AppendRetryPolicy::All {
727        return true;
728    }
729    inflight_appends
730        .iter()
731        .all(|ia| policy.is_compliant(&ia.input))
732}
733
734enum Command {
735    Submit {
736        input: AppendInput,
737        ack_tx: oneshot::Sender<Result<AppendAck, S2Error>>,
738        permit: Option<AppendPermit>,
739    },
740    Close {
741        done_tx: oneshot::Sender<Result<(), S2Error>>,
742    },
743}
744
745const DEFAULT_CHANNEL_BUFFER_SIZE: usize = 100;
746
747#[derive(Debug, Clone, Copy, PartialEq, Eq)]
748enum TimerEvent {
749    AckDeadline,
750}
751
752const N_TIMER_VARIANTS: usize = 1;
753
754impl From<TimerEvent> for usize {
755    fn from(event: TimerEvent) -> Self {
756        match event {
757            TimerEvent::AckDeadline => 0,
758        }
759    }
760}
761
762impl From<usize> for TimerEvent {
763    fn from(value: usize) -> Self {
764        match value {
765            0 => TimerEvent::AckDeadline,
766            _ => panic!("invalid ordinal"),
767        }
768    }
769}