Skip to main content

tensogram_core/
iter.rs

1// (C) Copyright 2026- ECMWF and individual contributors.
2//
3// This software is licensed under the terms of the Apache Licence Version 2.0
4// which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5// In applying this licence, ECMWF does not waive the privileges and immunities
6// granted to it by virtue of its status as an intergovernmental organisation nor
7// does it submit to any jurisdiction.
8
9//! Iterator types for lazy traversal of messages and objects.
10//!
11//! # Quick start
12//!
13//! ```ignore
14//! // Iterate over messages in a byte buffer (zero-copy)
15//! for msg in tensogram_core::iter::messages(&buf) {
16//!     let (meta, objs) = tensogram_core::decode(msg, &Default::default())?;
17//! }
18//!
19//! // Iterate over objects in a single message
20//! for result in tensogram_core::iter::objects(&msg_bytes, Default::default())? {
21//!     let (descriptor, data) = result?;
22//! }
23//!
24//! // File-based lazy iteration
25//! let mut file = TensogramFile::open("data.tgm")?;
26//! for raw in file.iter()? {
27//!     let raw = raw?;
28//! }
29//! ```
30
31use std::path::PathBuf;
32
33use crate::decode::DecodeOptions;
34use crate::encode::build_pipeline_config;
35use crate::error::Result;
36use crate::framing;
37use crate::types::DataObjectDescriptor;
38
39/// Create a zero-copy iterator over messages in a byte buffer.
40///
41/// Calls [`framing::scan`] once to locate all message boundaries, then yields
42/// `&[u8]` slices pointing into the original buffer on each `next()` call.
43/// Garbage between valid messages is silently skipped.
44pub fn messages(buf: &[u8]) -> MessageIter<'_> {
45    let offsets = framing::scan(buf);
46    MessageIter {
47        buf,
48        offsets,
49        pos: 0,
50    }
51}
52
53/// Create an iterator that decodes each object in a message on demand.
54///
55/// Parses the frame header and metadata once, then decodes objects lazily via
56/// the full pipeline (encoding + filter + decompression).
57pub fn objects(buf: &[u8], options: DecodeOptions) -> Result<ObjectIter> {
58    let msg = framing::decode_message(buf)?;
59    let object_data: Vec<(DataObjectDescriptor, Vec<u8>)> = msg
60        .objects
61        .into_iter()
62        .map(|(desc, payload, _)| (desc, payload.to_vec()))
63        .collect();
64    Ok(ObjectIter {
65        objects: object_data,
66        index: 0,
67        options,
68    })
69}
70
71/// Return an iterator over the [`DataObjectDescriptor`]s in a message without
72/// decoding any payload data.
73pub fn objects_metadata(buf: &[u8]) -> Result<impl Iterator<Item = DataObjectDescriptor> + use<>> {
74    let msg = framing::decode_message(buf)?;
75    Ok(msg
76        .objects
77        .into_iter()
78        .map(|(desc, _, _)| desc)
79        .collect::<Vec<_>>()
80        .into_iter())
81}
82
83// ── MessageIter ──────────────────────────────────────────────────────────────
84
85/// Zero-copy iterator over messages in a byte buffer.
86///
87/// Yields `&[u8]` slices pointing into the original buffer.
88/// Implements [`ExactSizeIterator`] because all boundaries are known after the
89/// initial scan.
90pub struct MessageIter<'a> {
91    buf: &'a [u8],
92    offsets: Vec<(usize, usize)>,
93    pos: usize,
94}
95
96impl<'a> Iterator for MessageIter<'a> {
97    type Item = &'a [u8];
98
99    fn next(&mut self) -> Option<Self::Item> {
100        if self.pos >= self.offsets.len() {
101            return None;
102        }
103        let (offset, length) = self.offsets[self.pos];
104        self.pos += 1;
105        Some(&self.buf[offset..offset + length])
106    }
107
108    fn size_hint(&self) -> (usize, Option<usize>) {
109        let remaining = self.offsets.len() - self.pos;
110        (remaining, Some(remaining))
111    }
112}
113
114impl ExactSizeIterator for MessageIter<'_> {}
115
116// ── ObjectIter ───────────────────────────────────────────────────────────────
117
118/// Iterator over the decoded objects (tensors) in a single message.
119///
120/// Decodes each object through the full pipeline on demand.
121/// Yields `Result<(DataObjectDescriptor, Vec<u8>)>`.
122/// Implements [`ExactSizeIterator`].
123pub struct ObjectIter {
124    objects: Vec<(DataObjectDescriptor, Vec<u8>)>,
125    index: usize,
126    options: DecodeOptions,
127}
128
129impl Iterator for ObjectIter {
130    type Item = Result<(DataObjectDescriptor, Vec<u8>)>;
131
132    fn next(&mut self) -> Option<Self::Item> {
133        if self.index >= self.objects.len() {
134            return None;
135        }
136        let i = self.index;
137        self.index += 1;
138        let (ref desc, ref payload_bytes) = self.objects[i];
139
140        // Verify hash if requested
141        if self.options.verify_hash
142            && let Some(ref hash_desc) = desc.hash
143            && let Err(e) = crate::hash::verify_hash(payload_bytes, hash_desc)
144        {
145            return Some(Err(e));
146        }
147
148        let shape_product = match desc
149            .shape
150            .iter()
151            .try_fold(1u64, |acc, &x| acc.checked_mul(x))
152        {
153            Some(p) => p,
154            None => {
155                return Some(Err(crate::error::TensogramError::Metadata(
156                    "shape product overflow".to_string(),
157                )));
158            }
159        };
160        let num_elements = match usize::try_from(shape_product) {
161            Ok(n) => n,
162            Err(_) => {
163                return Some(Err(crate::error::TensogramError::Metadata(
164                    "element count overflows usize".to_string(),
165                )));
166            }
167        };
168
169        let config = match build_pipeline_config(desc, num_elements, desc.dtype) {
170            Ok(c) => c,
171            Err(e) => return Some(Err(e)),
172        };
173
174        let decoded = match tensogram_encodings::pipeline::decode_pipeline(
175            payload_bytes,
176            &config,
177            self.options.native_byte_order,
178        ) {
179            Ok(d) => d,
180            Err(e) => return Some(Err(crate::error::TensogramError::Encoding(e.to_string()))),
181        };
182
183        Some(Ok((desc.clone(), decoded)))
184    }
185
186    fn size_hint(&self) -> (usize, Option<usize>) {
187        let remaining = self.objects.len() - self.index;
188        (remaining, Some(remaining))
189    }
190}
191
192impl ExactSizeIterator for ObjectIter {}
193
194// ── FileMessageIter ──────────────────────────────────────────────────────────
195
196/// Lazy iterator over messages stored in a file.
197///
198/// Holds a persistent file handle and seeks to each message offset on demand,
199/// avoiding both full-file reads and repeated open/close syscalls.
200/// Constructed via [`TensogramFile::iter`].
201///
202/// [`TensogramFile::iter`]: crate::file::TensogramFile::iter
203pub struct FileMessageIter {
204    file: std::fs::File,
205    offsets: Vec<(usize, usize)>,
206    pos: usize,
207}
208
209impl FileMessageIter {
210    pub(crate) fn new(path: PathBuf, offsets: Vec<(usize, usize)>) -> Result<Self> {
211        let file = std::fs::File::open(&path)?;
212        Ok(Self {
213            file,
214            offsets,
215            pos: 0,
216        })
217    }
218}
219
220impl Iterator for FileMessageIter {
221    type Item = Result<Vec<u8>>;
222
223    fn next(&mut self) -> Option<Self::Item> {
224        use std::io::{Read, Seek, SeekFrom};
225
226        if self.pos >= self.offsets.len() {
227            return None;
228        }
229        let (offset, length) = self.offsets[self.pos];
230        self.pos += 1;
231
232        let result = (|| {
233            self.file.seek(SeekFrom::Start(offset as u64))?;
234            let mut buf = vec![0u8; length];
235            self.file.read_exact(&mut buf)?;
236            Ok(buf)
237        })();
238        Some(result)
239    }
240
241    fn size_hint(&self) -> (usize, Option<usize>) {
242        let remaining = self.offsets.len() - self.pos;
243        (remaining, Some(remaining))
244    }
245}
246
247impl ExactSizeIterator for FileMessageIter {}
248
249// ── Tests ────────────────────────────────────────────────────────────────────
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::decode::DecodeOptions;
255    use crate::dtype::Dtype;
256    use crate::encode::{EncodeOptions, encode};
257    use crate::types::{ByteOrder, DataObjectDescriptor, GlobalMetadata};
258    use std::collections::BTreeMap;
259
260    fn make_global_meta() -> GlobalMetadata {
261        GlobalMetadata {
262            version: 2,
263            extra: BTreeMap::new(),
264            ..Default::default()
265        }
266    }
267
268    fn make_descriptor(shape: Vec<u64>) -> DataObjectDescriptor {
269        let strides = {
270            let mut s = vec![1u64; shape.len()];
271            for i in (0..shape.len().saturating_sub(1)).rev() {
272                s[i] = s[i + 1] * shape[i + 1];
273            }
274            s
275        };
276        DataObjectDescriptor {
277            obj_type: "ntensor".to_string(),
278            ndim: shape.len() as u64,
279            shape,
280            strides,
281            dtype: Dtype::Float32,
282            byte_order: ByteOrder::Little,
283            encoding: "none".to_string(),
284            filter: "none".to_string(),
285            compression: "none".to_string(),
286            params: BTreeMap::new(),
287            hash: None,
288        }
289    }
290
291    fn encode_msg(shape: Vec<u64>, fill: u8) -> Vec<u8> {
292        let n: usize = shape.iter().product::<u64>() as usize * 4;
293        let data = vec![fill; n];
294        let meta = make_global_meta();
295        let desc = make_descriptor(shape);
296        encode(
297            &meta,
298            &[(&desc, &data)],
299            &EncodeOptions {
300                hash_algorithm: None,
301                ..Default::default()
302            },
303        )
304        .unwrap()
305    }
306
307    // ── MessageIter ──
308
309    #[test]
310    fn test_message_iter_empty_buffer() {
311        let buf = vec![];
312        let mut it = messages(&buf);
313        assert_eq!(it.len(), 0);
314        assert!(it.next().is_none());
315    }
316
317    #[test]
318    fn test_message_iter_single_message() {
319        let msg = encode_msg(vec![4], 1);
320        let collected: Vec<&[u8]> = messages(&msg).collect();
321        assert_eq!(collected.len(), 1);
322        assert_eq!(collected[0], msg.as_slice());
323    }
324
325    #[test]
326    fn test_message_iter_multiple_messages() {
327        let msg0 = encode_msg(vec![4], 0);
328        let msg1 = encode_msg(vec![4], 1);
329        let msg2 = encode_msg(vec![4], 2);
330        let mut buf = msg0.clone();
331        buf.extend_from_slice(&msg1);
332        buf.extend_from_slice(&msg2);
333
334        let collected: Vec<&[u8]> = messages(&buf).collect();
335        assert_eq!(collected.len(), 3);
336    }
337
338    #[test]
339    fn test_message_iter_with_garbage() {
340        let msg0 = encode_msg(vec![4], 0);
341        let msg1 = encode_msg(vec![4], 1);
342        let mut buf = vec![0xDE, 0xAD, 0xBE, 0xEF];
343        buf.extend_from_slice(&msg0);
344        buf.extend_from_slice(&[0xFF, 0xFF]);
345        buf.extend_from_slice(&msg1);
346        let collected: Vec<&[u8]> = messages(&buf).collect();
347        assert_eq!(collected.len(), 2);
348    }
349
350    #[test]
351    fn test_message_iter_yields_decodable_slices() {
352        let msg0 = encode_msg(vec![3], 0xAB);
353        let msg1 = encode_msg(vec![5], 0xCD);
354        let mut buf = msg0;
355        buf.extend_from_slice(&msg1);
356
357        for (i, slice) in messages(&buf).enumerate() {
358            let (meta, objs) = crate::decode::decode(slice, &DecodeOptions::default()).unwrap();
359            assert_eq!(meta.version, 2);
360            let expected_shape = if i == 0 { vec![3u64] } else { vec![5u64] };
361            assert_eq!(objs[0].0.shape, expected_shape);
362        }
363    }
364
365    // ── ObjectIter ──
366
367    #[test]
368    fn test_object_iter_zero_objects() {
369        let meta = make_global_meta();
370        let msg = encode(
371            &meta,
372            &[],
373            &EncodeOptions {
374                hash_algorithm: None,
375                ..Default::default()
376            },
377        )
378        .unwrap();
379        let mut it = objects(&msg, DecodeOptions::default()).unwrap();
380        assert_eq!(it.len(), 0);
381        assert!(it.next().is_none());
382    }
383
384    #[test]
385    fn test_object_iter_single_object() {
386        let msg = encode_msg(vec![4], 42);
387        let collected: Vec<_> = objects(&msg, DecodeOptions::default()).unwrap().collect();
388        assert_eq!(collected.len(), 1);
389        let (desc, data) = collected[0].as_ref().unwrap();
390        assert_eq!(desc.shape, vec![4u64]);
391        assert_eq!(data.len(), 16);
392        assert_eq!(data, &vec![42u8; 16]);
393    }
394
395    #[test]
396    fn test_object_iter_multi_object() {
397        let shape = vec![4u64];
398        let data0 = vec![0u8; 16];
399        let data1 = vec![1u8; 16];
400        let meta = make_global_meta();
401        let desc0 = make_descriptor(shape.clone());
402        let desc1 = make_descriptor(shape.clone());
403
404        let msg = encode(
405            &meta,
406            &[(&desc0, data0.as_slice()), (&desc1, data1.as_slice())],
407            &EncodeOptions {
408                hash_algorithm: None,
409                ..Default::default()
410            },
411        )
412        .unwrap();
413        let mut it = objects(&msg, DecodeOptions::default()).unwrap();
414        assert_eq!(it.len(), 2);
415        let (d0_desc, d0) = it.next().unwrap().unwrap();
416        assert_eq!(d0_desc.shape, shape);
417        assert_eq!(d0, data0);
418        let (d1_desc, d1) = it.next().unwrap().unwrap();
419        assert_eq!(d1_desc.shape, shape);
420        assert_eq!(d1, data1);
421        assert_eq!(it.len(), 0);
422        assert!(it.next().is_none());
423    }
424
425    #[test]
426    fn test_objects_metadata_only() {
427        let msg = encode_msg(vec![3, 4], 7);
428        let descs: Vec<DataObjectDescriptor> = objects_metadata(&msg).unwrap().collect();
429        assert_eq!(descs.len(), 1);
430        assert_eq!(descs[0].shape, vec![3u64, 4u64]);
431        assert_eq!(descs[0].dtype, Dtype::Float32);
432    }
433
434    // ── FileMessageIter ──
435
436    #[test]
437    fn test_file_iter_empty() {
438        let dir = tempfile::tempdir().unwrap();
439        let path = dir.path().join("empty.tgm");
440        std::fs::write(&path, []).unwrap();
441        let it = FileMessageIter::new(path, vec![]).unwrap();
442        assert_eq!(it.len(), 0);
443        assert_eq!(it.collect::<Vec<_>>().len(), 0);
444    }
445
446    #[test]
447    fn test_file_iter_three_messages() {
448        let dir = tempfile::tempdir().unwrap();
449        let path = dir.path().join("three.tgm");
450
451        let msg0 = encode_msg(vec![4], 0);
452        let msg1 = encode_msg(vec![4], 1);
453        let msg2 = encode_msg(vec![4], 2);
454        let mut content = msg0.clone();
455        content.extend_from_slice(&msg1);
456        content.extend_from_slice(&msg2);
457        std::fs::write(&path, &content).unwrap();
458
459        let offsets = framing::scan(&content);
460        let it = FileMessageIter::new(path, offsets).unwrap();
461        assert_eq!(it.len(), 3);
462        let collected: Vec<_> = it.collect();
463        assert_eq!(collected[0].as_ref().unwrap(), &msg0);
464        assert_eq!(collected[1].as_ref().unwrap(), &msg1);
465        assert_eq!(collected[2].as_ref().unwrap(), &msg2);
466    }
467
468    #[test]
469    fn test_file_iter_each_decodable() {
470        let dir = tempfile::tempdir().unwrap();
471        let path = dir.path().join("decode.tgm");
472
473        let msgs: Vec<Vec<u8>> = (0u8..3).map(|fill| encode_msg(vec![2], fill)).collect();
474        let content: Vec<u8> = msgs.iter().flatten().copied().collect();
475        std::fs::write(&path, &content).unwrap();
476
477        let offsets = framing::scan(&content);
478        for raw in FileMessageIter::new(path, offsets).unwrap() {
479            let raw = raw.unwrap();
480            let (meta, _) = crate::decode::decode(&raw, &DecodeOptions::default()).unwrap();
481            assert_eq!(meta.version, 2);
482        }
483    }
484}