s2_common/record/
batcher.rs

1use std::iter::FusedIterator;
2
3use bytes::Bytes;
4
5use super::InternalRecordError;
6use crate::{
7    caps,
8    read_extent::{EvaluatedReadLimit, ReadLimit, ReadUntil},
9    record::{Metered, MeteredSize, SequencedRecord, StreamPosition},
10};
11
12#[derive(Debug)]
13pub struct RecordBatch {
14    pub records: Metered<Vec<SequencedRecord>>,
15    pub is_terminal: bool,
16}
17
18pub struct RecordBatcher<I, E>
19where
20    I: Iterator<Item = Result<(StreamPosition, Bytes), E>>,
21    E: Into<InternalRecordError>,
22{
23    record_iterator: I,
24    buffered_records: Metered<Vec<SequencedRecord>>,
25    buffered_error: Option<InternalRecordError>,
26    read_limit: EvaluatedReadLimit,
27    until: ReadUntil,
28    is_terminated: bool,
29}
30
31fn make_records(read_limit: &EvaluatedReadLimit) -> Metered<Vec<SequencedRecord>> {
32    match read_limit {
33        EvaluatedReadLimit::Remaining(limit) => {
34            Metered::with_capacity(limit.count().map_or(caps::RECORD_BATCH_MAX.count, |n| {
35                n.min(caps::RECORD_BATCH_MAX.count)
36            }))
37        }
38        EvaluatedReadLimit::Exhausted => Metered::default(),
39    }
40}
41
42impl<I, E> RecordBatcher<I, E>
43where
44    I: Iterator<Item = Result<(StreamPosition, Bytes), E>>,
45    E: std::fmt::Debug + Into<InternalRecordError>,
46{
47    pub fn new(record_iterator: I, read_limit: ReadLimit, until: ReadUntil) -> Self {
48        let read_limit = read_limit.remaining(0, 0);
49        Self {
50            record_iterator,
51            buffered_records: make_records(&read_limit),
52            buffered_error: None,
53            read_limit,
54            until,
55            is_terminated: false,
56        }
57    }
58
59    fn iter_next(&mut self) -> Option<Result<RecordBatch, InternalRecordError>> {
60        let EvaluatedReadLimit::Remaining(remaining_limit) = self.read_limit else {
61            return None;
62        };
63
64        let mut stashed_record = None;
65        while self.buffered_error.is_none() {
66            match self.record_iterator.next() {
67                Some(Ok((position, data))) => {
68                    let record = match Metered::try_from(data) {
69                        Ok(record) => record.sequenced(position),
70                        Err(err) => {
71                            self.buffered_error = Some(err);
72                            break;
73                        }
74                    };
75
76                    if remaining_limit.deny(
77                        self.buffered_records.len() + 1,
78                        self.buffered_records.metered_size() + record.metered_size(),
79                    ) || self.until.deny(position.timestamp)
80                    {
81                        self.read_limit = EvaluatedReadLimit::Exhausted;
82                        break;
83                    }
84
85                    if self.buffered_records.len() == caps::RECORD_BATCH_MAX.count
86                        || self.buffered_records.metered_size() + record.metered_size()
87                            > caps::RECORD_BATCH_MAX.bytes
88                    {
89                        // It would would violate the per-batch limits.
90                        stashed_record = Some(record);
91                        break;
92                    }
93
94                    self.buffered_records.push(record);
95                }
96                Some(Err(err)) => {
97                    self.buffered_error = Some(err.into());
98                    break;
99                }
100                None => {
101                    break;
102                }
103            }
104        }
105        if !self.buffered_records.is_empty() {
106            self.read_limit = match self.read_limit {
107                EvaluatedReadLimit::Remaining(read_limit) => read_limit.remaining(
108                    self.buffered_records.len(),
109                    self.buffered_records.metered_size(),
110                ),
111                EvaluatedReadLimit::Exhausted => EvaluatedReadLimit::Exhausted,
112            };
113            let is_terminal = self.read_limit == EvaluatedReadLimit::Exhausted;
114            let records = std::mem::replace(
115                &mut self.buffered_records,
116                if is_terminal || self.buffered_error.is_some() {
117                    Metered::default()
118                } else {
119                    let mut buf = make_records(&self.read_limit);
120                    if let Some(record) = stashed_record.take() {
121                        buf.push(record);
122                    }
123                    buf
124                },
125            );
126            return Some(Ok(RecordBatch {
127                records,
128                is_terminal,
129            }));
130        }
131        if let Some(err) = self.buffered_error.take() {
132            return Some(Err(err));
133        }
134        None
135    }
136}
137
138impl<I, E> Iterator for RecordBatcher<I, E>
139where
140    I: Iterator<Item = Result<(StreamPosition, Bytes), E>>,
141    E: std::fmt::Debug + Into<InternalRecordError>,
142{
143    type Item = Result<RecordBatch, InternalRecordError>;
144
145    fn next(&mut self) -> Option<Self::Item> {
146        if self.is_terminated {
147            return None;
148        }
149        let item = self.iter_next();
150        self.is_terminated = matches!(&item, None | Some(Err(_)));
151        item
152    }
153}
154
155impl<I, E> FusedIterator for RecordBatcher<I, E>
156where
157    I: Iterator<Item = Result<(StreamPosition, Bytes), E>>,
158    E: std::fmt::Debug + Into<InternalRecordError>,
159{
160}
161
162#[cfg(test)]
163mod tests {
164    use bytes::Bytes;
165
166    use super::*;
167    use crate::{
168        caps,
169        read_extent::{ReadLimit, ReadUntil},
170        record::{
171            CommandRecord, Encodable, MeteredSize, Record, SeqNum, SequencedRecord, Timestamp,
172        },
173    };
174
175    fn test_record(seq_num: SeqNum, timestamp: Timestamp) -> SequencedRecord {
176        Record::Command(CommandRecord::Trim(seq_num))
177            .sequenced(StreamPosition { seq_num, timestamp })
178    }
179
180    fn to_iter(
181        records: Vec<SequencedRecord>,
182    ) -> impl Iterator<Item = Result<(StreamPosition, Bytes), InternalRecordError>> {
183        records
184            .into_iter()
185            .map(|SequencedRecord { position, record }| {
186                (position, Metered::from(record).as_ref().to_bytes())
187            })
188            .map(Ok)
189    }
190
191    fn assert_batch(batch: &RecordBatch, expected: &[SequencedRecord], is_terminal: bool) {
192        assert_eq!(batch.is_terminal, is_terminal);
193        assert_eq!(batch.records.len(), expected.len());
194        let expected_size: usize = expected.iter().map(|r| r.metered_size()).sum();
195        assert_eq!(batch.records.metered_size(), expected_size);
196        for (actual, expected) in batch.records.iter().zip(expected.iter()) {
197            assert_eq!(actual, expected);
198        }
199    }
200
201    #[test]
202    fn collects_records_until_iterator_ends() {
203        let expected = vec![test_record(1, 10), test_record(2, 11), test_record(3, 12)];
204        let mut batcher = RecordBatcher::new(
205            to_iter(expected.clone()),
206            ReadLimit::Unbounded,
207            ReadUntil::Unbounded,
208        );
209        let batch = batcher.next().expect("batch expected").expect("ok batch");
210        assert_batch(&batch, &expected, false);
211        assert!(batcher.next().is_none());
212    }
213
214    #[test]
215    fn stops_at_count_read_limit() {
216        let expected = vec![test_record(1, 10), test_record(2, 11), test_record(3, 12)];
217        let mut batcher = RecordBatcher::new(
218            to_iter(expected.clone()),
219            ReadLimit::Count(2),
220            ReadUntil::Unbounded,
221        );
222
223        let batch = batcher.next().expect("batch expected").expect("ok batch");
224        assert_batch(&batch, &expected[..2], true);
225        assert!(batcher.next().is_none());
226    }
227
228    #[test]
229    fn stops_at_byte_read_limit() {
230        let expected = vec![test_record(1, 10), test_record(2, 11)];
231        let first_size = expected[0].metered_size();
232        let mut batcher = RecordBatcher::new(
233            to_iter(expected.clone()),
234            ReadLimit::Bytes(first_size),
235            ReadUntil::Unbounded,
236        );
237
238        let batch = batcher.next().expect("batch expected").expect("ok batch");
239        assert_batch(&batch, &expected[..1], true);
240        assert!(batcher.next().is_none());
241    }
242
243    #[test]
244    fn stops_at_timestamp_limit() {
245        let expected = vec![test_record(1, 10), test_record(2, 19), test_record(3, 20)];
246        let mut batcher = RecordBatcher::new(
247            to_iter(expected.clone()),
248            ReadLimit::Unbounded,
249            ReadUntil::Timestamp(20),
250        );
251
252        let batch = batcher.next().expect("batch expected").expect("ok batch");
253        assert_batch(&batch, &expected[..2], true);
254        assert!(batcher.next().is_none());
255    }
256
257    #[test]
258    fn splits_batches_when_caps_are_hit() {
259        let mut records = Vec::with_capacity(caps::RECORD_BATCH_MAX.count + 1);
260        for index in 0..=(caps::RECORD_BATCH_MAX.count as SeqNum) {
261            records.push(test_record(index, index + 10));
262        }
263        let mut batcher = RecordBatcher::new(
264            to_iter(records.clone()),
265            ReadLimit::Unbounded,
266            ReadUntil::Unbounded,
267        );
268
269        let first_batch = batcher
270            .next()
271            .expect("first batch expected")
272            .expect("first batch ok");
273        assert_batch(
274            &first_batch,
275            &records[..caps::RECORD_BATCH_MAX.count],
276            false,
277        );
278
279        let second_batch = batcher
280            .next()
281            .expect("second batch expected")
282            .expect("second batch ok");
283        assert_batch(
284            &second_batch,
285            &records[caps::RECORD_BATCH_MAX.count..],
286            false,
287        );
288        assert!(batcher.next().is_none());
289    }
290
291    #[test]
292    fn surfaces_decode_errors_after_draining_buffer() {
293        let records = vec![test_record(1, 10), test_record(2, 11)];
294        let invalid_data = (
295            StreamPosition {
296                seq_num: 3,
297                timestamp: 12,
298            },
299            Bytes::new(),
300        );
301
302        let mut batcher = RecordBatcher::new(
303            to_iter(records.clone()).chain(std::iter::once(Ok(invalid_data))),
304            ReadLimit::Unbounded,
305            ReadUntil::Unbounded,
306        );
307
308        let batch = batcher.next().expect("batch expected").expect("ok batch");
309        assert_batch(&batch, &records, false);
310
311        let error = batcher
312            .next()
313            .expect("error expected")
314            .expect_err("expected decode error");
315        assert!(matches!(error, InternalRecordError::Truncated("MagicByte")));
316        assert!(batcher.next().is_none());
317    }
318
319    #[test]
320    fn surfaces_iterator_errors_immediately() {
321        let iterator = std::iter::once::<Result<(StreamPosition, Bytes), InternalRecordError>>(
322            Err(InternalRecordError::InvalidValue("test", "boom")),
323        );
324        let mut batcher = RecordBatcher::new(iterator, ReadLimit::Unbounded, ReadUntil::Unbounded);
325
326        let error = batcher
327            .next()
328            .expect("error expected")
329            .expect_err("expected iterator error");
330        assert!(matches!(
331            error,
332            InternalRecordError::InvalidValue("test", "boom")
333        ));
334        assert!(batcher.next().is_none());
335    }
336}