Skip to main content

s2_common/record/
iterator.rs

1use std::iter::FusedIterator;
2
3use super::{
4    Metered, RecordDecodeError, StoredRecord, StoredSequencedBytes, StoredSequencedRecord,
5};
6
7pub struct StoredRecordIterator<I> {
8    inner: I,
9}
10
11impl<I> StoredRecordIterator<I> {
12    pub fn new(inner: I) -> Self {
13        Self { inner }
14    }
15}
16
17impl<I, E> Iterator for StoredRecordIterator<I>
18where
19    I: Iterator<Item = Result<StoredSequencedBytes, E>>,
20    E: std::fmt::Debug + Into<RecordDecodeError>,
21{
22    type Item = Result<Metered<StoredSequencedRecord>, RecordDecodeError>;
23
24    fn next(&mut self) -> Option<Self::Item> {
25        self.inner.next().map(|result| {
26            let (position, bytes) = result.map_err(Into::into)?.into_parts();
27            let record: Metered<StoredRecord> = bytes.try_into()?;
28            Ok(record.sequenced(position))
29        })
30    }
31}
32
33impl<I, E> FusedIterator for StoredRecordIterator<I>
34where
35    I: FusedIterator<Item = Result<StoredSequencedBytes, E>>,
36    E: std::fmt::Debug + Into<RecordDecodeError>,
37{
38}
39
40#[cfg(test)]
41mod tests {
42    use bytes::{BufMut, Bytes, BytesMut};
43
44    use super::*;
45    use crate::record::{
46        Encodable, EncryptedRecord, EnvelopeRecord, Metered, MeteredExt, MeteredSize, Record,
47        SeqNum, Sequenced, StoredRecord, StoredSequencedBytes, StoredSequencedRecord,
48        StreamPosition, Timestamp,
49    };
50
51    fn test_stored_plaintext_record(
52        seq_num: SeqNum,
53        timestamp: Timestamp,
54        body: &'static [u8],
55    ) -> Metered<StoredSequencedRecord> {
56        StoredRecord::Plaintext(Record::Envelope(
57            EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(body)).unwrap(),
58        ))
59        .metered()
60        .sequenced(StreamPosition { seq_num, timestamp })
61    }
62
63    fn test_stored_encrypted_record(
64        seq_num: SeqNum,
65        timestamp: Timestamp,
66    ) -> Metered<StoredSequencedRecord> {
67        let metered_size = Record::Envelope(
68            EnvelopeRecord::try_from_parts(vec![], Bytes::from_static(b"secret payload")).unwrap(),
69        )
70        .metered_size();
71
72        let mut encoded = BytesMut::with_capacity(1 + 12 + 10 + 16);
73        encoded.put_u8(0x02);
74        encoded.put_bytes(0xAB, 12);
75        encoded.put_slice(b"ciphertext");
76        encoded.put_bytes(0xCD, 16);
77        let record = EncryptedRecord::try_from(encoded.freeze()).unwrap();
78
79        StoredRecord::Encrypted {
80            metered_size,
81            record,
82        }
83        .metered()
84        .sequenced(StreamPosition { seq_num, timestamp })
85    }
86
87    fn to_stored_bytes_iter(
88        records: Vec<Metered<StoredSequencedRecord>>,
89    ) -> impl Iterator<Item = Result<StoredSequencedBytes, RecordDecodeError>> {
90        records
91            .into_iter()
92            .map(|record| {
93                let (position, record) = record.into_parts();
94                Sequenced::new(position, record.as_ref().to_bytes())
95            })
96            .map(Ok)
97    }
98
99    #[test]
100    fn stored_iterator_decodes_plaintext_records() {
101        let expected = vec![
102            test_stored_plaintext_record(1, 10, b"p0"),
103            test_stored_plaintext_record(2, 11, b"p1"),
104        ];
105        let actual = StoredRecordIterator::new(to_stored_bytes_iter(expected.clone()))
106            .collect::<Result<Vec<_>, _>>()
107            .unwrap();
108
109        assert_eq!(actual, expected);
110    }
111
112    #[test]
113    fn stored_iterator_preserves_encrypted_records() {
114        let expected = vec![test_stored_encrypted_record(1, 10)];
115
116        let actual = StoredRecordIterator::new(to_stored_bytes_iter(expected.clone()))
117            .collect::<Result<Vec<_>, _>>()
118            .unwrap();
119
120        assert_eq!(actual, expected);
121    }
122
123    #[test]
124    fn stored_iterator_surfaces_decode_errors() {
125        let invalid_data = Sequenced::new(
126            StreamPosition {
127                seq_num: 1,
128                timestamp: 10,
129            },
130            Bytes::new(),
131        );
132        let mut iter = StoredRecordIterator::new(std::iter::once::<
133            Result<StoredSequencedBytes, RecordDecodeError>,
134        >(Ok(invalid_data)));
135
136        let error = iter
137            .next()
138            .expect("error expected")
139            .expect_err("expected error");
140        assert!(matches!(error, RecordDecodeError::Truncated("MagicByte")));
141        assert!(iter.next().is_none());
142    }
143
144    #[test]
145    fn stored_iterator_preserves_source_errors() {
146        let mut iter = StoredRecordIterator::new(std::iter::once::<
147            Result<StoredSequencedBytes, RecordDecodeError>,
148        >(Err(RecordDecodeError::InvalidValue(
149            "test", "boom",
150        ))));
151
152        let error = iter
153            .next()
154            .expect("error expected")
155            .expect_err("expected error");
156        assert!(matches!(
157            error,
158            RecordDecodeError::InvalidValue("test", "boom")
159        ));
160    }
161}