Skip to main content

s2_common/record/
batcher.rs

1use std::iter::FusedIterator;
2
3use crate::{
4    caps,
5    read_extent::{EvaluatedReadLimit, ReadLimit, ReadUntil},
6    record::{Metered, MeteredSize, Sequenced, StoredRecord},
7};
8
9pub struct RecordBatch<T = StoredRecord>
10where
11    T: MeteredSize,
12{
13    pub records: Metered<Vec<Sequenced<T>>>,
14    pub is_terminal: bool,
15}
16
17impl<T> std::fmt::Debug for RecordBatch<T>
18where
19    T: MeteredSize,
20{
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        f.debug_struct("RecordBatch")
23            .field("num_records", &self.records.len())
24            .field("metered_size", &self.records.metered_size())
25            .field("is_terminal", &self.is_terminal)
26            .finish()
27    }
28}
29
30pub struct RecordBatcher<I, E, T>
31where
32    T: MeteredSize,
33    I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
34{
35    record_iterator: I,
36    buffered_records: Metered<Vec<Sequenced<T>>>,
37    buffered_error: Option<E>,
38    read_limit: EvaluatedReadLimit,
39    until: ReadUntil,
40    is_terminated: bool,
41}
42
43fn make_records<T>(read_limit: &EvaluatedReadLimit) -> Metered<Vec<Sequenced<T>>>
44where
45    T: MeteredSize,
46{
47    match read_limit {
48        EvaluatedReadLimit::Remaining(limit) => {
49            Metered::with_capacity(limit.count().map_or(caps::RECORD_BATCH_MAX.count, |n| {
50                n.min(caps::RECORD_BATCH_MAX.count)
51            }))
52        }
53        EvaluatedReadLimit::Exhausted => Metered::default(),
54    }
55}
56
57impl<I, E, T> RecordBatcher<I, E, T>
58where
59    T: MeteredSize,
60    I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
61{
62    pub fn new(record_iterator: I, read_limit: ReadLimit, until: ReadUntil) -> Self {
63        let read_limit = read_limit.remaining(0, 0);
64        Self {
65            record_iterator,
66            buffered_records: make_records(&read_limit),
67            buffered_error: None,
68            read_limit,
69            until,
70            is_terminated: false,
71        }
72    }
73
74    fn iter_next(&mut self) -> Option<Result<RecordBatch<T>, E>> {
75        let EvaluatedReadLimit::Remaining(remaining_limit) = self.read_limit else {
76            return None;
77        };
78
79        let mut stashed_record = None;
80        while self.buffered_error.is_none() {
81            match self.record_iterator.next() {
82                Some(Ok(record)) => {
83                    if remaining_limit.deny(
84                        self.buffered_records.len() + 1,
85                        self.buffered_records.metered_size() + record.metered_size(),
86                    ) || self.until.deny(record.position.timestamp)
87                    {
88                        self.read_limit = EvaluatedReadLimit::Exhausted;
89                        break;
90                    }
91
92                    if self.buffered_records.len() == caps::RECORD_BATCH_MAX.count
93                        || self.buffered_records.metered_size() + record.metered_size()
94                            > caps::RECORD_BATCH_MAX.bytes
95                    {
96                        // It would would violate the per-batch limits.
97                        stashed_record = Some(record);
98                        break;
99                    }
100
101                    self.buffered_records.push(record);
102                }
103                Some(Err(err)) => {
104                    self.buffered_error = Some(err);
105                    break;
106                }
107                None => {
108                    break;
109                }
110            }
111        }
112        if !self.buffered_records.is_empty() {
113            self.read_limit = match self.read_limit {
114                EvaluatedReadLimit::Remaining(read_limit) => read_limit.remaining(
115                    self.buffered_records.len(),
116                    self.buffered_records.metered_size(),
117                ),
118                EvaluatedReadLimit::Exhausted => EvaluatedReadLimit::Exhausted,
119            };
120            let is_terminal = self.read_limit == EvaluatedReadLimit::Exhausted;
121            let records = std::mem::replace(
122                &mut self.buffered_records,
123                if is_terminal || self.buffered_error.is_some() {
124                    Metered::default()
125                } else {
126                    let mut buf = make_records(&self.read_limit);
127                    if let Some(record) = stashed_record.take() {
128                        buf.push(record);
129                    }
130                    buf
131                },
132            );
133            return Some(Ok(RecordBatch {
134                records,
135                is_terminal,
136            }));
137        }
138        if let Some(err) = self.buffered_error.take() {
139            return Some(Err(err));
140        }
141        None
142    }
143}
144
145impl<I, E, T> Iterator for RecordBatcher<I, E, T>
146where
147    T: MeteredSize,
148    I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
149{
150    type Item = Result<RecordBatch<T>, E>;
151
152    fn next(&mut self) -> Option<Self::Item> {
153        if self.is_terminated {
154            return None;
155        }
156        let item = self.iter_next();
157        self.is_terminated = matches!(&item, None | Some(Err(_)));
158        item
159    }
160}
161
162impl<I, E, T> FusedIterator for RecordBatcher<I, E, T>
163where
164    T: MeteredSize,
165    I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
166{
167}
168
169#[cfg(test)]
170mod tests {
171    use bytes::Bytes;
172
173    use super::*;
174    use crate::{
175        caps,
176        read_extent::{ReadLimit, ReadUntil},
177        record::{
178            CommandRecord, Encodable, EnvelopeRecord, Metered, MeteredExt, MeteredSize, Record,
179            RecordDecodeError, SeqNum, Sequenced, SequencedRecord, StoredRecord,
180            StoredRecordIterator, StoredSequencedBytes, StoredSequencedRecord, StreamPosition,
181            Timestamp,
182        },
183    };
184
185    fn test_logical_record(seq_num: SeqNum, timestamp: Timestamp) -> SequencedRecord {
186        Record::Command(CommandRecord::Trim(seq_num))
187            .metered()
188            .sequenced(StreamPosition { seq_num, timestamp })
189            .into_inner()
190    }
191
192    fn test_record(seq_num: SeqNum, timestamp: Timestamp) -> StoredSequencedRecord {
193        Metered::from(StoredRecord::from(Record::Command(CommandRecord::Trim(
194            seq_num,
195        ))))
196        .sequenced(StreamPosition { seq_num, timestamp })
197        .into_inner()
198    }
199
200    fn test_large_record(
201        seq_num: SeqNum,
202        timestamp: Timestamp,
203        body_len: usize,
204    ) -> StoredSequencedRecord {
205        Metered::from(StoredRecord::from(Record::Envelope(
206            EnvelopeRecord::try_from_parts(vec![], Bytes::from(vec![0; body_len])).unwrap(),
207        )))
208        .sequenced(StreamPosition { seq_num, timestamp })
209        .into_inner()
210    }
211
212    fn to_iter(
213        records: Vec<StoredSequencedRecord>,
214    ) -> impl Iterator<Item = Result<Metered<StoredSequencedRecord>, RecordDecodeError>> {
215        records.into_iter().map(Metered::from).map(Ok)
216    }
217
218    fn to_logical_iter(
219        records: Vec<SequencedRecord>,
220    ) -> impl Iterator<Item = Result<Metered<SequencedRecord>, RecordDecodeError>> {
221        records.into_iter().map(Metered::from).map(Ok)
222    }
223
224    fn to_stored_bytes_iter(
225        records: Vec<StoredSequencedRecord>,
226    ) -> impl Iterator<Item = Result<StoredSequencedBytes, RecordDecodeError>> {
227        records
228            .into_iter()
229            .map(|record| {
230                let (position, record) = record.into_parts();
231                Sequenced::new(position, (&record).metered().to_bytes())
232            })
233            .map(Ok)
234    }
235
236    fn assert_batch(batch: &RecordBatch, expected: &[StoredSequencedRecord], is_terminal: bool) {
237        assert_eq!(batch.is_terminal, is_terminal);
238        assert_eq!(batch.records.len(), expected.len());
239        let expected_size: usize = expected.iter().map(|r| r.metered_size()).sum();
240        assert_eq!(batch.records.metered_size(), expected_size);
241        for (actual, expected) in batch.records.iter().zip(expected.iter()) {
242            assert_eq!(actual, expected);
243        }
244    }
245
246    #[test]
247    fn collects_records_until_iterator_ends() {
248        let expected = vec![test_record(1, 10), test_record(2, 11), test_record(3, 12)];
249        let mut batcher = RecordBatcher::new(
250            to_iter(expected.clone()),
251            ReadLimit::Unbounded,
252            ReadUntil::Unbounded,
253        );
254        let batch = batcher.next().expect("batch expected").expect("ok batch");
255        assert_batch(&batch, &expected, false);
256        assert!(batcher.next().is_none());
257    }
258
259    #[test]
260    fn generic_batcher_collects_logical_records() {
261        let expected = vec![
262            test_logical_record(1, 10),
263            test_logical_record(2, 11),
264            test_logical_record(3, 12),
265        ];
266        let mut batcher = RecordBatcher::new(
267            to_logical_iter(expected.clone()),
268            ReadLimit::Unbounded,
269            ReadUntil::Unbounded,
270        );
271
272        let batch = batcher.next().expect("batch expected").expect("ok batch");
273        assert!(!batch.is_terminal);
274        assert_eq!(batch.records.len(), expected.len());
275        let expected_size: usize = expected.iter().map(|r| r.metered_size()).sum();
276        assert_eq!(batch.records.metered_size(), expected_size);
277        for (actual, expected) in batch.records.iter().zip(expected.iter()) {
278            assert_eq!(actual, expected);
279        }
280        assert!(batcher.next().is_none());
281    }
282
283    #[test]
284    fn stops_at_count_read_limit() {
285        let expected = vec![test_record(1, 10), test_record(2, 11), test_record(3, 12)];
286        let mut batcher = RecordBatcher::new(
287            to_iter(expected.clone()),
288            ReadLimit::Count(2),
289            ReadUntil::Unbounded,
290        );
291
292        let batch = batcher.next().expect("batch expected").expect("ok batch");
293        assert_batch(&batch, &expected[..2], true);
294        assert!(batcher.next().is_none());
295    }
296
297    #[test]
298    fn stops_at_byte_read_limit() {
299        let expected = vec![test_record(1, 10), test_record(2, 11)];
300        let first_size = expected[0].metered_size();
301        let mut batcher = RecordBatcher::new(
302            to_iter(expected.clone()),
303            ReadLimit::Bytes(first_size),
304            ReadUntil::Unbounded,
305        );
306
307        let batch = batcher.next().expect("batch expected").expect("ok batch");
308        assert_batch(&batch, &expected[..1], true);
309        assert!(batcher.next().is_none());
310    }
311
312    #[test]
313    fn stops_at_timestamp_limit() {
314        let expected = vec![test_record(1, 10), test_record(2, 19), test_record(3, 20)];
315        let mut batcher = RecordBatcher::new(
316            to_iter(expected.clone()),
317            ReadLimit::Unbounded,
318            ReadUntil::Timestamp(20),
319        );
320
321        let batch = batcher.next().expect("batch expected").expect("ok batch");
322        assert_batch(&batch, &expected[..2], true);
323        assert!(batcher.next().is_none());
324    }
325
326    #[test]
327    fn splits_batches_when_caps_are_hit() {
328        let mut records = Vec::with_capacity(caps::RECORD_BATCH_MAX.count + 1);
329        for index in 0..=(caps::RECORD_BATCH_MAX.count as SeqNum) {
330            records.push(test_record(index, index + 10));
331        }
332        let mut batcher = RecordBatcher::new(
333            to_iter(records.clone()),
334            ReadLimit::Unbounded,
335            ReadUntil::Unbounded,
336        );
337
338        let first_batch = batcher
339            .next()
340            .expect("first batch expected")
341            .expect("first batch ok");
342        assert_batch(
343            &first_batch,
344            &records[..caps::RECORD_BATCH_MAX.count],
345            false,
346        );
347
348        let second_batch = batcher
349            .next()
350            .expect("second batch expected")
351            .expect("second batch ok");
352        assert_batch(
353            &second_batch,
354            &records[caps::RECORD_BATCH_MAX.count..],
355            false,
356        );
357        assert!(batcher.next().is_none());
358    }
359
360    #[test]
361    fn splits_batches_when_byte_cap_is_hit() {
362        let records = vec![
363            test_large_record(1, 10, caps::RECORD_BATCH_MAX.bytes / 2 + 1),
364            test_large_record(2, 11, caps::RECORD_BATCH_MAX.bytes / 2 + 1),
365        ];
366        assert!(records[0].metered_size() <= caps::RECORD_BATCH_MAX.bytes);
367        assert!(records[1].metered_size() <= caps::RECORD_BATCH_MAX.bytes);
368        assert!(
369            records[0].metered_size() + records[1].metered_size() > caps::RECORD_BATCH_MAX.bytes
370        );
371
372        let mut batcher = RecordBatcher::new(
373            to_iter(records.clone()),
374            ReadLimit::Unbounded,
375            ReadUntil::Unbounded,
376        );
377
378        let first_batch = batcher
379            .next()
380            .expect("first batch expected")
381            .expect("first batch ok");
382        assert_batch(&first_batch, &records[..1], false);
383
384        let second_batch = batcher
385            .next()
386            .expect("second batch expected")
387            .expect("second batch ok");
388        assert_batch(&second_batch, &records[1..], false);
389        assert!(batcher.next().is_none());
390    }
391
392    #[test]
393    fn surfaces_decode_errors_after_draining_buffer() {
394        let records = vec![test_record(1, 10), test_record(2, 11)];
395        let invalid_data = Sequenced::new(
396            StreamPosition {
397                seq_num: 3,
398                timestamp: 12,
399            },
400            Bytes::new(),
401        );
402
403        let mut batcher = RecordBatcher::new(
404            StoredRecordIterator::new(
405                to_stored_bytes_iter(records.clone()).chain(std::iter::once(Ok(invalid_data))),
406            ),
407            ReadLimit::Unbounded,
408            ReadUntil::Unbounded,
409        );
410
411        let batch = batcher.next().expect("batch expected").expect("ok batch");
412        assert_batch(&batch, &records, false);
413
414        let error = batcher
415            .next()
416            .expect("error expected")
417            .expect_err("expected decode error");
418        assert!(matches!(error, RecordDecodeError::Truncated("MagicByte")));
419        assert!(batcher.next().is_none());
420    }
421
422    #[test]
423    fn surfaces_iterator_errors_immediately() {
424        let iterator = StoredRecordIterator::new(std::iter::once::<
425            Result<StoredSequencedBytes, RecordDecodeError>,
426        >(Err(RecordDecodeError::InvalidValue(
427            "test", "boom",
428        ))));
429        let mut batcher = RecordBatcher::new(iterator, ReadLimit::Unbounded, ReadUntil::Unbounded);
430
431        let error = batcher
432            .next()
433            .expect("error expected")
434            .expect_err("expected iterator error");
435        assert!(matches!(
436            error,
437            RecordDecodeError::InvalidValue("test", "boom")
438        ));
439        assert!(batcher.next().is_none());
440    }
441}