Skip to main content

s2_storage/record/
batcher.rs

1use std::iter::FusedIterator;
2
3use s2_common::{
4    caps,
5    read_extent::{EvaluatedReadLimit, ReadLimit, ReadUntil},
6    record::{Metered, MeteredSize, Sequenced},
7};
8
9use super::StoredRecord;
10
11pub struct RecordBatch<T = StoredRecord>
12where
13    T: MeteredSize,
14{
15    pub records: Metered<Vec<Sequenced<T>>>,
16    pub is_terminal: bool,
17}
18
19impl<T> std::fmt::Debug for RecordBatch<T>
20where
21    T: MeteredSize,
22{
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        f.debug_struct("RecordBatch")
25            .field("num_records", &self.records.len())
26            .field("metered_size", &self.records.metered_size())
27            .field("is_terminal", &self.is_terminal)
28            .finish()
29    }
30}
31
32pub struct RecordBatcher<I, E, T>
33where
34    T: MeteredSize,
35    I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
36{
37    record_iterator: I,
38    buffered_records: Metered<Vec<Sequenced<T>>>,
39    buffered_error: Option<E>,
40    read_limit: EvaluatedReadLimit,
41    until: ReadUntil,
42    is_terminated: bool,
43}
44
45fn make_records<T>(read_limit: &EvaluatedReadLimit) -> Metered<Vec<Sequenced<T>>>
46where
47    T: MeteredSize,
48{
49    match read_limit {
50        EvaluatedReadLimit::Remaining(limit) => {
51            Metered::with_capacity(limit.count().map_or(caps::RECORD_BATCH_MAX.count, |n| {
52                n.min(caps::RECORD_BATCH_MAX.count)
53            }))
54        }
55        EvaluatedReadLimit::Exhausted => Metered::default(),
56    }
57}
58
59impl<I, E, T> RecordBatcher<I, E, T>
60where
61    T: MeteredSize,
62    I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
63{
64    pub fn new(record_iterator: I, read_limit: ReadLimit, until: ReadUntil) -> Self {
65        let read_limit = read_limit.remaining(0, 0);
66        Self {
67            record_iterator,
68            buffered_records: make_records(&read_limit),
69            buffered_error: None,
70            read_limit,
71            until,
72            is_terminated: false,
73        }
74    }
75
76    fn iter_next(&mut self) -> Option<Result<RecordBatch<T>, E>> {
77        let EvaluatedReadLimit::Remaining(remaining_limit) = self.read_limit else {
78            return None;
79        };
80
81        let mut stashed_record = None;
82        while self.buffered_error.is_none() {
83            match self.record_iterator.next() {
84                Some(Ok(record)) => {
85                    if remaining_limit.deny(
86                        self.buffered_records.len() + 1,
87                        self.buffered_records.metered_size() + record.metered_size(),
88                    ) || self.until.deny(record.position().timestamp)
89                    {
90                        self.read_limit = EvaluatedReadLimit::Exhausted;
91                        break;
92                    }
93
94                    if self.buffered_records.len() == caps::RECORD_BATCH_MAX.count
95                        || self.buffered_records.metered_size() + record.metered_size()
96                            > caps::RECORD_BATCH_MAX.bytes
97                    {
98                        // It would would violate the per-batch limits.
99                        stashed_record = Some(record);
100                        break;
101                    }
102
103                    self.buffered_records.push(record);
104                }
105                Some(Err(err)) => {
106                    self.buffered_error = Some(err);
107                    break;
108                }
109                None => {
110                    break;
111                }
112            }
113        }
114        if !self.buffered_records.is_empty() {
115            self.read_limit = match self.read_limit {
116                EvaluatedReadLimit::Remaining(read_limit) => read_limit.remaining(
117                    self.buffered_records.len(),
118                    self.buffered_records.metered_size(),
119                ),
120                EvaluatedReadLimit::Exhausted => EvaluatedReadLimit::Exhausted,
121            };
122            let is_terminal = self.read_limit == EvaluatedReadLimit::Exhausted;
123            let records = std::mem::replace(
124                &mut self.buffered_records,
125                if is_terminal || self.buffered_error.is_some() {
126                    Metered::default()
127                } else {
128                    let mut buf = make_records(&self.read_limit);
129                    if let Some(record) = stashed_record.take() {
130                        buf.push(record);
131                    }
132                    buf
133                },
134            );
135            return Some(Ok(RecordBatch {
136                records,
137                is_terminal,
138            }));
139        }
140        if let Some(err) = self.buffered_error.take() {
141            return Some(Err(err));
142        }
143        None
144    }
145}
146
147impl<I, E, T> Iterator for RecordBatcher<I, E, T>
148where
149    T: MeteredSize,
150    I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
151{
152    type Item = Result<RecordBatch<T>, E>;
153
154    fn next(&mut self) -> Option<Self::Item> {
155        if self.is_terminated {
156            return None;
157        }
158        let item = self.iter_next();
159        self.is_terminated = matches!(&item, None | Some(Err(_)));
160        item
161    }
162}
163
164impl<I, E, T> FusedIterator for RecordBatcher<I, E, T>
165where
166    T: MeteredSize,
167    I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
168{
169}
170
171#[cfg(test)]
172mod tests {
173    use bytes::Bytes;
174    use s2_common::{
175        caps,
176        read_extent::{ReadLimit, ReadUntil},
177        record::{
178            CommandRecord, EnvelopeRecord, Metered, MeteredExt, MeteredSize, Record, SeqNum,
179            Sequenced, SequencedRecord, StreamPosition, Timestamp,
180        },
181    };
182
183    use crate::record::{
184        RecordBatch, RecordBatcher, StoredRecord, StoredRecordDecodeError, StoredRecordIterator,
185        StoredSequencedBytes, StoredSequencedRecord, encode_stored_record,
186    };
187
188    fn test_logical_record(seq_num: SeqNum, timestamp: Timestamp) -> SequencedRecord {
189        Record::Command(CommandRecord::Trim(seq_num))
190            .metered()
191            .sequenced(StreamPosition { seq_num, timestamp })
192            .into_inner()
193    }
194
195    fn test_record(seq_num: SeqNum, timestamp: Timestamp) -> StoredSequencedRecord {
196        StoredRecord::from(Record::Command(CommandRecord::Trim(seq_num)))
197            .metered()
198            .sequenced(StreamPosition { seq_num, timestamp })
199            .into_inner()
200    }
201
202    fn test_large_record(
203        seq_num: SeqNum,
204        timestamp: Timestamp,
205        body_len: usize,
206    ) -> StoredSequencedRecord {
207        StoredRecord::from(Record::Envelope(
208            EnvelopeRecord::try_from_parts(vec![], Bytes::from(vec![0; body_len])).unwrap(),
209        ))
210        .metered()
211        .sequenced(StreamPosition { seq_num, timestamp })
212        .into_inner()
213    }
214
215    fn to_iter(
216        records: Vec<StoredSequencedRecord>,
217    ) -> impl Iterator<Item = Result<Metered<StoredSequencedRecord>, StoredRecordDecodeError>> {
218        records.into_iter().map(Metered::from).map(Ok)
219    }
220
221    fn to_logical_iter(
222        records: Vec<SequencedRecord>,
223    ) -> impl Iterator<Item = Result<Metered<SequencedRecord>, StoredRecordDecodeError>> {
224        records.into_iter().map(Metered::from).map(Ok)
225    }
226
227    fn to_stored_bytes_iter(
228        records: Vec<StoredSequencedRecord>,
229    ) -> impl Iterator<Item = Result<StoredSequencedBytes, StoredRecordDecodeError>> {
230        records
231            .into_iter()
232            .map(|record| {
233                let (position, record) = record.into_parts();
234                Sequenced::new(position, encode_stored_record((&record).metered()))
235            })
236            .map(Ok)
237    }
238
239    fn assert_batch(batch: &RecordBatch, expected: &[StoredSequencedRecord], is_terminal: bool) {
240        assert_eq!(batch.is_terminal, is_terminal);
241        assert_eq!(batch.records.len(), expected.len());
242        let expected_size: usize = expected.iter().map(|r| r.metered_size()).sum();
243        assert_eq!(batch.records.metered_size(), expected_size);
244        for (actual, expected) in batch.records.iter().zip(expected.iter()) {
245            assert_eq!(actual, expected);
246        }
247    }
248
249    #[test]
250    fn collects_records_until_iterator_ends() {
251        let expected = vec![test_record(1, 10), test_record(2, 11), test_record(3, 12)];
252        let mut batcher = RecordBatcher::new(
253            to_iter(expected.clone()),
254            ReadLimit::Unbounded,
255            ReadUntil::Unbounded,
256        );
257        let batch = batcher.next().expect("batch expected").expect("ok batch");
258        assert_batch(&batch, &expected, false);
259        assert!(batcher.next().is_none());
260    }
261
262    #[test]
263    fn generic_batcher_collects_logical_records() {
264        let expected = vec![
265            test_logical_record(1, 10),
266            test_logical_record(2, 11),
267            test_logical_record(3, 12),
268        ];
269        let mut batcher = RecordBatcher::new(
270            to_logical_iter(expected.clone()),
271            ReadLimit::Unbounded,
272            ReadUntil::Unbounded,
273        );
274
275        let batch = batcher.next().expect("batch expected").expect("ok batch");
276        assert!(!batch.is_terminal);
277        assert_eq!(batch.records.len(), expected.len());
278        let expected_size: usize = expected.iter().map(|r| r.metered_size()).sum();
279        assert_eq!(batch.records.metered_size(), expected_size);
280        for (actual, expected) in batch.records.iter().zip(expected.iter()) {
281            assert_eq!(actual, expected);
282        }
283        assert!(batcher.next().is_none());
284    }
285
286    #[test]
287    fn stops_at_count_read_limit() {
288        let expected = vec![test_record(1, 10), test_record(2, 11), test_record(3, 12)];
289        let mut batcher = RecordBatcher::new(
290            to_iter(expected.clone()),
291            ReadLimit::Count(2),
292            ReadUntil::Unbounded,
293        );
294
295        let batch = batcher.next().expect("batch expected").expect("ok batch");
296        assert_batch(&batch, &expected[..2], true);
297        assert!(batcher.next().is_none());
298    }
299
300    #[test]
301    fn stops_at_byte_read_limit() {
302        let expected = vec![test_record(1, 10), test_record(2, 11)];
303        let first_size = expected[0].metered_size();
304        let mut batcher = RecordBatcher::new(
305            to_iter(expected.clone()),
306            ReadLimit::Bytes(first_size),
307            ReadUntil::Unbounded,
308        );
309
310        let batch = batcher.next().expect("batch expected").expect("ok batch");
311        assert_batch(&batch, &expected[..1], true);
312        assert!(batcher.next().is_none());
313    }
314
315    #[test]
316    fn stops_at_timestamp_limit() {
317        let expected = vec![test_record(1, 10), test_record(2, 19), test_record(3, 20)];
318        let mut batcher = RecordBatcher::new(
319            to_iter(expected.clone()),
320            ReadLimit::Unbounded,
321            ReadUntil::Timestamp(20),
322        );
323
324        let batch = batcher.next().expect("batch expected").expect("ok batch");
325        assert_batch(&batch, &expected[..2], true);
326        assert!(batcher.next().is_none());
327    }
328
329    #[test]
330    fn splits_batches_when_caps_are_hit() {
331        let mut records = Vec::with_capacity(caps::RECORD_BATCH_MAX.count + 1);
332        for index in 0..=(caps::RECORD_BATCH_MAX.count as SeqNum) {
333            records.push(test_record(index, index + 10));
334        }
335        let mut batcher = RecordBatcher::new(
336            to_iter(records.clone()),
337            ReadLimit::Unbounded,
338            ReadUntil::Unbounded,
339        );
340
341        let first_batch = batcher
342            .next()
343            .expect("first batch expected")
344            .expect("first batch ok");
345        assert_batch(
346            &first_batch,
347            &records[..caps::RECORD_BATCH_MAX.count],
348            false,
349        );
350
351        let second_batch = batcher
352            .next()
353            .expect("second batch expected")
354            .expect("second batch ok");
355        assert_batch(
356            &second_batch,
357            &records[caps::RECORD_BATCH_MAX.count..],
358            false,
359        );
360        assert!(batcher.next().is_none());
361    }
362
363    #[test]
364    fn splits_batches_when_byte_cap_is_hit() {
365        let records = vec![
366            test_large_record(1, 10, caps::RECORD_BATCH_MAX.bytes / 2 + 1),
367            test_large_record(2, 11, caps::RECORD_BATCH_MAX.bytes / 2 + 1),
368        ];
369        assert!(records[0].metered_size() <= caps::RECORD_BATCH_MAX.bytes);
370        assert!(records[1].metered_size() <= caps::RECORD_BATCH_MAX.bytes);
371        assert!(
372            records[0].metered_size() + records[1].metered_size() > caps::RECORD_BATCH_MAX.bytes
373        );
374
375        let mut batcher = RecordBatcher::new(
376            to_iter(records.clone()),
377            ReadLimit::Unbounded,
378            ReadUntil::Unbounded,
379        );
380
381        let first_batch = batcher
382            .next()
383            .expect("first batch expected")
384            .expect("first batch ok");
385        assert_batch(&first_batch, &records[..1], false);
386
387        let second_batch = batcher
388            .next()
389            .expect("second batch expected")
390            .expect("second batch ok");
391        assert_batch(&second_batch, &records[1..], false);
392        assert!(batcher.next().is_none());
393    }
394
395    #[test]
396    fn surfaces_decode_errors_after_draining_buffer() {
397        let records = vec![test_record(1, 10), test_record(2, 11)];
398        let invalid_data = Sequenced::new(
399            StreamPosition {
400                seq_num: 3,
401                timestamp: 12,
402            },
403            Bytes::new(),
404        );
405
406        let mut batcher = RecordBatcher::new(
407            StoredRecordIterator::new(
408                to_stored_bytes_iter(records.clone()).chain(std::iter::once(Ok(invalid_data))),
409            ),
410            ReadLimit::Unbounded,
411            ReadUntil::Unbounded,
412        );
413
414        let batch = batcher.next().expect("batch expected").expect("ok batch");
415        assert_batch(&batch, &records, false);
416
417        let error = batcher
418            .next()
419            .expect("error expected")
420            .expect_err("expected decode error");
421        assert!(matches!(
422            error,
423            StoredRecordDecodeError::Truncated("MagicByte")
424        ));
425        assert!(batcher.next().is_none());
426    }
427
428    #[test]
429    fn surfaces_iterator_errors_immediately() {
430        let iterator = StoredRecordIterator::new(std::iter::once::<
431            Result<StoredSequencedBytes, StoredRecordDecodeError>,
432        >(Err(
433            StoredRecordDecodeError::InvalidValue("test", "boom"),
434        )));
435        let mut batcher = RecordBatcher::new(iterator, ReadLimit::Unbounded, ReadUntil::Unbounded);
436
437        let error = batcher
438            .next()
439            .expect("error expected")
440            .expect_err("expected iterator error");
441        assert!(matches!(
442            error,
443            StoredRecordDecodeError::InvalidValue("test", "boom")
444        ));
445        assert!(batcher.next().is_none());
446    }
447}