1use std::iter::FusedIterator;
2
3use bytes::Bytes;
4
5use super::InternalRecordError;
6use crate::{
7 caps,
8 read_extent::{EvaluatedReadLimit, ReadLimit, ReadUntil},
9 record::{Metered, MeteredSize, SequencedRecord, StreamPosition},
10};
11
12#[derive(Debug)]
13pub struct RecordBatch {
14 pub records: Metered<Vec<SequencedRecord>>,
15 pub is_terminal: bool,
16}
17
18pub struct RecordBatcher<I, E>
19where
20 I: Iterator<Item = Result<(StreamPosition, Bytes), E>>,
21 E: Into<InternalRecordError>,
22{
23 record_iterator: I,
24 buffered_records: Metered<Vec<SequencedRecord>>,
25 buffered_error: Option<InternalRecordError>,
26 read_limit: EvaluatedReadLimit,
27 until: ReadUntil,
28 is_terminated: bool,
29}
30
31fn make_records(read_limit: &EvaluatedReadLimit) -> Metered<Vec<SequencedRecord>> {
32 match read_limit {
33 EvaluatedReadLimit::Remaining(limit) => {
34 Metered::with_capacity(limit.count().map_or(caps::RECORD_BATCH_MAX.count, |n| {
35 n.min(caps::RECORD_BATCH_MAX.count)
36 }))
37 }
38 EvaluatedReadLimit::Exhausted => Metered::default(),
39 }
40}
41
42impl<I, E> RecordBatcher<I, E>
43where
44 I: Iterator<Item = Result<(StreamPosition, Bytes), E>>,
45 E: std::fmt::Debug + Into<InternalRecordError>,
46{
47 pub fn new(record_iterator: I, read_limit: ReadLimit, until: ReadUntil) -> Self {
48 let read_limit = read_limit.remaining(0, 0);
49 Self {
50 record_iterator,
51 buffered_records: make_records(&read_limit),
52 buffered_error: None,
53 read_limit,
54 until,
55 is_terminated: false,
56 }
57 }
58
59 fn iter_next(&mut self) -> Option<Result<RecordBatch, InternalRecordError>> {
60 let EvaluatedReadLimit::Remaining(remaining_limit) = self.read_limit else {
61 return None;
62 };
63
64 let mut stashed_record = None;
65 while self.buffered_error.is_none() {
66 match self.record_iterator.next() {
67 Some(Ok((position, data))) => {
68 let record = match Metered::try_from(data) {
69 Ok(record) => record.sequenced(position),
70 Err(err) => {
71 self.buffered_error = Some(err);
72 break;
73 }
74 };
75
76 if remaining_limit.deny(
77 self.buffered_records.len() + 1,
78 self.buffered_records.metered_size() + record.metered_size(),
79 ) || self.until.deny(position.timestamp)
80 {
81 self.read_limit = EvaluatedReadLimit::Exhausted;
82 break;
83 }
84
85 if self.buffered_records.len() == caps::RECORD_BATCH_MAX.count
86 || self.buffered_records.metered_size() + record.metered_size()
87 > caps::RECORD_BATCH_MAX.bytes
88 {
89 stashed_record = Some(record);
91 break;
92 }
93
94 self.buffered_records.push(record);
95 }
96 Some(Err(err)) => {
97 self.buffered_error = Some(err.into());
98 break;
99 }
100 None => {
101 break;
102 }
103 }
104 }
105 if !self.buffered_records.is_empty() {
106 self.read_limit = match self.read_limit {
107 EvaluatedReadLimit::Remaining(read_limit) => read_limit.remaining(
108 self.buffered_records.len(),
109 self.buffered_records.metered_size(),
110 ),
111 EvaluatedReadLimit::Exhausted => EvaluatedReadLimit::Exhausted,
112 };
113 let is_terminal = self.read_limit == EvaluatedReadLimit::Exhausted;
114 let records = std::mem::replace(
115 &mut self.buffered_records,
116 if is_terminal || self.buffered_error.is_some() {
117 Metered::default()
118 } else {
119 let mut buf = make_records(&self.read_limit);
120 if let Some(record) = stashed_record.take() {
121 buf.push(record);
122 }
123 buf
124 },
125 );
126 return Some(Ok(RecordBatch {
127 records,
128 is_terminal,
129 }));
130 }
131 if let Some(err) = self.buffered_error.take() {
132 return Some(Err(err));
133 }
134 None
135 }
136}
137
138impl<I, E> Iterator for RecordBatcher<I, E>
139where
140 I: Iterator<Item = Result<(StreamPosition, Bytes), E>>,
141 E: std::fmt::Debug + Into<InternalRecordError>,
142{
143 type Item = Result<RecordBatch, InternalRecordError>;
144
145 fn next(&mut self) -> Option<Self::Item> {
146 if self.is_terminated {
147 return None;
148 }
149 let item = self.iter_next();
150 self.is_terminated = matches!(&item, None | Some(Err(_)));
151 item
152 }
153}
154
155impl<I, E> FusedIterator for RecordBatcher<I, E>
156where
157 I: Iterator<Item = Result<(StreamPosition, Bytes), E>>,
158 E: std::fmt::Debug + Into<InternalRecordError>,
159{
160}
161
162#[cfg(test)]
163mod tests {
164 use bytes::Bytes;
165
166 use super::*;
167 use crate::{
168 caps,
169 read_extent::{ReadLimit, ReadUntil},
170 record::{
171 CommandRecord, Encodable, MeteredSize, Record, SeqNum, SequencedRecord, Timestamp,
172 },
173 };
174
175 fn test_record(seq_num: SeqNum, timestamp: Timestamp) -> SequencedRecord {
176 Record::Command(CommandRecord::Trim(seq_num))
177 .sequenced(StreamPosition { seq_num, timestamp })
178 }
179
180 fn to_iter(
181 records: Vec<SequencedRecord>,
182 ) -> impl Iterator<Item = Result<(StreamPosition, Bytes), InternalRecordError>> {
183 records
184 .into_iter()
185 .map(|SequencedRecord { position, record }| {
186 (position, Metered::from(record).as_ref().to_bytes())
187 })
188 .map(Ok)
189 }
190
191 fn assert_batch(batch: &RecordBatch, expected: &[SequencedRecord], is_terminal: bool) {
192 assert_eq!(batch.is_terminal, is_terminal);
193 assert_eq!(batch.records.len(), expected.len());
194 let expected_size: usize = expected.iter().map(|r| r.metered_size()).sum();
195 assert_eq!(batch.records.metered_size(), expected_size);
196 for (actual, expected) in batch.records.iter().zip(expected.iter()) {
197 assert_eq!(actual, expected);
198 }
199 }
200
201 #[test]
202 fn collects_records_until_iterator_ends() {
203 let expected = vec![test_record(1, 10), test_record(2, 11), test_record(3, 12)];
204 let mut batcher = RecordBatcher::new(
205 to_iter(expected.clone()),
206 ReadLimit::Unbounded,
207 ReadUntil::Unbounded,
208 );
209 let batch = batcher.next().expect("batch expected").expect("ok batch");
210 assert_batch(&batch, &expected, false);
211 assert!(batcher.next().is_none());
212 }
213
214 #[test]
215 fn stops_at_count_read_limit() {
216 let expected = vec![test_record(1, 10), test_record(2, 11), test_record(3, 12)];
217 let mut batcher = RecordBatcher::new(
218 to_iter(expected.clone()),
219 ReadLimit::Count(2),
220 ReadUntil::Unbounded,
221 );
222
223 let batch = batcher.next().expect("batch expected").expect("ok batch");
224 assert_batch(&batch, &expected[..2], true);
225 assert!(batcher.next().is_none());
226 }
227
228 #[test]
229 fn stops_at_byte_read_limit() {
230 let expected = vec![test_record(1, 10), test_record(2, 11)];
231 let first_size = expected[0].metered_size();
232 let mut batcher = RecordBatcher::new(
233 to_iter(expected.clone()),
234 ReadLimit::Bytes(first_size),
235 ReadUntil::Unbounded,
236 );
237
238 let batch = batcher.next().expect("batch expected").expect("ok batch");
239 assert_batch(&batch, &expected[..1], true);
240 assert!(batcher.next().is_none());
241 }
242
243 #[test]
244 fn stops_at_timestamp_limit() {
245 let expected = vec![test_record(1, 10), test_record(2, 19), test_record(3, 20)];
246 let mut batcher = RecordBatcher::new(
247 to_iter(expected.clone()),
248 ReadLimit::Unbounded,
249 ReadUntil::Timestamp(20),
250 );
251
252 let batch = batcher.next().expect("batch expected").expect("ok batch");
253 assert_batch(&batch, &expected[..2], true);
254 assert!(batcher.next().is_none());
255 }
256
257 #[test]
258 fn splits_batches_when_caps_are_hit() {
259 let mut records = Vec::with_capacity(caps::RECORD_BATCH_MAX.count + 1);
260 for index in 0..=(caps::RECORD_BATCH_MAX.count as SeqNum) {
261 records.push(test_record(index, index + 10));
262 }
263 let mut batcher = RecordBatcher::new(
264 to_iter(records.clone()),
265 ReadLimit::Unbounded,
266 ReadUntil::Unbounded,
267 );
268
269 let first_batch = batcher
270 .next()
271 .expect("first batch expected")
272 .expect("first batch ok");
273 assert_batch(
274 &first_batch,
275 &records[..caps::RECORD_BATCH_MAX.count],
276 false,
277 );
278
279 let second_batch = batcher
280 .next()
281 .expect("second batch expected")
282 .expect("second batch ok");
283 assert_batch(
284 &second_batch,
285 &records[caps::RECORD_BATCH_MAX.count..],
286 false,
287 );
288 assert!(batcher.next().is_none());
289 }
290
291 #[test]
292 fn surfaces_decode_errors_after_draining_buffer() {
293 let records = vec![test_record(1, 10), test_record(2, 11)];
294 let invalid_data = (
295 StreamPosition {
296 seq_num: 3,
297 timestamp: 12,
298 },
299 Bytes::new(),
300 );
301
302 let mut batcher = RecordBatcher::new(
303 to_iter(records.clone()).chain(std::iter::once(Ok(invalid_data))),
304 ReadLimit::Unbounded,
305 ReadUntil::Unbounded,
306 );
307
308 let batch = batcher.next().expect("batch expected").expect("ok batch");
309 assert_batch(&batch, &records, false);
310
311 let error = batcher
312 .next()
313 .expect("error expected")
314 .expect_err("expected decode error");
315 assert!(matches!(error, InternalRecordError::Truncated("MagicByte")));
316 assert!(batcher.next().is_none());
317 }
318
319 #[test]
320 fn surfaces_iterator_errors_immediately() {
321 let iterator = std::iter::once::<Result<(StreamPosition, Bytes), InternalRecordError>>(
322 Err(InternalRecordError::InvalidValue("test", "boom")),
323 );
324 let mut batcher = RecordBatcher::new(iterator, ReadLimit::Unbounded, ReadUntil::Unbounded);
325
326 let error = batcher
327 .next()
328 .expect("error expected")
329 .expect_err("expected iterator error");
330 assert!(matches!(
331 error,
332 InternalRecordError::InvalidValue("test", "boom")
333 ));
334 assert!(batcher.next().is_none());
335 }
336}