1use std::iter::FusedIterator;
2
3use crate::{
4 caps,
5 read_extent::{EvaluatedReadLimit, ReadLimit, ReadUntil},
6 record::{Metered, MeteredSize, Sequenced, StoredRecord},
7};
8
9pub struct RecordBatch<T = StoredRecord>
10where
11 T: MeteredSize,
12{
13 pub records: Metered<Vec<Sequenced<T>>>,
14 pub is_terminal: bool,
15}
16
17impl<T> std::fmt::Debug for RecordBatch<T>
18where
19 T: MeteredSize,
20{
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 f.debug_struct("RecordBatch")
23 .field("num_records", &self.records.len())
24 .field("metered_size", &self.records.metered_size())
25 .field("is_terminal", &self.is_terminal)
26 .finish()
27 }
28}
29
30pub struct RecordBatcher<I, E, T>
31where
32 T: MeteredSize,
33 I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
34{
35 record_iterator: I,
36 buffered_records: Metered<Vec<Sequenced<T>>>,
37 buffered_error: Option<E>,
38 read_limit: EvaluatedReadLimit,
39 until: ReadUntil,
40 is_terminated: bool,
41}
42
43fn make_records<T>(read_limit: &EvaluatedReadLimit) -> Metered<Vec<Sequenced<T>>>
44where
45 T: MeteredSize,
46{
47 match read_limit {
48 EvaluatedReadLimit::Remaining(limit) => {
49 Metered::with_capacity(limit.count().map_or(caps::RECORD_BATCH_MAX.count, |n| {
50 n.min(caps::RECORD_BATCH_MAX.count)
51 }))
52 }
53 EvaluatedReadLimit::Exhausted => Metered::default(),
54 }
55}
56
57impl<I, E, T> RecordBatcher<I, E, T>
58where
59 T: MeteredSize,
60 I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
61{
62 pub fn new(record_iterator: I, read_limit: ReadLimit, until: ReadUntil) -> Self {
63 let read_limit = read_limit.remaining(0, 0);
64 Self {
65 record_iterator,
66 buffered_records: make_records(&read_limit),
67 buffered_error: None,
68 read_limit,
69 until,
70 is_terminated: false,
71 }
72 }
73
74 fn iter_next(&mut self) -> Option<Result<RecordBatch<T>, E>> {
75 let EvaluatedReadLimit::Remaining(remaining_limit) = self.read_limit else {
76 return None;
77 };
78
79 let mut stashed_record = None;
80 while self.buffered_error.is_none() {
81 match self.record_iterator.next() {
82 Some(Ok(record)) => {
83 if remaining_limit.deny(
84 self.buffered_records.len() + 1,
85 self.buffered_records.metered_size() + record.metered_size(),
86 ) || self.until.deny(record.position.timestamp)
87 {
88 self.read_limit = EvaluatedReadLimit::Exhausted;
89 break;
90 }
91
92 if self.buffered_records.len() == caps::RECORD_BATCH_MAX.count
93 || self.buffered_records.metered_size() + record.metered_size()
94 > caps::RECORD_BATCH_MAX.bytes
95 {
96 stashed_record = Some(record);
98 break;
99 }
100
101 self.buffered_records.push(record);
102 }
103 Some(Err(err)) => {
104 self.buffered_error = Some(err);
105 break;
106 }
107 None => {
108 break;
109 }
110 }
111 }
112 if !self.buffered_records.is_empty() {
113 self.read_limit = match self.read_limit {
114 EvaluatedReadLimit::Remaining(read_limit) => read_limit.remaining(
115 self.buffered_records.len(),
116 self.buffered_records.metered_size(),
117 ),
118 EvaluatedReadLimit::Exhausted => EvaluatedReadLimit::Exhausted,
119 };
120 let is_terminal = self.read_limit == EvaluatedReadLimit::Exhausted;
121 let records = std::mem::replace(
122 &mut self.buffered_records,
123 if is_terminal || self.buffered_error.is_some() {
124 Metered::default()
125 } else {
126 let mut buf = make_records(&self.read_limit);
127 if let Some(record) = stashed_record.take() {
128 buf.push(record);
129 }
130 buf
131 },
132 );
133 return Some(Ok(RecordBatch {
134 records,
135 is_terminal,
136 }));
137 }
138 if let Some(err) = self.buffered_error.take() {
139 return Some(Err(err));
140 }
141 None
142 }
143}
144
145impl<I, E, T> Iterator for RecordBatcher<I, E, T>
146where
147 T: MeteredSize,
148 I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
149{
150 type Item = Result<RecordBatch<T>, E>;
151
152 fn next(&mut self) -> Option<Self::Item> {
153 if self.is_terminated {
154 return None;
155 }
156 let item = self.iter_next();
157 self.is_terminated = matches!(&item, None | Some(Err(_)));
158 item
159 }
160}
161
162impl<I, E, T> FusedIterator for RecordBatcher<I, E, T>
163where
164 T: MeteredSize,
165 I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
166{
167}
168
169#[cfg(test)]
170mod tests {
171 use bytes::Bytes;
172
173 use super::*;
174 use crate::{
175 caps,
176 read_extent::{ReadLimit, ReadUntil},
177 record::{
178 CommandRecord, Encodable, EnvelopeRecord, Metered, MeteredExt, MeteredSize, Record,
179 RecordDecodeError, SeqNum, Sequenced, SequencedRecord, StoredRecord,
180 StoredRecordIterator, StoredSequencedBytes, StoredSequencedRecord, StreamPosition,
181 Timestamp,
182 },
183 };
184
185 fn test_logical_record(seq_num: SeqNum, timestamp: Timestamp) -> SequencedRecord {
186 Record::Command(CommandRecord::Trim(seq_num))
187 .metered()
188 .sequenced(StreamPosition { seq_num, timestamp })
189 .into_inner()
190 }
191
192 fn test_record(seq_num: SeqNum, timestamp: Timestamp) -> StoredSequencedRecord {
193 Metered::from(StoredRecord::from(Record::Command(CommandRecord::Trim(
194 seq_num,
195 ))))
196 .sequenced(StreamPosition { seq_num, timestamp })
197 .into_inner()
198 }
199
200 fn test_large_record(
201 seq_num: SeqNum,
202 timestamp: Timestamp,
203 body_len: usize,
204 ) -> StoredSequencedRecord {
205 Metered::from(StoredRecord::from(Record::Envelope(
206 EnvelopeRecord::try_from_parts(vec![], Bytes::from(vec![0; body_len])).unwrap(),
207 )))
208 .sequenced(StreamPosition { seq_num, timestamp })
209 .into_inner()
210 }
211
212 fn to_iter(
213 records: Vec<StoredSequencedRecord>,
214 ) -> impl Iterator<Item = Result<Metered<StoredSequencedRecord>, RecordDecodeError>> {
215 records.into_iter().map(Metered::from).map(Ok)
216 }
217
218 fn to_logical_iter(
219 records: Vec<SequencedRecord>,
220 ) -> impl Iterator<Item = Result<Metered<SequencedRecord>, RecordDecodeError>> {
221 records.into_iter().map(Metered::from).map(Ok)
222 }
223
224 fn to_stored_bytes_iter(
225 records: Vec<StoredSequencedRecord>,
226 ) -> impl Iterator<Item = Result<StoredSequencedBytes, RecordDecodeError>> {
227 records
228 .into_iter()
229 .map(|record| {
230 let (position, record) = record.into_parts();
231 Sequenced::new(position, (&record).metered().to_bytes())
232 })
233 .map(Ok)
234 }
235
236 fn assert_batch(batch: &RecordBatch, expected: &[StoredSequencedRecord], is_terminal: bool) {
237 assert_eq!(batch.is_terminal, is_terminal);
238 assert_eq!(batch.records.len(), expected.len());
239 let expected_size: usize = expected.iter().map(|r| r.metered_size()).sum();
240 assert_eq!(batch.records.metered_size(), expected_size);
241 for (actual, expected) in batch.records.iter().zip(expected.iter()) {
242 assert_eq!(actual, expected);
243 }
244 }
245
246 #[test]
247 fn collects_records_until_iterator_ends() {
248 let expected = vec![test_record(1, 10), test_record(2, 11), test_record(3, 12)];
249 let mut batcher = RecordBatcher::new(
250 to_iter(expected.clone()),
251 ReadLimit::Unbounded,
252 ReadUntil::Unbounded,
253 );
254 let batch = batcher.next().expect("batch expected").expect("ok batch");
255 assert_batch(&batch, &expected, false);
256 assert!(batcher.next().is_none());
257 }
258
259 #[test]
260 fn generic_batcher_collects_logical_records() {
261 let expected = vec![
262 test_logical_record(1, 10),
263 test_logical_record(2, 11),
264 test_logical_record(3, 12),
265 ];
266 let mut batcher = RecordBatcher::new(
267 to_logical_iter(expected.clone()),
268 ReadLimit::Unbounded,
269 ReadUntil::Unbounded,
270 );
271
272 let batch = batcher.next().expect("batch expected").expect("ok batch");
273 assert!(!batch.is_terminal);
274 assert_eq!(batch.records.len(), expected.len());
275 let expected_size: usize = expected.iter().map(|r| r.metered_size()).sum();
276 assert_eq!(batch.records.metered_size(), expected_size);
277 for (actual, expected) in batch.records.iter().zip(expected.iter()) {
278 assert_eq!(actual, expected);
279 }
280 assert!(batcher.next().is_none());
281 }
282
283 #[test]
284 fn stops_at_count_read_limit() {
285 let expected = vec![test_record(1, 10), test_record(2, 11), test_record(3, 12)];
286 let mut batcher = RecordBatcher::new(
287 to_iter(expected.clone()),
288 ReadLimit::Count(2),
289 ReadUntil::Unbounded,
290 );
291
292 let batch = batcher.next().expect("batch expected").expect("ok batch");
293 assert_batch(&batch, &expected[..2], true);
294 assert!(batcher.next().is_none());
295 }
296
297 #[test]
298 fn stops_at_byte_read_limit() {
299 let expected = vec![test_record(1, 10), test_record(2, 11)];
300 let first_size = expected[0].metered_size();
301 let mut batcher = RecordBatcher::new(
302 to_iter(expected.clone()),
303 ReadLimit::Bytes(first_size),
304 ReadUntil::Unbounded,
305 );
306
307 let batch = batcher.next().expect("batch expected").expect("ok batch");
308 assert_batch(&batch, &expected[..1], true);
309 assert!(batcher.next().is_none());
310 }
311
312 #[test]
313 fn stops_at_timestamp_limit() {
314 let expected = vec![test_record(1, 10), test_record(2, 19), test_record(3, 20)];
315 let mut batcher = RecordBatcher::new(
316 to_iter(expected.clone()),
317 ReadLimit::Unbounded,
318 ReadUntil::Timestamp(20),
319 );
320
321 let batch = batcher.next().expect("batch expected").expect("ok batch");
322 assert_batch(&batch, &expected[..2], true);
323 assert!(batcher.next().is_none());
324 }
325
326 #[test]
327 fn splits_batches_when_caps_are_hit() {
328 let mut records = Vec::with_capacity(caps::RECORD_BATCH_MAX.count + 1);
329 for index in 0..=(caps::RECORD_BATCH_MAX.count as SeqNum) {
330 records.push(test_record(index, index + 10));
331 }
332 let mut batcher = RecordBatcher::new(
333 to_iter(records.clone()),
334 ReadLimit::Unbounded,
335 ReadUntil::Unbounded,
336 );
337
338 let first_batch = batcher
339 .next()
340 .expect("first batch expected")
341 .expect("first batch ok");
342 assert_batch(
343 &first_batch,
344 &records[..caps::RECORD_BATCH_MAX.count],
345 false,
346 );
347
348 let second_batch = batcher
349 .next()
350 .expect("second batch expected")
351 .expect("second batch ok");
352 assert_batch(
353 &second_batch,
354 &records[caps::RECORD_BATCH_MAX.count..],
355 false,
356 );
357 assert!(batcher.next().is_none());
358 }
359
360 #[test]
361 fn splits_batches_when_byte_cap_is_hit() {
362 let records = vec![
363 test_large_record(1, 10, caps::RECORD_BATCH_MAX.bytes / 2 + 1),
364 test_large_record(2, 11, caps::RECORD_BATCH_MAX.bytes / 2 + 1),
365 ];
366 assert!(records[0].metered_size() <= caps::RECORD_BATCH_MAX.bytes);
367 assert!(records[1].metered_size() <= caps::RECORD_BATCH_MAX.bytes);
368 assert!(
369 records[0].metered_size() + records[1].metered_size() > caps::RECORD_BATCH_MAX.bytes
370 );
371
372 let mut batcher = RecordBatcher::new(
373 to_iter(records.clone()),
374 ReadLimit::Unbounded,
375 ReadUntil::Unbounded,
376 );
377
378 let first_batch = batcher
379 .next()
380 .expect("first batch expected")
381 .expect("first batch ok");
382 assert_batch(&first_batch, &records[..1], false);
383
384 let second_batch = batcher
385 .next()
386 .expect("second batch expected")
387 .expect("second batch ok");
388 assert_batch(&second_batch, &records[1..], false);
389 assert!(batcher.next().is_none());
390 }
391
392 #[test]
393 fn surfaces_decode_errors_after_draining_buffer() {
394 let records = vec![test_record(1, 10), test_record(2, 11)];
395 let invalid_data = Sequenced::new(
396 StreamPosition {
397 seq_num: 3,
398 timestamp: 12,
399 },
400 Bytes::new(),
401 );
402
403 let mut batcher = RecordBatcher::new(
404 StoredRecordIterator::new(
405 to_stored_bytes_iter(records.clone()).chain(std::iter::once(Ok(invalid_data))),
406 ),
407 ReadLimit::Unbounded,
408 ReadUntil::Unbounded,
409 );
410
411 let batch = batcher.next().expect("batch expected").expect("ok batch");
412 assert_batch(&batch, &records, false);
413
414 let error = batcher
415 .next()
416 .expect("error expected")
417 .expect_err("expected decode error");
418 assert!(matches!(error, RecordDecodeError::Truncated("MagicByte")));
419 assert!(batcher.next().is_none());
420 }
421
422 #[test]
423 fn surfaces_iterator_errors_immediately() {
424 let iterator = StoredRecordIterator::new(std::iter::once::<
425 Result<StoredSequencedBytes, RecordDecodeError>,
426 >(Err(RecordDecodeError::InvalidValue(
427 "test", "boom",
428 ))));
429 let mut batcher = RecordBatcher::new(iterator, ReadLimit::Unbounded, ReadUntil::Unbounded);
430
431 let error = batcher
432 .next()
433 .expect("error expected")
434 .expect_err("expected iterator error");
435 assert!(matches!(
436 error,
437 RecordDecodeError::InvalidValue("test", "boom")
438 ));
439 assert!(batcher.next().is_none());
440 }
441}