Skip to main content

s2_sdk/
batching.rs

1//! Utilities for batching [AppendRecord]s.
2
3use std::{
4    pin::Pin,
5    task::{Context, Poll},
6    time::Duration,
7};
8
9use futures::{Stream, StreamExt};
10use s2_common::{caps::RECORD_BATCH_MAX, read_extent::CountOrBytes};
11use tokio::time::Instant;
12
13use crate::types::{
14    AppendInput, AppendRecord, AppendRecordBatch, FencingToken, MeteredBytes, ValidationError,
15};
16
17const RECORD_BATCH_MIN: CountOrBytes = CountOrBytes { count: 1, bytes: 8 };
18
19#[derive(Debug, Clone)]
20/// Configuration for batching [`AppendRecord`]s.
21pub struct BatchingConfig {
22    linger: Duration,
23    max_batch_bytes: usize,
24    max_batch_records: usize,
25}
26
27impl Default for BatchingConfig {
28    fn default() -> Self {
29        Self {
30            linger: Duration::from_millis(5),
31            max_batch_bytes: RECORD_BATCH_MAX.bytes,
32            max_batch_records: RECORD_BATCH_MAX.count,
33        }
34    }
35}
36
37impl BatchingConfig {
38    /// Create a new [`BatchingConfig`] with default settings.
39    pub fn new() -> Self {
40        Self::default()
41    }
42
43    /// Set the duration for how long to wait for more records before flushing a batch.
44    ///
45    /// Defaults to `5ms`.
46    pub fn with_linger(self, linger: Duration) -> Self {
47        Self { linger, ..self }
48    }
49
50    /// Set the maximum metered bytes per batch.
51    ///
52    /// **Note:** It must be at least `8B` and must not exceed `1MiB`.
53    ///
54    /// Defaults to `1MiB`.
55    pub fn with_max_batch_bytes(self, max_batch_bytes: usize) -> Result<Self, ValidationError> {
56        if max_batch_bytes < RECORD_BATCH_MIN.bytes {
57            return Err(ValidationError(format!(
58                "max_batch_bytes ({max_batch_bytes}) must be at least {}",
59                RECORD_BATCH_MIN.bytes
60            )));
61        }
62        if max_batch_bytes > RECORD_BATCH_MAX.bytes {
63            return Err(ValidationError(format!(
64                "max_batch_bytes ({max_batch_bytes}) must not exceed {}",
65                RECORD_BATCH_MAX.bytes
66            )));
67        }
68        Ok(Self {
69            max_batch_bytes,
70            ..self
71        })
72    }
73
74    /// Set the maximum number of records per batch.
75    ///
76    /// **Note:** It must be at least `1` and must not exceed `1000`.
77    ///
78    /// Defaults to `1000`.
79    pub fn with_max_batch_records(self, max_batch_records: usize) -> Result<Self, ValidationError> {
80        if max_batch_records < RECORD_BATCH_MIN.count {
81            return Err(ValidationError(format!(
82                "max_batch_records ({max_batch_records}) must be at least {}",
83                RECORD_BATCH_MIN.count
84            )));
85        }
86        if max_batch_records > RECORD_BATCH_MAX.count {
87            return Err(ValidationError(format!(
88                "max_batch_records ({max_batch_records}) must not exceed {}",
89                RECORD_BATCH_MAX.count
90            )));
91        }
92        Ok(Self {
93            max_batch_records,
94            ..self
95        })
96    }
97}
98
99/// A [`Stream`] that batches [`AppendRecord`]s into [`AppendInput`]s.
100pub struct AppendInputs {
101    pub(crate) batches: AppendRecordBatches,
102    pub(crate) fencing_token: Option<FencingToken>,
103    pub(crate) match_seq_num: Option<u64>,
104}
105
106impl AppendInputs {
107    /// Create a new [`AppendInputs`] with the given records and config.
108    pub fn new(
109        records: impl Stream<Item = impl Into<AppendRecord> + Send> + Send + Unpin + 'static,
110        config: BatchingConfig,
111    ) -> Self {
112        Self {
113            batches: AppendRecordBatches::new(records, config),
114            fencing_token: None,
115            match_seq_num: None,
116        }
117    }
118
119    /// Set the fencing token for all [`AppendInput`]s.
120    pub fn with_fencing_token(self, fencing_token: FencingToken) -> Self {
121        Self {
122            fencing_token: Some(fencing_token),
123            ..self
124        }
125    }
126
127    /// Set the match sequence number for the initial [`AppendInput`]. It will be auto-incremented
128    /// for the subsequent ones.
129    pub fn with_match_seq_num(self, seq_num: u64) -> Self {
130        Self {
131            match_seq_num: Some(seq_num),
132            ..self
133        }
134    }
135}
136
137impl Stream for AppendInputs {
138    type Item = Result<AppendInput, ValidationError>;
139
140    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
141        match self.batches.poll_next_unpin(cx) {
142            Poll::Ready(Some(Ok(batch))) => {
143                let match_seq_num = self.match_seq_num;
144                if let Some(seq_num) = self.match_seq_num.as_mut() {
145                    *seq_num += batch.len() as u64;
146                }
147                Poll::Ready(Some(Ok(AppendInput {
148                    records: batch,
149                    match_seq_num,
150                    fencing_token: self.fencing_token.clone(),
151                })))
152            }
153            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
154            Poll::Ready(None) => Poll::Ready(None),
155            Poll::Pending => Poll::Pending,
156        }
157    }
158}
159
160/// A [`Stream`] that batches [`AppendRecord`]s into [`AppendRecordBatch`]es.
161pub struct AppendRecordBatches {
162    inner: Pin<Box<dyn Stream<Item = Result<AppendRecordBatch, ValidationError>> + Send>>,
163}
164
165impl AppendRecordBatches {
166    /// Create a new [`AppendRecordBatches`] with the given records and config.
167    pub fn new(
168        records: impl Stream<Item = impl Into<AppendRecord> + Send> + Send + Unpin + 'static,
169        config: BatchingConfig,
170    ) -> Self {
171        Self {
172            inner: Box::pin(append_record_batches(records, config)),
173        }
174    }
175}
176
177impl Stream for AppendRecordBatches {
178    type Item = Result<AppendRecordBatch, ValidationError>;
179
180    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
181        self.inner.as_mut().poll_next(cx)
182    }
183}
184
185fn is_batch_full(config: &BatchingConfig, count: usize, bytes: usize) -> bool {
186    count >= config.max_batch_records || bytes >= config.max_batch_bytes
187}
188
189fn would_overflow_batch(
190    config: &BatchingConfig,
191    count: usize,
192    bytes: usize,
193    record: &AppendRecord,
194) -> bool {
195    count + 1 > config.max_batch_records || bytes + record.metered_bytes() > config.max_batch_bytes
196}
197
198fn append_record_batches(
199    mut records: impl Stream<Item = impl Into<AppendRecord> + Send> + Send + Unpin + 'static,
200    config: BatchingConfig,
201) -> impl Stream<Item = Result<AppendRecordBatch, ValidationError>> + Send + 'static {
202    async_stream::try_stream! {
203        let mut batch = AppendRecordBatch::with_capacity(config.max_batch_records);
204        let mut overflowed_record: Option<AppendRecord> = None;
205
206        let linger_deadline = tokio::time::sleep(config.linger);
207        tokio::pin!(linger_deadline);
208
209        'outer: loop {
210            let first_record = match overflowed_record.take() {
211                Some(record) => record,
212                None => match records.next().await {
213                    Some(item) => item.into(),
214                    None => break,
215                },
216            };
217
218            let record_bytes = first_record.metered_bytes();
219            if record_bytes > config.max_batch_bytes {
220                Err(ValidationError(format!(
221                    "record size in metered bytes ({record_bytes}) exceeds max_batch_bytes ({})",
222                    config.max_batch_bytes
223                )))?;
224            }
225            batch.push(first_record);
226
227            while !is_batch_full(&config, batch.len(), batch.metered_bytes())
228                && overflowed_record.is_none()
229            {
230                if batch.len() == 1 {
231                    linger_deadline
232                        .as_mut()
233                        .reset(Instant::now() + config.linger);
234                }
235
236                tokio::select! {
237                    next_record = records.next() => {
238                        match next_record {
239                            Some(record) => {
240                                let record: AppendRecord = record.into();
241                                if would_overflow_batch(&config, batch.len(), batch.metered_bytes(), &record) {
242                                    overflowed_record = Some(record);
243                                } else {
244                                    batch.push(record);
245                                }
246                            }
247                            None => {
248                                yield std::mem::replace(&mut batch, AppendRecordBatch::with_capacity(config.max_batch_records));
249                                break 'outer;
250                            }
251                        }
252                    },
253                    _ = &mut linger_deadline, if !batch.is_empty() => {
254                        break;
255                    }
256                };
257            }
258
259            yield std::mem::replace(
260                &mut batch,
261                AppendRecordBatch::with_capacity(config.max_batch_records),
262            );
263        }
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use assert_matches::assert_matches;
270    use futures::TryStreamExt;
271
272    use super::*;
273
274    #[tokio::test]
275    async fn batches_should_be_empty_when_record_stream_is_empty() {
276        let batches: Vec<_> = AppendRecordBatches::new(
277            futures::stream::iter::<Vec<AppendRecord>>(vec![]),
278            BatchingConfig::default(),
279        )
280        .collect()
281        .await;
282        assert_eq!(batches.len(), 0);
283    }
284
285    #[tokio::test]
286    async fn batches_respect_count_limit() -> Result<(), ValidationError> {
287        let records: Vec<_> = (0..10)
288            .map(|i| AppendRecord::new(format!("record{i}")))
289            .collect::<Result<_, _>>()?;
290        let config = BatchingConfig::default().with_max_batch_records(3)?;
291        let batches: Vec<_> = AppendRecordBatches::new(futures::stream::iter(records), config)
292            .try_collect()
293            .await?;
294
295        assert_eq!(batches.len(), 4);
296        assert_eq!(batches[0].len(), 3);
297        assert_eq!(batches[1].len(), 3);
298        assert_eq!(batches[2].len(), 3);
299        assert_eq!(batches[3].len(), 1);
300
301        Ok(())
302    }
303
304    #[tokio::test]
305    async fn batches_respect_bytes_limit() -> Result<(), ValidationError> {
306        let records: Vec<_> = (0..10)
307            .map(|i| AppendRecord::new(format!("record{i}")))
308            .collect::<Result<_, _>>()?;
309        let single_record_bytes = records[0].metered_bytes();
310        let max_batch_bytes = single_record_bytes * 3;
311
312        let config = BatchingConfig::default().with_max_batch_bytes(max_batch_bytes)?;
313        let batches: Vec<_> = AppendRecordBatches::new(futures::stream::iter(records), config)
314            .try_collect()
315            .await?;
316
317        assert_eq!(batches.len(), 4);
318        assert_eq!(batches[0].metered_bytes(), max_batch_bytes);
319        assert_eq!(batches[1].metered_bytes(), max_batch_bytes);
320        assert_eq!(batches[2].metered_bytes(), max_batch_bytes);
321        assert_eq!(batches[3].metered_bytes(), single_record_bytes);
322
323        Ok(())
324    }
325
326    #[tokio::test]
327    async fn batching_should_error_when_it_sees_oversized_record() -> Result<(), ValidationError> {
328        let record = AppendRecord::new("hello-world")?;
329        let record_bytes = record.metered_bytes();
330        let max_batch_bytes = 10;
331
332        let config = BatchingConfig::default().with_max_batch_bytes(max_batch_bytes)?;
333        let results: Vec<_> = AppendRecordBatches::new(futures::stream::iter(vec![record]), config)
334            .collect()
335            .await;
336
337        assert_eq!(results.len(), 1);
338        assert_matches!(&results[0], Err(err) => {
339            assert_eq!(
340                err.to_string(),
341                format!("record size in metered bytes ({record_bytes}) exceeds max_batch_bytes ({max_batch_bytes})")
342            );
343        });
344
345        Ok(())
346    }
347}