1use std::iter::FusedIterator;
2
3use s2_common::{
4 caps,
5 read_extent::{EvaluatedReadLimit, ReadLimit, ReadUntil},
6 record::{Metered, MeteredSize, Sequenced},
7};
8
9use super::StoredRecord;
10
11pub struct RecordBatch<T = StoredRecord>
12where
13 T: MeteredSize,
14{
15 pub records: Metered<Vec<Sequenced<T>>>,
16 pub is_terminal: bool,
17}
18
19impl<T> std::fmt::Debug for RecordBatch<T>
20where
21 T: MeteredSize,
22{
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("RecordBatch")
25 .field("num_records", &self.records.len())
26 .field("metered_size", &self.records.metered_size())
27 .field("is_terminal", &self.is_terminal)
28 .finish()
29 }
30}
31
32pub struct RecordBatcher<I, E, T>
33where
34 T: MeteredSize,
35 I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
36{
37 record_iterator: I,
38 buffered_records: Metered<Vec<Sequenced<T>>>,
39 buffered_error: Option<E>,
40 read_limit: EvaluatedReadLimit,
41 until: ReadUntil,
42 is_terminated: bool,
43}
44
45fn make_records<T>(read_limit: &EvaluatedReadLimit) -> Metered<Vec<Sequenced<T>>>
46where
47 T: MeteredSize,
48{
49 match read_limit {
50 EvaluatedReadLimit::Remaining(limit) => {
51 Metered::with_capacity(limit.count().map_or(caps::RECORD_BATCH_MAX.count, |n| {
52 n.min(caps::RECORD_BATCH_MAX.count)
53 }))
54 }
55 EvaluatedReadLimit::Exhausted => Metered::default(),
56 }
57}
58
59impl<I, E, T> RecordBatcher<I, E, T>
60where
61 T: MeteredSize,
62 I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
63{
64 pub fn new(record_iterator: I, read_limit: ReadLimit, until: ReadUntil) -> Self {
65 let read_limit = read_limit.remaining(0, 0);
66 Self {
67 record_iterator,
68 buffered_records: make_records(&read_limit),
69 buffered_error: None,
70 read_limit,
71 until,
72 is_terminated: false,
73 }
74 }
75
76 fn iter_next(&mut self) -> Option<Result<RecordBatch<T>, E>> {
77 let EvaluatedReadLimit::Remaining(remaining_limit) = self.read_limit else {
78 return None;
79 };
80
81 let mut stashed_record = None;
82 while self.buffered_error.is_none() {
83 match self.record_iterator.next() {
84 Some(Ok(record)) => {
85 if remaining_limit.deny(
86 self.buffered_records.len() + 1,
87 self.buffered_records.metered_size() + record.metered_size(),
88 ) || self.until.deny(record.position().timestamp)
89 {
90 self.read_limit = EvaluatedReadLimit::Exhausted;
91 break;
92 }
93
94 if self.buffered_records.len() == caps::RECORD_BATCH_MAX.count
95 || self.buffered_records.metered_size() + record.metered_size()
96 > caps::RECORD_BATCH_MAX.bytes
97 {
98 stashed_record = Some(record);
100 break;
101 }
102
103 self.buffered_records.push(record);
104 }
105 Some(Err(err)) => {
106 self.buffered_error = Some(err);
107 break;
108 }
109 None => {
110 break;
111 }
112 }
113 }
114 if !self.buffered_records.is_empty() {
115 self.read_limit = match self.read_limit {
116 EvaluatedReadLimit::Remaining(read_limit) => read_limit.remaining(
117 self.buffered_records.len(),
118 self.buffered_records.metered_size(),
119 ),
120 EvaluatedReadLimit::Exhausted => EvaluatedReadLimit::Exhausted,
121 };
122 let is_terminal = self.read_limit == EvaluatedReadLimit::Exhausted;
123 let records = std::mem::replace(
124 &mut self.buffered_records,
125 if is_terminal || self.buffered_error.is_some() {
126 Metered::default()
127 } else {
128 let mut buf = make_records(&self.read_limit);
129 if let Some(record) = stashed_record.take() {
130 buf.push(record);
131 }
132 buf
133 },
134 );
135 return Some(Ok(RecordBatch {
136 records,
137 is_terminal,
138 }));
139 }
140 if let Some(err) = self.buffered_error.take() {
141 return Some(Err(err));
142 }
143 None
144 }
145}
146
147impl<I, E, T> Iterator for RecordBatcher<I, E, T>
148where
149 T: MeteredSize,
150 I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
151{
152 type Item = Result<RecordBatch<T>, E>;
153
154 fn next(&mut self) -> Option<Self::Item> {
155 if self.is_terminated {
156 return None;
157 }
158 let item = self.iter_next();
159 self.is_terminated = matches!(&item, None | Some(Err(_)));
160 item
161 }
162}
163
164impl<I, E, T> FusedIterator for RecordBatcher<I, E, T>
165where
166 T: MeteredSize,
167 I: Iterator<Item = Result<Metered<Sequenced<T>>, E>>,
168{
169}
170
171#[cfg(test)]
172mod tests {
173 use bytes::Bytes;
174 use s2_common::{
175 caps,
176 read_extent::{ReadLimit, ReadUntil},
177 record::{
178 CommandRecord, EnvelopeRecord, Metered, MeteredExt, MeteredSize, Record, SeqNum,
179 Sequenced, SequencedRecord, StreamPosition, Timestamp,
180 },
181 };
182
183 use crate::record::{
184 RecordBatch, RecordBatcher, StoredRecord, StoredRecordDecodeError, StoredRecordIterator,
185 StoredSequencedBytes, StoredSequencedRecord, encode_stored_record,
186 };
187
188 fn test_logical_record(seq_num: SeqNum, timestamp: Timestamp) -> SequencedRecord {
189 Record::Command(CommandRecord::Trim(seq_num))
190 .metered()
191 .sequenced(StreamPosition { seq_num, timestamp })
192 .into_inner()
193 }
194
195 fn test_record(seq_num: SeqNum, timestamp: Timestamp) -> StoredSequencedRecord {
196 StoredRecord::from(Record::Command(CommandRecord::Trim(seq_num)))
197 .metered()
198 .sequenced(StreamPosition { seq_num, timestamp })
199 .into_inner()
200 }
201
202 fn test_large_record(
203 seq_num: SeqNum,
204 timestamp: Timestamp,
205 body_len: usize,
206 ) -> StoredSequencedRecord {
207 StoredRecord::from(Record::Envelope(
208 EnvelopeRecord::try_from_parts(vec![], Bytes::from(vec![0; body_len])).unwrap(),
209 ))
210 .metered()
211 .sequenced(StreamPosition { seq_num, timestamp })
212 .into_inner()
213 }
214
215 fn to_iter(
216 records: Vec<StoredSequencedRecord>,
217 ) -> impl Iterator<Item = Result<Metered<StoredSequencedRecord>, StoredRecordDecodeError>> {
218 records.into_iter().map(Metered::from).map(Ok)
219 }
220
221 fn to_logical_iter(
222 records: Vec<SequencedRecord>,
223 ) -> impl Iterator<Item = Result<Metered<SequencedRecord>, StoredRecordDecodeError>> {
224 records.into_iter().map(Metered::from).map(Ok)
225 }
226
227 fn to_stored_bytes_iter(
228 records: Vec<StoredSequencedRecord>,
229 ) -> impl Iterator<Item = Result<StoredSequencedBytes, StoredRecordDecodeError>> {
230 records
231 .into_iter()
232 .map(|record| {
233 let (position, record) = record.into_parts();
234 Sequenced::new(position, encode_stored_record((&record).metered()))
235 })
236 .map(Ok)
237 }
238
239 fn assert_batch(batch: &RecordBatch, expected: &[StoredSequencedRecord], is_terminal: bool) {
240 assert_eq!(batch.is_terminal, is_terminal);
241 assert_eq!(batch.records.len(), expected.len());
242 let expected_size: usize = expected.iter().map(|r| r.metered_size()).sum();
243 assert_eq!(batch.records.metered_size(), expected_size);
244 for (actual, expected) in batch.records.iter().zip(expected.iter()) {
245 assert_eq!(actual, expected);
246 }
247 }
248
249 #[test]
250 fn collects_records_until_iterator_ends() {
251 let expected = vec![test_record(1, 10), test_record(2, 11), test_record(3, 12)];
252 let mut batcher = RecordBatcher::new(
253 to_iter(expected.clone()),
254 ReadLimit::Unbounded,
255 ReadUntil::Unbounded,
256 );
257 let batch = batcher.next().expect("batch expected").expect("ok batch");
258 assert_batch(&batch, &expected, false);
259 assert!(batcher.next().is_none());
260 }
261
262 #[test]
263 fn generic_batcher_collects_logical_records() {
264 let expected = vec![
265 test_logical_record(1, 10),
266 test_logical_record(2, 11),
267 test_logical_record(3, 12),
268 ];
269 let mut batcher = RecordBatcher::new(
270 to_logical_iter(expected.clone()),
271 ReadLimit::Unbounded,
272 ReadUntil::Unbounded,
273 );
274
275 let batch = batcher.next().expect("batch expected").expect("ok batch");
276 assert!(!batch.is_terminal);
277 assert_eq!(batch.records.len(), expected.len());
278 let expected_size: usize = expected.iter().map(|r| r.metered_size()).sum();
279 assert_eq!(batch.records.metered_size(), expected_size);
280 for (actual, expected) in batch.records.iter().zip(expected.iter()) {
281 assert_eq!(actual, expected);
282 }
283 assert!(batcher.next().is_none());
284 }
285
286 #[test]
287 fn stops_at_count_read_limit() {
288 let expected = vec![test_record(1, 10), test_record(2, 11), test_record(3, 12)];
289 let mut batcher = RecordBatcher::new(
290 to_iter(expected.clone()),
291 ReadLimit::Count(2),
292 ReadUntil::Unbounded,
293 );
294
295 let batch = batcher.next().expect("batch expected").expect("ok batch");
296 assert_batch(&batch, &expected[..2], true);
297 assert!(batcher.next().is_none());
298 }
299
300 #[test]
301 fn stops_at_byte_read_limit() {
302 let expected = vec![test_record(1, 10), test_record(2, 11)];
303 let first_size = expected[0].metered_size();
304 let mut batcher = RecordBatcher::new(
305 to_iter(expected.clone()),
306 ReadLimit::Bytes(first_size),
307 ReadUntil::Unbounded,
308 );
309
310 let batch = batcher.next().expect("batch expected").expect("ok batch");
311 assert_batch(&batch, &expected[..1], true);
312 assert!(batcher.next().is_none());
313 }
314
315 #[test]
316 fn stops_at_timestamp_limit() {
317 let expected = vec![test_record(1, 10), test_record(2, 19), test_record(3, 20)];
318 let mut batcher = RecordBatcher::new(
319 to_iter(expected.clone()),
320 ReadLimit::Unbounded,
321 ReadUntil::Timestamp(20),
322 );
323
324 let batch = batcher.next().expect("batch expected").expect("ok batch");
325 assert_batch(&batch, &expected[..2], true);
326 assert!(batcher.next().is_none());
327 }
328
329 #[test]
330 fn splits_batches_when_caps_are_hit() {
331 let mut records = Vec::with_capacity(caps::RECORD_BATCH_MAX.count + 1);
332 for index in 0..=(caps::RECORD_BATCH_MAX.count as SeqNum) {
333 records.push(test_record(index, index + 10));
334 }
335 let mut batcher = RecordBatcher::new(
336 to_iter(records.clone()),
337 ReadLimit::Unbounded,
338 ReadUntil::Unbounded,
339 );
340
341 let first_batch = batcher
342 .next()
343 .expect("first batch expected")
344 .expect("first batch ok");
345 assert_batch(
346 &first_batch,
347 &records[..caps::RECORD_BATCH_MAX.count],
348 false,
349 );
350
351 let second_batch = batcher
352 .next()
353 .expect("second batch expected")
354 .expect("second batch ok");
355 assert_batch(
356 &second_batch,
357 &records[caps::RECORD_BATCH_MAX.count..],
358 false,
359 );
360 assert!(batcher.next().is_none());
361 }
362
363 #[test]
364 fn splits_batches_when_byte_cap_is_hit() {
365 let records = vec![
366 test_large_record(1, 10, caps::RECORD_BATCH_MAX.bytes / 2 + 1),
367 test_large_record(2, 11, caps::RECORD_BATCH_MAX.bytes / 2 + 1),
368 ];
369 assert!(records[0].metered_size() <= caps::RECORD_BATCH_MAX.bytes);
370 assert!(records[1].metered_size() <= caps::RECORD_BATCH_MAX.bytes);
371 assert!(
372 records[0].metered_size() + records[1].metered_size() > caps::RECORD_BATCH_MAX.bytes
373 );
374
375 let mut batcher = RecordBatcher::new(
376 to_iter(records.clone()),
377 ReadLimit::Unbounded,
378 ReadUntil::Unbounded,
379 );
380
381 let first_batch = batcher
382 .next()
383 .expect("first batch expected")
384 .expect("first batch ok");
385 assert_batch(&first_batch, &records[..1], false);
386
387 let second_batch = batcher
388 .next()
389 .expect("second batch expected")
390 .expect("second batch ok");
391 assert_batch(&second_batch, &records[1..], false);
392 assert!(batcher.next().is_none());
393 }
394
395 #[test]
396 fn surfaces_decode_errors_after_draining_buffer() {
397 let records = vec![test_record(1, 10), test_record(2, 11)];
398 let invalid_data = Sequenced::new(
399 StreamPosition {
400 seq_num: 3,
401 timestamp: 12,
402 },
403 Bytes::new(),
404 );
405
406 let mut batcher = RecordBatcher::new(
407 StoredRecordIterator::new(
408 to_stored_bytes_iter(records.clone()).chain(std::iter::once(Ok(invalid_data))),
409 ),
410 ReadLimit::Unbounded,
411 ReadUntil::Unbounded,
412 );
413
414 let batch = batcher.next().expect("batch expected").expect("ok batch");
415 assert_batch(&batch, &records, false);
416
417 let error = batcher
418 .next()
419 .expect("error expected")
420 .expect_err("expected decode error");
421 assert!(matches!(
422 error,
423 StoredRecordDecodeError::Truncated("MagicByte")
424 ));
425 assert!(batcher.next().is_none());
426 }
427
428 #[test]
429 fn surfaces_iterator_errors_immediately() {
430 let iterator = StoredRecordIterator::new(std::iter::once::<
431 Result<StoredSequencedBytes, StoredRecordDecodeError>,
432 >(Err(
433 StoredRecordDecodeError::InvalidValue("test", "boom"),
434 )));
435 let mut batcher = RecordBatcher::new(iterator, ReadLimit::Unbounded, ReadUntil::Unbounded);
436
437 let error = batcher
438 .next()
439 .expect("error expected")
440 .expect_err("expected iterator error");
441 assert!(matches!(
442 error,
443 StoredRecordDecodeError::InvalidValue("test", "boom")
444 ));
445 assert!(batcher.next().is_none());
446 }
447}