1use std::{
4 pin::Pin,
5 task::{Context, Poll},
6 time::Duration,
7};
8
9use futures::{Stream, StreamExt};
10use s2_common::caps::RECORD_BATCH_MAX;
11use tokio::time::Instant;
12
13use crate::types::{
14 AppendInput, AppendRecord, AppendRecordBatch, FencingToken, MeteredBytes, ValidationError,
15};
16
17#[derive(Debug, Clone)]
18pub struct BatchingConfig {
20 linger: Duration,
21 max_batch_bytes: usize,
22 max_batch_records: usize,
23}
24
25impl Default for BatchingConfig {
26 fn default() -> Self {
27 Self {
28 linger: Duration::from_millis(5),
29 max_batch_bytes: RECORD_BATCH_MAX.bytes,
30 max_batch_records: RECORD_BATCH_MAX.count,
31 }
32 }
33}
34
35impl BatchingConfig {
36 pub fn new() -> Self {
38 Self::default()
39 }
40
41 pub fn with_linger(self, linger: Duration) -> Self {
45 Self { linger, ..self }
46 }
47
48 pub fn with_max_batch_bytes(self, max_batch_bytes: usize) -> Result<Self, ValidationError> {
54 if max_batch_bytes > RECORD_BATCH_MAX.bytes {
55 return Err(ValidationError(format!(
56 "max_batch_bytes ({max_batch_bytes}) exceeds {}",
57 RECORD_BATCH_MAX.bytes
58 )));
59 }
60 Ok(Self {
61 max_batch_bytes,
62 ..self
63 })
64 }
65
66 pub fn with_max_batch_records(self, max_batch_records: usize) -> Result<Self, ValidationError> {
72 if max_batch_records > RECORD_BATCH_MAX.count {
73 return Err(ValidationError(format!(
74 "max_batch_records ({max_batch_records}) exceeds {}",
75 RECORD_BATCH_MAX.count
76 )));
77 }
78 Ok(Self {
79 max_batch_records,
80 ..self
81 })
82 }
83}
84
85pub struct AppendInputs {
87 pub(crate) batches: AppendRecordBatches,
88 pub(crate) fencing_token: Option<FencingToken>,
89 pub(crate) match_seq_num: Option<u64>,
90}
91
92impl AppendInputs {
93 pub fn new(
95 records: impl Stream<Item = impl Into<AppendRecord> + Send> + Send + Unpin + 'static,
96 config: BatchingConfig,
97 ) -> Self {
98 Self {
99 batches: AppendRecordBatches::new(records, config),
100 fencing_token: None,
101 match_seq_num: None,
102 }
103 }
104
105 pub fn with_fencing_token(self, fencing_token: FencingToken) -> Self {
107 Self {
108 fencing_token: Some(fencing_token),
109 ..self
110 }
111 }
112
113 pub fn with_match_seq_num(self, seq_num: u64) -> Self {
116 Self {
117 match_seq_num: Some(seq_num),
118 ..self
119 }
120 }
121}
122
123impl Stream for AppendInputs {
124 type Item = Result<AppendInput, ValidationError>;
125
126 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127 match self.batches.poll_next_unpin(cx) {
128 Poll::Ready(Some(Ok(batch))) => {
129 let match_seq_num = self.match_seq_num;
130 if let Some(seq_num) = self.match_seq_num.as_mut() {
131 *seq_num += batch.len() as u64;
132 }
133 Poll::Ready(Some(Ok(AppendInput {
134 records: batch,
135 match_seq_num,
136 fencing_token: self.fencing_token.clone(),
137 })))
138 }
139 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
140 Poll::Ready(None) => Poll::Ready(None),
141 Poll::Pending => Poll::Pending,
142 }
143 }
144}
145
146pub struct AppendRecordBatches {
148 inner: Pin<Box<dyn Stream<Item = Result<AppendRecordBatch, ValidationError>> + Send>>,
149}
150
151impl AppendRecordBatches {
152 pub fn new(
154 records: impl Stream<Item = impl Into<AppendRecord> + Send> + Send + Unpin + 'static,
155 config: BatchingConfig,
156 ) -> Self {
157 Self {
158 inner: Box::pin(append_record_batches(records, config)),
159 }
160 }
161}
162
163impl Stream for AppendRecordBatches {
164 type Item = Result<AppendRecordBatch, ValidationError>;
165
166 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
167 self.inner.as_mut().poll_next(cx)
168 }
169}
170
171fn is_batch_full(config: &BatchingConfig, count: usize, bytes: usize) -> bool {
172 count >= config.max_batch_records || bytes >= config.max_batch_bytes
173}
174
175fn would_overflow_batch(
176 config: &BatchingConfig,
177 count: usize,
178 bytes: usize,
179 record: &AppendRecord,
180) -> bool {
181 count + 1 > config.max_batch_records || bytes + record.metered_bytes() > config.max_batch_bytes
182}
183
184fn append_record_batches(
185 mut records: impl Stream<Item = impl Into<AppendRecord> + Send> + Send + Unpin + 'static,
186 config: BatchingConfig,
187) -> impl Stream<Item = Result<AppendRecordBatch, ValidationError>> + Send + 'static {
188 async_stream::try_stream! {
189 let mut batch = AppendRecordBatch::with_capacity(config.max_batch_records);
190 let mut overflowed_record: Option<AppendRecord> = None;
191
192 let linger_deadline = tokio::time::sleep(config.linger);
193 tokio::pin!(linger_deadline);
194
195 'outer: loop {
196 let first_record = match overflowed_record.take() {
197 Some(record) => record,
198 None => match records.next().await {
199 Some(item) => item.into(),
200 None => break,
201 },
202 };
203
204 let record_bytes = first_record.metered_bytes();
205 if record_bytes > config.max_batch_bytes {
206 Err(ValidationError(format!(
207 "record size in metered bytes ({record_bytes}) exceeds max_batch_bytes ({})",
208 config.max_batch_bytes
209 )))?;
210 }
211 batch.push(first_record);
212
213 while !is_batch_full(&config, batch.len(), batch.metered_bytes())
214 && overflowed_record.is_none()
215 {
216 if batch.len() == 1 {
217 linger_deadline
218 .as_mut()
219 .reset(Instant::now() + config.linger);
220 }
221
222 tokio::select! {
223 next_record = records.next() => {
224 match next_record {
225 Some(record) => {
226 let record: AppendRecord = record.into();
227 if would_overflow_batch(&config, batch.len(), batch.metered_bytes(), &record) {
228 overflowed_record = Some(record);
229 } else {
230 batch.push(record);
231 }
232 }
233 None => {
234 yield std::mem::replace(&mut batch, AppendRecordBatch::with_capacity(config.max_batch_records));
235 break 'outer;
236 }
237 }
238 },
239 _ = &mut linger_deadline, if !batch.is_empty() => {
240 break;
241 }
242 };
243 }
244
245 yield std::mem::replace(
246 &mut batch,
247 AppendRecordBatch::with_capacity(config.max_batch_records),
248 );
249 }
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use assert_matches::assert_matches;
256 use futures::TryStreamExt;
257
258 use super::*;
259
260 #[tokio::test]
261 async fn batches_should_be_empty_when_record_stream_is_empty() {
262 let batches: Vec<_> = AppendRecordBatches::new(
263 futures::stream::iter::<Vec<AppendRecord>>(vec![]),
264 BatchingConfig::default(),
265 )
266 .collect()
267 .await;
268 assert_eq!(batches.len(), 0);
269 }
270
271 #[tokio::test]
272 async fn batches_respect_count_limit() -> Result<(), ValidationError> {
273 let records: Vec<_> = (0..10)
274 .map(|i| AppendRecord::new(format!("record{i}")))
275 .collect::<Result<_, _>>()?;
276 let config = BatchingConfig::default().with_max_batch_records(3)?;
277 let batches: Vec<_> = AppendRecordBatches::new(futures::stream::iter(records), config)
278 .try_collect()
279 .await?;
280
281 assert_eq!(batches.len(), 4);
282 assert_eq!(batches[0].len(), 3);
283 assert_eq!(batches[1].len(), 3);
284 assert_eq!(batches[2].len(), 3);
285 assert_eq!(batches[3].len(), 1);
286
287 Ok(())
288 }
289
290 #[tokio::test]
291 async fn batches_respect_bytes_limit() -> Result<(), ValidationError> {
292 let records: Vec<_> = (0..10)
293 .map(|i| AppendRecord::new(format!("record{i}")))
294 .collect::<Result<_, _>>()?;
295 let single_record_bytes = records[0].metered_bytes();
296 let max_batch_bytes = single_record_bytes * 3;
297
298 let config = BatchingConfig::default().with_max_batch_bytes(max_batch_bytes)?;
299 let batches: Vec<_> = AppendRecordBatches::new(futures::stream::iter(records), config)
300 .try_collect()
301 .await?;
302
303 assert_eq!(batches.len(), 4);
304 assert_eq!(batches[0].metered_bytes(), max_batch_bytes);
305 assert_eq!(batches[1].metered_bytes(), max_batch_bytes);
306 assert_eq!(batches[2].metered_bytes(), max_batch_bytes);
307 assert_eq!(batches[3].metered_bytes(), single_record_bytes);
308
309 Ok(())
310 }
311
312 #[tokio::test]
313 async fn batching_should_error_when_it_sees_oversized_record() -> Result<(), ValidationError> {
314 let record = AppendRecord::new("hello")?;
315 let record_bytes = record.metered_bytes();
316 let max_batch_bytes = 1;
317
318 let config = BatchingConfig::default().with_max_batch_bytes(max_batch_bytes)?;
319 let results: Vec<_> = AppendRecordBatches::new(futures::stream::iter(vec![record]), config)
320 .collect()
321 .await;
322
323 assert_eq!(results.len(), 1);
324 assert_matches!(&results[0], Err(err) => {
325 assert_eq!(
326 err.to_string(),
327 format!("record size in metered bytes ({record_bytes}) exceeds max_batch_bytes ({max_batch_bytes})")
328 );
329 });
330
331 Ok(())
332 }
333}