s2_sdk/
batching.rs

1//! Utilities for batching [AppendRecord]s.
2
3use std::{
4    pin::Pin,
5    task::{Context, Poll},
6    time::Duration,
7};
8
9use futures::{Stream, StreamExt};
10use s2_common::caps::RECORD_BATCH_MAX;
11use tokio::time::Instant;
12
13use crate::types::{
14    AppendInput, AppendRecord, AppendRecordBatch, FencingToken, MeteredBytes, ValidationError,
15};
16
17#[derive(Debug, Clone)]
18/// Configuration for batching [`AppendRecord`]s.
19pub struct BatchingConfig {
20    linger: Duration,
21    max_batch_bytes: usize,
22    max_batch_records: usize,
23}
24
25impl Default for BatchingConfig {
26    fn default() -> Self {
27        Self {
28            linger: Duration::from_millis(5),
29            max_batch_bytes: RECORD_BATCH_MAX.bytes,
30            max_batch_records: RECORD_BATCH_MAX.count,
31        }
32    }
33}
34
35impl BatchingConfig {
36    /// Create a new [`BatchingConfig`] with default settings.
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    /// Set the duration for how long to wait for more records before flushing a batch.
42    ///
43    /// Defaults to `5ms`.
44    pub fn with_linger(self, linger: Duration) -> Self {
45        Self { linger, ..self }
46    }
47
48    /// Set the maximum metered bytes per batch.
49    ///
50    /// **Note:** It must not exceed `1MiB`.
51    ///
52    /// Defaults to `1MiB`.
53    pub fn with_max_batch_bytes(self, max_batch_bytes: usize) -> Result<Self, ValidationError> {
54        if max_batch_bytes > RECORD_BATCH_MAX.bytes {
55            return Err(ValidationError(format!(
56                "max_batch_bytes ({max_batch_bytes}) exceeds {}",
57                RECORD_BATCH_MAX.bytes
58            )));
59        }
60        Ok(Self {
61            max_batch_bytes,
62            ..self
63        })
64    }
65
66    /// Set the maximum number of records per batch.
67    ///
68    /// **Note:** It must not exceed `1000`.
69    ///
70    /// Defaults to `1000`.
71    pub fn with_max_batch_records(self, max_batch_records: usize) -> Result<Self, ValidationError> {
72        if max_batch_records > RECORD_BATCH_MAX.count {
73            return Err(ValidationError(format!(
74                "max_batch_records ({max_batch_records}) exceeds {}",
75                RECORD_BATCH_MAX.count
76            )));
77        }
78        Ok(Self {
79            max_batch_records,
80            ..self
81        })
82    }
83}
84
85/// A [`Stream`] that batches [`AppendRecord`]s into [`AppendInput`]s.
86pub struct AppendInputs {
87    pub(crate) batches: AppendRecordBatches,
88    pub(crate) fencing_token: Option<FencingToken>,
89    pub(crate) match_seq_num: Option<u64>,
90}
91
92impl AppendInputs {
93    /// Create a new [`AppendInputs`] with the given records and config.
94    pub fn new(
95        records: impl Stream<Item = impl Into<AppendRecord> + Send> + Send + Unpin + 'static,
96        config: BatchingConfig,
97    ) -> Self {
98        Self {
99            batches: AppendRecordBatches::new(records, config),
100            fencing_token: None,
101            match_seq_num: None,
102        }
103    }
104
105    /// Set the fencing token for all [`AppendInput`]s.
106    pub fn with_fencing_token(self, fencing_token: FencingToken) -> Self {
107        Self {
108            fencing_token: Some(fencing_token),
109            ..self
110        }
111    }
112
113    /// Set the match sequence number for the initial [`AppendInput`]. It will be auto-incremented
114    /// for the subsequent ones.
115    pub fn with_match_seq_num(self, seq_num: u64) -> Self {
116        Self {
117            match_seq_num: Some(seq_num),
118            ..self
119        }
120    }
121}
122
123impl Stream for AppendInputs {
124    type Item = Result<AppendInput, ValidationError>;
125
126    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127        match self.batches.poll_next_unpin(cx) {
128            Poll::Ready(Some(Ok(batch))) => {
129                let match_seq_num = self.match_seq_num;
130                if let Some(seq_num) = self.match_seq_num.as_mut() {
131                    *seq_num += batch.len() as u64;
132                }
133                Poll::Ready(Some(Ok(AppendInput {
134                    records: batch,
135                    match_seq_num,
136                    fencing_token: self.fencing_token.clone(),
137                })))
138            }
139            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
140            Poll::Ready(None) => Poll::Ready(None),
141            Poll::Pending => Poll::Pending,
142        }
143    }
144}
145
146/// A [`Stream`] that batches [`AppendRecord`]s into [`AppendRecordBatch`]es.
147pub struct AppendRecordBatches {
148    inner: Pin<Box<dyn Stream<Item = Result<AppendRecordBatch, ValidationError>> + Send>>,
149}
150
151impl AppendRecordBatches {
152    /// Create a new [`AppendRecordBatches`] with the given records and config.
153    pub fn new(
154        records: impl Stream<Item = impl Into<AppendRecord> + Send> + Send + Unpin + 'static,
155        config: BatchingConfig,
156    ) -> Self {
157        Self {
158            inner: Box::pin(append_record_batches(records, config)),
159        }
160    }
161}
162
163impl Stream for AppendRecordBatches {
164    type Item = Result<AppendRecordBatch, ValidationError>;
165
166    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
167        self.inner.as_mut().poll_next(cx)
168    }
169}
170
171fn is_batch_full(config: &BatchingConfig, count: usize, bytes: usize) -> bool {
172    count >= config.max_batch_records || bytes >= config.max_batch_bytes
173}
174
175fn would_overflow_batch(
176    config: &BatchingConfig,
177    count: usize,
178    bytes: usize,
179    record: &AppendRecord,
180) -> bool {
181    count + 1 > config.max_batch_records || bytes + record.metered_bytes() > config.max_batch_bytes
182}
183
184fn append_record_batches(
185    mut records: impl Stream<Item = impl Into<AppendRecord> + Send> + Send + Unpin + 'static,
186    config: BatchingConfig,
187) -> impl Stream<Item = Result<AppendRecordBatch, ValidationError>> + Send + 'static {
188    async_stream::try_stream! {
189        let mut batch = AppendRecordBatch::with_capacity(config.max_batch_records);
190        let mut overflowed_record: Option<AppendRecord> = None;
191
192        let linger_deadline = tokio::time::sleep(config.linger);
193        tokio::pin!(linger_deadline);
194
195        'outer: loop {
196            let first_record = match overflowed_record.take() {
197                Some(record) => record,
198                None => match records.next().await {
199                    Some(item) => item.into(),
200                    None => break,
201                },
202            };
203
204            let record_bytes = first_record.metered_bytes();
205            if record_bytes > config.max_batch_bytes {
206                Err(ValidationError(format!(
207                    "record size in metered bytes ({record_bytes}) exceeds max_batch_bytes ({})",
208                    config.max_batch_bytes
209                )))?;
210            }
211            batch.push(first_record);
212
213            while !is_batch_full(&config, batch.len(), batch.metered_bytes())
214                && overflowed_record.is_none()
215            {
216                if batch.len() == 1 {
217                    linger_deadline
218                        .as_mut()
219                        .reset(Instant::now() + config.linger);
220                }
221
222                tokio::select! {
223                    next_record = records.next() => {
224                        match next_record {
225                            Some(record) => {
226                                let record: AppendRecord = record.into();
227                                if would_overflow_batch(&config, batch.len(), batch.metered_bytes(), &record) {
228                                    overflowed_record = Some(record);
229                                } else {
230                                    batch.push(record);
231                                }
232                            }
233                            None => {
234                                yield std::mem::replace(&mut batch, AppendRecordBatch::with_capacity(config.max_batch_records));
235                                break 'outer;
236                            }
237                        }
238                    },
239                    _ = &mut linger_deadline, if !batch.is_empty() => {
240                        break;
241                    }
242                };
243            }
244
245            yield std::mem::replace(
246                &mut batch,
247                AppendRecordBatch::with_capacity(config.max_batch_records),
248            );
249        }
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use assert_matches::assert_matches;
256    use futures::TryStreamExt;
257
258    use super::*;
259
260    #[tokio::test]
261    async fn batches_should_be_empty_when_record_stream_is_empty() {
262        let batches: Vec<_> = AppendRecordBatches::new(
263            futures::stream::iter::<Vec<AppendRecord>>(vec![]),
264            BatchingConfig::default(),
265        )
266        .collect()
267        .await;
268        assert_eq!(batches.len(), 0);
269    }
270
271    #[tokio::test]
272    async fn batches_respect_count_limit() -> Result<(), ValidationError> {
273        let records: Vec<_> = (0..10)
274            .map(|i| AppendRecord::new(format!("record{i}")))
275            .collect::<Result<_, _>>()?;
276        let config = BatchingConfig::default().with_max_batch_records(3)?;
277        let batches: Vec<_> = AppendRecordBatches::new(futures::stream::iter(records), config)
278            .try_collect()
279            .await?;
280
281        assert_eq!(batches.len(), 4);
282        assert_eq!(batches[0].len(), 3);
283        assert_eq!(batches[1].len(), 3);
284        assert_eq!(batches[2].len(), 3);
285        assert_eq!(batches[3].len(), 1);
286
287        Ok(())
288    }
289
290    #[tokio::test]
291    async fn batches_respect_bytes_limit() -> Result<(), ValidationError> {
292        let records: Vec<_> = (0..10)
293            .map(|i| AppendRecord::new(format!("record{i}")))
294            .collect::<Result<_, _>>()?;
295        let single_record_bytes = records[0].metered_bytes();
296        let max_batch_bytes = single_record_bytes * 3;
297
298        let config = BatchingConfig::default().with_max_batch_bytes(max_batch_bytes)?;
299        let batches: Vec<_> = AppendRecordBatches::new(futures::stream::iter(records), config)
300            .try_collect()
301            .await?;
302
303        assert_eq!(batches.len(), 4);
304        assert_eq!(batches[0].metered_bytes(), max_batch_bytes);
305        assert_eq!(batches[1].metered_bytes(), max_batch_bytes);
306        assert_eq!(batches[2].metered_bytes(), max_batch_bytes);
307        assert_eq!(batches[3].metered_bytes(), single_record_bytes);
308
309        Ok(())
310    }
311
312    #[tokio::test]
313    async fn batching_should_error_when_it_sees_oversized_record() -> Result<(), ValidationError> {
314        let record = AppendRecord::new("hello")?;
315        let record_bytes = record.metered_bytes();
316        let max_batch_bytes = 1;
317
318        let config = BatchingConfig::default().with_max_batch_bytes(max_batch_bytes)?;
319        let results: Vec<_> = AppendRecordBatches::new(futures::stream::iter(vec![record]), config)
320            .collect()
321            .await;
322
323        assert_eq!(results.len(), 1);
324        assert_matches!(&results[0], Err(err) => {
325            assert_eq!(
326                err.to_string(),
327                format!("record size in metered bytes ({record_bytes}) exceeds max_batch_bytes ({max_batch_bytes})")
328            );
329        });
330
331        Ok(())
332    }
333}