1use std::path::PathBuf;
32
33use crate::decode::DecodeOptions;
34use crate::encode::build_pipeline_config;
35use crate::error::Result;
36use crate::framing;
37use crate::types::DataObjectDescriptor;
38
39pub fn messages(buf: &[u8]) -> MessageIter<'_> {
45 let offsets = framing::scan(buf);
46 MessageIter {
47 buf,
48 offsets,
49 pos: 0,
50 }
51}
52
53pub fn objects(buf: &[u8], options: DecodeOptions) -> Result<ObjectIter> {
58 let msg = framing::decode_message(buf)?;
59 let object_data: Vec<(DataObjectDescriptor, Vec<u8>)> = msg
60 .objects
61 .into_iter()
62 .map(|(desc, payload, _)| (desc, payload.to_vec()))
63 .collect();
64 Ok(ObjectIter {
65 objects: object_data,
66 index: 0,
67 options,
68 })
69}
70
71pub fn objects_metadata(buf: &[u8]) -> Result<impl Iterator<Item = DataObjectDescriptor> + use<>> {
74 let msg = framing::decode_message(buf)?;
75 Ok(msg
76 .objects
77 .into_iter()
78 .map(|(desc, _, _)| desc)
79 .collect::<Vec<_>>()
80 .into_iter())
81}
82
83pub struct MessageIter<'a> {
91 buf: &'a [u8],
92 offsets: Vec<(usize, usize)>,
93 pos: usize,
94}
95
96impl<'a> Iterator for MessageIter<'a> {
97 type Item = &'a [u8];
98
99 fn next(&mut self) -> Option<Self::Item> {
100 if self.pos >= self.offsets.len() {
101 return None;
102 }
103 let (offset, length) = self.offsets[self.pos];
104 self.pos += 1;
105 Some(&self.buf[offset..offset + length])
106 }
107
108 fn size_hint(&self) -> (usize, Option<usize>) {
109 let remaining = self.offsets.len() - self.pos;
110 (remaining, Some(remaining))
111 }
112}
113
114impl ExactSizeIterator for MessageIter<'_> {}
115
116pub struct ObjectIter {
124 objects: Vec<(DataObjectDescriptor, Vec<u8>)>,
125 index: usize,
126 options: DecodeOptions,
127}
128
129impl Iterator for ObjectIter {
130 type Item = Result<(DataObjectDescriptor, Vec<u8>)>;
131
132 fn next(&mut self) -> Option<Self::Item> {
133 if self.index >= self.objects.len() {
134 return None;
135 }
136 let i = self.index;
137 self.index += 1;
138 let (ref desc, ref payload_bytes) = self.objects[i];
139
140 if self.options.verify_hash
142 && let Some(ref hash_desc) = desc.hash
143 && let Err(e) = crate::hash::verify_hash(payload_bytes, hash_desc)
144 {
145 return Some(Err(e));
146 }
147
148 let shape_product = match desc
149 .shape
150 .iter()
151 .try_fold(1u64, |acc, &x| acc.checked_mul(x))
152 {
153 Some(p) => p,
154 None => {
155 return Some(Err(crate::error::TensogramError::Metadata(
156 "shape product overflow".to_string(),
157 )));
158 }
159 };
160 let num_elements = match usize::try_from(shape_product) {
161 Ok(n) => n,
162 Err(_) => {
163 return Some(Err(crate::error::TensogramError::Metadata(
164 "element count overflows usize".to_string(),
165 )));
166 }
167 };
168
169 let config = match build_pipeline_config(desc, num_elements, desc.dtype) {
170 Ok(c) => c,
171 Err(e) => return Some(Err(e)),
172 };
173
174 let decoded = match tensogram_encodings::pipeline::decode_pipeline(
175 payload_bytes,
176 &config,
177 self.options.native_byte_order,
178 ) {
179 Ok(d) => d,
180 Err(e) => return Some(Err(crate::error::TensogramError::Encoding(e.to_string()))),
181 };
182
183 Some(Ok((desc.clone(), decoded)))
184 }
185
186 fn size_hint(&self) -> (usize, Option<usize>) {
187 let remaining = self.objects.len() - self.index;
188 (remaining, Some(remaining))
189 }
190}
191
192impl ExactSizeIterator for ObjectIter {}
193
194pub struct FileMessageIter {
204 file: std::fs::File,
205 offsets: Vec<(usize, usize)>,
206 pos: usize,
207}
208
209impl FileMessageIter {
210 pub(crate) fn new(path: PathBuf, offsets: Vec<(usize, usize)>) -> Result<Self> {
211 let file = std::fs::File::open(&path)?;
212 Ok(Self {
213 file,
214 offsets,
215 pos: 0,
216 })
217 }
218}
219
220impl Iterator for FileMessageIter {
221 type Item = Result<Vec<u8>>;
222
223 fn next(&mut self) -> Option<Self::Item> {
224 use std::io::{Read, Seek, SeekFrom};
225
226 if self.pos >= self.offsets.len() {
227 return None;
228 }
229 let (offset, length) = self.offsets[self.pos];
230 self.pos += 1;
231
232 let result = (|| {
233 self.file.seek(SeekFrom::Start(offset as u64))?;
234 let mut buf = vec![0u8; length];
235 self.file.read_exact(&mut buf)?;
236 Ok(buf)
237 })();
238 Some(result)
239 }
240
241 fn size_hint(&self) -> (usize, Option<usize>) {
242 let remaining = self.offsets.len() - self.pos;
243 (remaining, Some(remaining))
244 }
245}
246
247impl ExactSizeIterator for FileMessageIter {}
248
249#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::decode::DecodeOptions;
255 use crate::dtype::Dtype;
256 use crate::encode::{EncodeOptions, encode};
257 use crate::types::{ByteOrder, DataObjectDescriptor, GlobalMetadata};
258 use std::collections::BTreeMap;
259
260 fn make_global_meta() -> GlobalMetadata {
261 GlobalMetadata {
262 version: 2,
263 extra: BTreeMap::new(),
264 ..Default::default()
265 }
266 }
267
268 fn make_descriptor(shape: Vec<u64>) -> DataObjectDescriptor {
269 let strides = {
270 let mut s = vec![1u64; shape.len()];
271 for i in (0..shape.len().saturating_sub(1)).rev() {
272 s[i] = s[i + 1] * shape[i + 1];
273 }
274 s
275 };
276 DataObjectDescriptor {
277 obj_type: "ntensor".to_string(),
278 ndim: shape.len() as u64,
279 shape,
280 strides,
281 dtype: Dtype::Float32,
282 byte_order: ByteOrder::Little,
283 encoding: "none".to_string(),
284 filter: "none".to_string(),
285 compression: "none".to_string(),
286 params: BTreeMap::new(),
287 hash: None,
288 }
289 }
290
291 fn encode_msg(shape: Vec<u64>, fill: u8) -> Vec<u8> {
292 let n: usize = shape.iter().product::<u64>() as usize * 4;
293 let data = vec![fill; n];
294 let meta = make_global_meta();
295 let desc = make_descriptor(shape);
296 encode(
297 &meta,
298 &[(&desc, &data)],
299 &EncodeOptions {
300 hash_algorithm: None,
301 ..Default::default()
302 },
303 )
304 .unwrap()
305 }
306
307 #[test]
310 fn test_message_iter_empty_buffer() {
311 let buf = vec![];
312 let mut it = messages(&buf);
313 assert_eq!(it.len(), 0);
314 assert!(it.next().is_none());
315 }
316
317 #[test]
318 fn test_message_iter_single_message() {
319 let msg = encode_msg(vec![4], 1);
320 let collected: Vec<&[u8]> = messages(&msg).collect();
321 assert_eq!(collected.len(), 1);
322 assert_eq!(collected[0], msg.as_slice());
323 }
324
325 #[test]
326 fn test_message_iter_multiple_messages() {
327 let msg0 = encode_msg(vec![4], 0);
328 let msg1 = encode_msg(vec![4], 1);
329 let msg2 = encode_msg(vec![4], 2);
330 let mut buf = msg0.clone();
331 buf.extend_from_slice(&msg1);
332 buf.extend_from_slice(&msg2);
333
334 let collected: Vec<&[u8]> = messages(&buf).collect();
335 assert_eq!(collected.len(), 3);
336 }
337
338 #[test]
339 fn test_message_iter_with_garbage() {
340 let msg0 = encode_msg(vec![4], 0);
341 let msg1 = encode_msg(vec![4], 1);
342 let mut buf = vec![0xDE, 0xAD, 0xBE, 0xEF];
343 buf.extend_from_slice(&msg0);
344 buf.extend_from_slice(&[0xFF, 0xFF]);
345 buf.extend_from_slice(&msg1);
346 let collected: Vec<&[u8]> = messages(&buf).collect();
347 assert_eq!(collected.len(), 2);
348 }
349
350 #[test]
351 fn test_message_iter_yields_decodable_slices() {
352 let msg0 = encode_msg(vec![3], 0xAB);
353 let msg1 = encode_msg(vec![5], 0xCD);
354 let mut buf = msg0;
355 buf.extend_from_slice(&msg1);
356
357 for (i, slice) in messages(&buf).enumerate() {
358 let (meta, objs) = crate::decode::decode(slice, &DecodeOptions::default()).unwrap();
359 assert_eq!(meta.version, 2);
360 let expected_shape = if i == 0 { vec![3u64] } else { vec![5u64] };
361 assert_eq!(objs[0].0.shape, expected_shape);
362 }
363 }
364
365 #[test]
368 fn test_object_iter_zero_objects() {
369 let meta = make_global_meta();
370 let msg = encode(
371 &meta,
372 &[],
373 &EncodeOptions {
374 hash_algorithm: None,
375 ..Default::default()
376 },
377 )
378 .unwrap();
379 let mut it = objects(&msg, DecodeOptions::default()).unwrap();
380 assert_eq!(it.len(), 0);
381 assert!(it.next().is_none());
382 }
383
384 #[test]
385 fn test_object_iter_single_object() {
386 let msg = encode_msg(vec![4], 42);
387 let collected: Vec<_> = objects(&msg, DecodeOptions::default()).unwrap().collect();
388 assert_eq!(collected.len(), 1);
389 let (desc, data) = collected[0].as_ref().unwrap();
390 assert_eq!(desc.shape, vec![4u64]);
391 assert_eq!(data.len(), 16);
392 assert_eq!(data, &vec![42u8; 16]);
393 }
394
395 #[test]
396 fn test_object_iter_multi_object() {
397 let shape = vec![4u64];
398 let data0 = vec![0u8; 16];
399 let data1 = vec![1u8; 16];
400 let meta = make_global_meta();
401 let desc0 = make_descriptor(shape.clone());
402 let desc1 = make_descriptor(shape.clone());
403
404 let msg = encode(
405 &meta,
406 &[(&desc0, data0.as_slice()), (&desc1, data1.as_slice())],
407 &EncodeOptions {
408 hash_algorithm: None,
409 ..Default::default()
410 },
411 )
412 .unwrap();
413 let mut it = objects(&msg, DecodeOptions::default()).unwrap();
414 assert_eq!(it.len(), 2);
415 let (d0_desc, d0) = it.next().unwrap().unwrap();
416 assert_eq!(d0_desc.shape, shape);
417 assert_eq!(d0, data0);
418 let (d1_desc, d1) = it.next().unwrap().unwrap();
419 assert_eq!(d1_desc.shape, shape);
420 assert_eq!(d1, data1);
421 assert_eq!(it.len(), 0);
422 assert!(it.next().is_none());
423 }
424
425 #[test]
426 fn test_objects_metadata_only() {
427 let msg = encode_msg(vec![3, 4], 7);
428 let descs: Vec<DataObjectDescriptor> = objects_metadata(&msg).unwrap().collect();
429 assert_eq!(descs.len(), 1);
430 assert_eq!(descs[0].shape, vec![3u64, 4u64]);
431 assert_eq!(descs[0].dtype, Dtype::Float32);
432 }
433
434 #[test]
437 fn test_file_iter_empty() {
438 let dir = tempfile::tempdir().unwrap();
439 let path = dir.path().join("empty.tgm");
440 std::fs::write(&path, []).unwrap();
441 let it = FileMessageIter::new(path, vec![]).unwrap();
442 assert_eq!(it.len(), 0);
443 assert_eq!(it.collect::<Vec<_>>().len(), 0);
444 }
445
446 #[test]
447 fn test_file_iter_three_messages() {
448 let dir = tempfile::tempdir().unwrap();
449 let path = dir.path().join("three.tgm");
450
451 let msg0 = encode_msg(vec![4], 0);
452 let msg1 = encode_msg(vec![4], 1);
453 let msg2 = encode_msg(vec![4], 2);
454 let mut content = msg0.clone();
455 content.extend_from_slice(&msg1);
456 content.extend_from_slice(&msg2);
457 std::fs::write(&path, &content).unwrap();
458
459 let offsets = framing::scan(&content);
460 let it = FileMessageIter::new(path, offsets).unwrap();
461 assert_eq!(it.len(), 3);
462 let collected: Vec<_> = it.collect();
463 assert_eq!(collected[0].as_ref().unwrap(), &msg0);
464 assert_eq!(collected[1].as_ref().unwrap(), &msg1);
465 assert_eq!(collected[2].as_ref().unwrap(), &msg2);
466 }
467
468 #[test]
469 fn test_file_iter_each_decodable() {
470 let dir = tempfile::tempdir().unwrap();
471 let path = dir.path().join("decode.tgm");
472
473 let msgs: Vec<Vec<u8>> = (0u8..3).map(|fill| encode_msg(vec![2], fill)).collect();
474 let content: Vec<u8> = msgs.iter().flatten().copied().collect();
475 std::fs::write(&path, &content).unwrap();
476
477 let offsets = framing::scan(&content);
478 for raw in FileMessageIter::new(path, offsets).unwrap() {
479 let raw = raw.unwrap();
480 let (meta, _) = crate::decode::decode(&raw, &DecodeOptions::default()).unwrap();
481 assert_eq!(meta.version, 2);
482 }
483 }
484}