1use std::{
50 pin::Pin,
51 task::{Context, Poll},
52 time::Duration,
53};
54
55use futures::{Stream, StreamExt};
56
57use crate::types;
58
59#[derive(Debug, Clone)]
61pub struct AppendRecordsBatchingOpts {
62 max_batch_records: usize,
63 #[cfg(test)]
64 max_batch_bytes: u64,
65 match_seq_num: Option<u64>,
66 fencing_token: Option<types::FencingToken>,
67 linger_duration: Duration,
68}
69
70impl Default for AppendRecordsBatchingOpts {
71 fn default() -> Self {
72 Self {
73 max_batch_records: 1000,
74 #[cfg(test)]
75 max_batch_bytes: types::AppendRecordBatch::MAX_BYTES,
76 match_seq_num: None,
77 fencing_token: None,
78 linger_duration: Duration::from_millis(5),
79 }
80 }
81}
82
83impl AppendRecordsBatchingOpts {
84 pub fn new() -> Self {
86 Self::default()
87 }
88
89 pub fn with_max_batch_records(self, max_batch_records: usize) -> Self {
91 assert!(
92 max_batch_records > 0 && max_batch_records <= types::AppendRecordBatch::MAX_CAPACITY,
93 "Batch capacity must be between 1 and 1000"
94 );
95 Self {
96 max_batch_records,
97 ..self
98 }
99 }
100
101 #[cfg(test)]
103 pub fn with_max_batch_bytes(self, max_batch_bytes: u64) -> Self {
104 assert!(
105 max_batch_bytes > 0 && max_batch_bytes <= types::AppendRecordBatch::MAX_BYTES,
106 "Batch capacity must be between 1 byte and 1 MiB"
107 );
108 Self {
109 max_batch_bytes,
110 ..self
111 }
112 }
113
114 pub fn with_match_seq_num(self, match_seq_num: Option<u64>) -> Self {
118 Self {
119 match_seq_num,
120 ..self
121 }
122 }
123
124 pub fn with_fencing_token(self, fencing_token: Option<types::FencingToken>) -> Self {
126 Self {
127 fencing_token,
128 ..self
129 }
130 }
131
132 pub fn with_linger(self, linger_duration: impl Into<Duration>) -> Self {
137 Self {
138 linger_duration: linger_duration.into(),
139 ..self
140 }
141 }
142}
143
144pub struct AppendRecordsBatchingStream(Pin<Box<dyn Stream<Item = types::AppendInput> + Send>>);
153
154impl AppendRecordsBatchingStream {
155 pub fn new<R, S>(stream: S, opts: AppendRecordsBatchingOpts) -> Self
157 where
158 R: 'static + Into<types::AppendRecord>,
159 S: 'static + Send + Stream<Item = R> + Unpin,
160 {
161 Self(Box::pin(append_records_batching_stream(stream, opts)))
162 }
163}
164
165impl Stream for AppendRecordsBatchingStream {
166 type Item = types::AppendInput;
167
168 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
169 self.0.poll_next_unpin(cx)
170 }
171}
172
173fn append_records_batching_stream<R, S>(
174 mut stream: S,
175 opts: AppendRecordsBatchingOpts,
176) -> impl Stream<Item = types::AppendInput> + Send
177where
178 R: Into<types::AppendRecord>,
179 S: 'static + Send + Stream<Item = R> + Unpin,
180{
181 async_stream::stream! {
182 let mut terminated = false;
183 let mut batch_builder = BatchBuilder::new(&opts);
184
185 let batch_deadline = tokio::time::sleep(Duration::ZERO);
186 tokio::pin!(batch_deadline);
187
188 while !terminated {
189 while !batch_builder.is_full() {
190 if batch_builder.len() == 1 {
191 batch_deadline
193 .as_mut()
194 .reset(tokio::time::Instant::now() + opts.linger_duration);
195 }
196
197 tokio::select! {
198 biased;
199 next = stream.next() => {
200 if let Some(record) = next {
201 batch_builder.push(record);
202 } else {
203 terminated = true;
204 break;
205 }
206 },
207 _ = &mut batch_deadline, if !batch_builder.is_empty() => {
208 break;
209 }
210 };
211 }
212
213 if let Some(input) = batch_builder.flush() {
214 yield input;
215 }
216 }
217 }
218}
219
220struct BatchBuilder<'a> {
221 opts: &'a AppendRecordsBatchingOpts,
222 peeked_record: Option<types::AppendRecord>,
223 next_match_seq_num: Option<u64>,
224 batch: types::AppendRecordBatch,
225}
226
227impl<'a> BatchBuilder<'a> {
228 pub fn new<'b: 'a>(opts: &'b AppendRecordsBatchingOpts) -> Self {
229 Self {
230 peeked_record: None,
231 next_match_seq_num: opts.match_seq_num,
232 batch: Self::new_batch(opts),
233 opts,
234 }
235 }
236
237 #[cfg(not(test))]
238 fn new_batch(opts: &AppendRecordsBatchingOpts) -> types::AppendRecordBatch {
239 types::AppendRecordBatch::with_max_capacity(opts.max_batch_records)
240 }
241
242 #[cfg(test)]
243 fn new_batch(opts: &AppendRecordsBatchingOpts) -> types::AppendRecordBatch {
244 types::AppendRecordBatch::with_max_capacity_and_bytes(
245 opts.max_batch_records,
246 opts.max_batch_bytes,
247 )
248 }
249
250 pub fn push(&mut self, record: impl Into<types::AppendRecord>) {
251 if let Err(record) = self.batch.push(record) {
252 let ret = self.peeked_record.replace(record);
253 assert_eq!(ret, None);
254 }
255 }
256
257 pub fn is_empty(&self) -> bool {
258 self.batch.is_empty()
259 }
260
261 pub fn len(&self) -> usize {
262 self.batch.len()
263 }
264
265 pub fn is_full(&self) -> bool {
266 self.batch.is_full() || self.peeked_record.is_some()
267 }
268
269 pub fn flush(&mut self) -> Option<types::AppendInput> {
270 let ret = if self.batch.is_empty() {
271 None
272 } else {
273 let match_seq_num = self.next_match_seq_num;
274 if let Some(next_match_seq_num) = self.next_match_seq_num.as_mut() {
275 *next_match_seq_num += self.batch.len() as u64;
276 }
277
278 let records = std::mem::replace(&mut self.batch, Self::new_batch(self.opts));
279 Some(types::AppendInput {
280 records,
281 match_seq_num,
282 fencing_token: self.opts.fencing_token.clone(),
283 })
284 };
285
286 if let Some(record) = self.peeked_record.take() {
287 self.push(record);
288 }
289
290 assert_eq!(self.peeked_record, None);
294
295 ret
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use std::time::Duration;
302
303 use bytes::Bytes;
304 use futures::StreamExt as _;
305 use rstest::rstest;
306 use tokio::sync::mpsc;
307 use tokio_stream::wrappers::UnboundedReceiverStream;
308
309 use super::{AppendRecordsBatchingOpts, AppendRecordsBatchingStream};
310 use crate::types::{self, AppendInput, AppendRecordBatch};
311
312 #[rstest]
313 #[case(Some(2), None)]
314 #[case(None, Some(30))]
315 #[case(Some(2), Some(100))]
316 #[case(Some(10), Some(30))]
317 #[tokio::test]
318 async fn test_append_record_batching_mechanics(
319 #[case] max_batch_records: Option<usize>,
320 #[case] max_batch_bytes: Option<u64>,
321 ) {
322 let stream_iter = (0..100)
323 .map(|i| {
324 let body = format!("r_{i}");
325 if let Some(max_batch_size) = max_batch_bytes {
326 types::AppendRecord::with_max_bytes(max_batch_size, body)
327 } else {
328 types::AppendRecord::new(body)
329 }
330 .unwrap()
331 })
332 .collect::<Vec<_>>();
333 let stream = futures::stream::iter(stream_iter);
334
335 let mut opts = AppendRecordsBatchingOpts::new().with_linger(Duration::ZERO);
336 if let Some(max_batch_records) = max_batch_records {
337 opts = opts.with_max_batch_records(max_batch_records);
338 }
339 if let Some(max_batch_size) = max_batch_bytes {
340 opts = opts.with_max_batch_bytes(max_batch_size);
341 }
342
343 let batch_stream = AppendRecordsBatchingStream::new(stream, opts);
344
345 let batches = batch_stream
346 .map(|batch| batch.records)
347 .collect::<Vec<_>>()
348 .await;
349
350 let mut i = 0;
351 for batch in batches {
352 assert_eq!(batch.len(), 2);
353 for record in batch {
354 assert_eq!(record.into_parts().body, format!("r_{i}").into_bytes());
355 i += 1;
356 }
357 }
358 }
359
360 #[tokio::test(start_paused = true)]
361 async fn test_append_record_batching_linger() {
362 let (stream_tx, stream_rx) = mpsc::unbounded_channel::<types::AppendRecord>();
363 let mut i = 0;
364
365 let size_limit = 40;
366
367 let collect_batches_handle = tokio::spawn(async move {
368 let batch_stream = AppendRecordsBatchingStream::new(
369 UnboundedReceiverStream::new(stream_rx),
370 AppendRecordsBatchingOpts::new()
371 .with_linger(Duration::from_secs(2))
372 .with_max_batch_records(3)
373 .with_max_batch_bytes(size_limit),
374 );
375
376 batch_stream
377 .map(|batch| {
378 batch
379 .records
380 .into_iter()
381 .map(|rec| rec.into_parts().body)
382 .collect::<Vec<_>>()
383 })
384 .collect::<Vec<_>>()
385 .await
386 });
387
388 let mut send_next = |padding: Option<&str>| {
389 let mut record =
390 types::AppendRecord::with_max_bytes(size_limit, format!("r_{i}")).unwrap();
391 if let Some(padding) = padding {
392 record = record
395 .with_headers(vec![types::Header::new("padding", padding.to_owned())])
396 .unwrap();
397 }
398 stream_tx.send(record).unwrap();
399 i += 1;
400 };
401
402 async fn sleep_secs(secs: u64) {
403 let dur = Duration::from_secs(secs) + Duration::from_millis(10);
404 tokio::time::sleep(dur).await;
405 }
406
407 send_next(None);
408 send_next(None);
409
410 sleep_secs(2).await;
411
412 send_next(None);
413
414 sleep_secs(1).await;
416
417 send_next(None);
418
419 sleep_secs(1).await;
420
421 send_next(None);
424 send_next(None);
425 send_next(None);
426 send_next(None);
427
428 sleep_secs(200).await;
430
431 send_next(Some("large string"));
434 send_next(None);
435
436 std::mem::drop(stream_tx); let batches = collect_batches_handle.await.unwrap();
439
440 let expected_batches: Vec<Vec<Bytes>> = vec![
441 vec!["r_0".into(), "r_1".into()],
442 vec!["r_2".into(), "r_3".into()],
443 vec!["r_4".into(), "r_5".into(), "r_6".into()],
444 vec!["r_7".into()],
445 vec!["r_8".into()],
446 vec!["r_9".into()],
447 ];
448
449 assert_eq!(batches, expected_batches);
450 }
451
452 #[tokio::test]
453 #[should_panic]
454 async fn test_append_record_batching_panic_size_limits() {
455 let size_limit = 1;
456 let record =
457 types::AppendRecord::with_max_bytes(size_limit, "too long to fit into size limits")
458 .unwrap();
459
460 let mut batch_stream = AppendRecordsBatchingStream::new(
461 futures::stream::iter([record]),
462 AppendRecordsBatchingOpts::new().with_max_batch_bytes(size_limit),
463 );
464
465 let _ = batch_stream.next().await;
466 }
467
468 #[tokio::test]
469 async fn test_append_record_batching_append_input_opts() {
470 let test_record = types::AppendRecord::new("a").unwrap();
471
472 let total_records = 12;
473 let test_records = (0..total_records)
474 .map(|_| test_record.clone())
475 .collect::<Vec<_>>();
476
477 let expected_fencing_token: types::FencingToken = "hello".parse().unwrap();
478 let mut expected_match_seq_num = 10;
479
480 let num_batch_records = 3;
481
482 let batch_stream = AppendRecordsBatchingStream::new(
483 futures::stream::iter(test_records),
484 AppendRecordsBatchingOpts::new()
485 .with_max_batch_records(num_batch_records)
486 .with_fencing_token(Some(expected_fencing_token.clone()))
487 .with_match_seq_num(Some(expected_match_seq_num)),
488 );
489
490 let batches = batch_stream.collect::<Vec<_>>().await;
491
492 assert_eq!(batches.len(), total_records / num_batch_records);
493
494 let expected_batch =
495 AppendRecordBatch::try_from_iter((0..num_batch_records).map(|_| test_record.clone()))
496 .unwrap();
497
498 for input in batches {
499 let AppendInput {
500 records,
501 match_seq_num,
502 fencing_token,
503 } = input;
504 assert_eq!(records, expected_batch);
505 assert_eq!(fencing_token.as_ref(), Some(&expected_fencing_token));
506 assert_eq!(match_seq_num, Some(expected_match_seq_num));
507 expected_match_seq_num += num_batch_records as u64;
508 }
509 }
510}