Skip to main content

tensogram_core/
streaming.rs

1// (C) Copyright 2026- ECMWF and individual contributors.
2//
3// This software is licensed under the terms of the Apache Licence Version 2.0
4// which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5// In applying this licence, ECMWF does not waive the privileges and immunities
6// granted to it by virtue of its status as an intergovernmental organisation nor
7// does it submit to any jurisdiction.
8
9use std::collections::BTreeMap;
10use std::io::Write;
11
12use crate::encode::{
13    EncodeOptions, build_pipeline_config, populate_base_entries, populate_reserved_provenance,
14    validate_no_szip_offsets_for_non_szip, validate_object, validate_szip_block_offsets,
15};
16use crate::error::{Result, TensogramError};
17use crate::framing::EncodedObject;
18use crate::hash::{HashAlgorithm, compute_hash};
19use crate::metadata::{self, RESERVED_KEY};
20use crate::types::{DataObjectDescriptor, GlobalMetadata, HashDescriptor, HashFrame, IndexFrame};
21use crate::wire::{
22    FRAME_END, FRAME_HEADER_SIZE, FrameHeader, FrameType, MessageFlags, PREAMBLE_SIZE, Postamble,
23    Preamble,
24};
25use tensogram_encodings::pipeline;
26
27/// A streaming encoder that writes Tensogram frames progressively to a sink.
28///
29/// Unlike [`crate::encode::encode`], which builds the entire message in memory,
30/// `StreamingEncoder` writes each data object frame immediately. This allows
31/// encoding to a socket or pipe without buffering the full message.
32///
33/// The trade-off is that header-based index and hash frames are not possible;
34/// instead, these are written as footer frames when [`finish`](StreamingEncoder::finish)
35/// is called.
36///
37/// # Example
38/// ```no_run
39/// use std::io::BufWriter;
40/// use std::fs::File;
41/// use tensogram_core::streaming::StreamingEncoder;
42/// use tensogram_core::{GlobalMetadata, EncodeOptions};
43///
44/// let file = BufWriter::new(File::create("output.tgm").unwrap());
45/// let meta = GlobalMetadata::default();
46/// let mut enc = StreamingEncoder::new(file, &meta, &EncodeOptions::default()).unwrap();
47/// // enc.write_object(&desc, &data).unwrap();
48/// // enc.finish().unwrap();
49/// ```
50pub struct StreamingEncoder<W: Write> {
51    writer: W,
52    /// Byte offsets of each data object frame from message start.
53    object_offsets: Vec<u64>,
54    /// Total byte length of each data object frame, excluding alignment padding.
55    object_lengths: Vec<u64>,
56    /// Per-object hash entries: (hash_type, hash_value).
57    hash_entries: Vec<Option<(String, String)>>,
58    /// Descriptors of completed objects (payloads not retained) — used to
59    /// populate per-object payload entries in the footer metadata frame.
60    completed_objects: Vec<EncodedObject>,
61    /// Total bytes written so far.
62    bytes_written: u64,
63    /// Hash algorithm to use for payload integrity.
64    hash_algorithm: Option<HashAlgorithm>,
65    /// Original global metadata — re-used to build the footer metadata frame.
66    global_meta: GlobalMetadata,
67    /// True when a PrecederMetadata frame has been written but the
68    /// corresponding DataObject has not yet been written.
69    pending_preceder: bool,
70    /// Per-object preceder payloads — stored so the footer metadata can
71    /// include all per-object metadata (for decoders that skip preceders).
72    preceder_payloads: Vec<Option<BTreeMap<String, ciborium::Value>>>,
73    /// Intra-codec thread budget resolved from `EncodeOptions.threads`
74    /// at construction time.  Passed through to every `write_object`
75    /// pipeline call; axis A is not applicable in streaming mode
76    /// because each `write_object` is a separate caller-paced event.
77    intra_codec_threads: u32,
78    /// Snapshot of the parallel-threshold option for the same reason.
79    parallel_threshold_bytes: Option<usize>,
80}
81
82impl<W: Write> StreamingEncoder<W> {
83    /// Begin a new streaming message.
84    ///
85    /// Writes the preamble (with `total_length = 0` for streaming mode)
86    /// and a header metadata frame containing the global metadata.
87    pub fn new(
88        mut writer: W,
89        global_meta: &GlobalMetadata,
90        options: &EncodeOptions,
91    ) -> Result<Self> {
92        let meta_cbor = metadata::global_metadata_to_cbor(global_meta)?;
93
94        // Streaming preamble: total_length=0 signals unknown length at write time.
95        // Always set PRECEDER_METADATA in streaming mode — the flag is advisory
96        // and decoders handle the absence of actual preceder frames gracefully.
97        let mut flags = MessageFlags::default();
98        flags.set(MessageFlags::HEADER_METADATA);
99        flags.set(MessageFlags::FOOTER_METADATA);
100        flags.set(MessageFlags::FOOTER_INDEX);
101        flags.set(MessageFlags::PRECEDER_METADATA);
102        if options.hash_algorithm.is_some() {
103            flags.set(MessageFlags::FOOTER_HASHES);
104        }
105
106        let preamble = Preamble {
107            version: 2,
108            flags,
109            reserved: 0,
110            total_length: 0,
111        };
112        let preamble_bytes = preamble_to_bytes(&preamble);
113        writer.write_all(&preamble_bytes)?;
114        let mut bytes_written = PREAMBLE_SIZE as u64;
115
116        // Write header metadata frame
117        let frame_bytes = build_frame(FrameType::HeaderMetadata, 1, 0, &meta_cbor);
118        writer.write_all(&frame_bytes)?;
119        bytes_written += frame_bytes.len() as u64;
120
121        write_padding(&mut writer, &mut bytes_written)?;
122
123        // Snapshot the thread budget now so that mid-message changes to
124        // TENSOGRAM_THREADS don't leak in between write_object calls —
125        // one message is deterministic.
126        let intra_codec_threads = crate::parallel::resolve_budget(options.threads);
127
128        Ok(Self {
129            writer,
130            object_offsets: Vec::new(),
131            object_lengths: Vec::new(),
132            hash_entries: Vec::new(),
133            completed_objects: Vec::new(),
134            bytes_written,
135            hash_algorithm: options.hash_algorithm,
136            global_meta: global_meta.clone(),
137            pending_preceder: false,
138            preceder_payloads: Vec::new(),
139            intra_codec_threads,
140            parallel_threshold_bytes: options.parallel_threshold_bytes,
141        })
142    }
143
144    /// Write a PrecederMetadata frame for the next data object.
145    ///
146    /// The `metadata` map becomes `base[0]` in a `GlobalMetadata` CBOR
147    /// wrapper.  Must be followed by exactly one
148    /// [`write_object`](Self::write_object) or
149    /// [`write_object_pre_encoded`](Self::write_object_pre_encoded) call
150    /// before another `write_preceder` or [`finish`](Self::finish).
151    pub fn write_preceder(&mut self, metadata: BTreeMap<String, ciborium::Value>) -> Result<()> {
152        if self.pending_preceder {
153            return Err(TensogramError::Framing(
154                "write_preceder called twice without an intervening write_object/write_object_pre_encoded".to_string(),
155            ));
156        }
157
158        // Reject _reserved_ in preceder metadata — this namespace is library-managed
159        // and would collide with the encoder's auto-populated _reserved_.tensor.
160        if metadata.contains_key(RESERVED_KEY) {
161            return Err(TensogramError::Metadata(format!(
162                "client code must not write '{RESERVED_KEY}' in preceder metadata; \
163                     this field is populated by the library"
164            )));
165        }
166
167        let preceder_meta = GlobalMetadata {
168            version: self.global_meta.version,
169            base: vec![metadata.clone()],
170            ..Default::default()
171        };
172        let cbor = crate::metadata::global_metadata_to_cbor(&preceder_meta)?;
173        let frame_bytes = build_frame(FrameType::PrecederMetadata, 1, 0, &cbor);
174        self.writer.write_all(&frame_bytes)?;
175        self.bytes_written += frame_bytes.len() as u64;
176
177        write_padding(&mut self.writer, &mut self.bytes_written)?;
178
179        self.pending_preceder = true;
180        // Store for inclusion in footer metadata
181        self.preceder_payloads.push(Some(metadata));
182        Ok(())
183    }
184
185    /// Encode and write a single data object frame.
186    ///
187    /// The descriptor's encoding/filter/compression pipeline is applied,
188    /// the payload is hashed (if configured), and the frame is written
189    /// immediately — no buffering.
190    ///
191    /// When `EncodeOptions.threads > 0` was passed to
192    /// [`StreamingEncoder::new`], the pipeline call may use up to that
193    /// many threads internally (axis B).  Axis A is not available in
194    /// streaming mode — each `write_object` is a caller-paced event
195    /// with no cross-object parallelism opportunity.
196    pub fn write_object(&mut self, desc: &DataObjectDescriptor, data: &[u8]) -> Result<()> {
197        validate_object(desc, data.len())?;
198
199        let shape_product = desc
200            .shape
201            .iter()
202            .try_fold(1u64, |acc, &x| acc.checked_mul(x))
203            .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
204        let num_elements = usize::try_from(shape_product)
205            .map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))?;
206
207        // Honour the intra-codec thread budget captured at construction.
208        // Small-message threshold: if the payload is below the threshold,
209        // skip the pool (the overhead would outweigh any codec win).
210        let parallel = crate::parallel::should_parallelise(
211            self.intra_codec_threads,
212            data.len(),
213            self.parallel_threshold_bytes,
214        );
215        let intra = if parallel {
216            self.intra_codec_threads
217        } else {
218            0
219        };
220
221        let config = crate::encode::build_pipeline_config_with_backend(
222            desc,
223            num_elements,
224            desc.dtype,
225            tensogram_encodings::pipeline::CompressionBackend::default(),
226            intra,
227        )?;
228
229        let result =
230            crate::parallel::run_maybe_pooled(self.intra_codec_threads, parallel, intra, || {
231                pipeline::encode_pipeline(data, &config)
232            })
233            .map_err(|e| TensogramError::Encoding(e.to_string()))?;
234
235        // Build final descriptor with computed fields
236        let mut final_desc = desc.clone();
237
238        if let Some(offsets) = &result.block_offsets {
239            final_desc.params.insert(
240                "szip_block_offsets".to_string(),
241                ciborium::Value::Array(
242                    offsets
243                        .iter()
244                        .map(|&o| ciborium::Value::Integer(o.into()))
245                        .collect(),
246                ),
247            );
248        }
249
250        self.write_object_inner(final_desc, &result.encoded_bytes)
251    }
252
253    /// Write a pre-encoded data object frame directly.
254    ///
255    /// Unlike [`write_object`](Self::write_object), this method does **not**
256    /// run the encoding pipeline — `pre_encoded_bytes` are written to the
257    /// stream as-is.  The descriptor must accurately describe the encoding
258    /// that was already applied (encoding, filter, compression, params) so
259    /// that decoders can reconstruct the original payload.
260    ///
261    /// This method participates in the same preceder consumption logic as
262    /// [`write_object`](Self::write_object) and can be freely intermixed
263    /// with it.
264    ///
265    /// # Errors
266    ///
267    /// Returns an error if the descriptor is invalid or the frame cannot be
268    /// written to the underlying writer.
269    #[tracing::instrument(skip(self, descriptor, pre_encoded_bytes))]
270    pub fn write_object_pre_encoded(
271        &mut self,
272        descriptor: &DataObjectDescriptor,
273        pre_encoded_bytes: &[u8],
274    ) -> Result<()> {
275        validate_object(descriptor, pre_encoded_bytes.len())?;
276
277        let shape_product = descriptor
278            .shape
279            .iter()
280            .try_fold(1u64, |acc, &x| acc.checked_mul(x))
281            .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
282        let num_elements = usize::try_from(shape_product)
283            .map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))?;
284
285        // Validate descriptor pipeline configuration without encoding.
286        build_pipeline_config(descriptor, num_elements, descriptor.dtype)?;
287
288        // Validate szip metadata — same checks as buffered encode_pre_encoded.
289        validate_no_szip_offsets_for_non_szip(descriptor)?;
290        if descriptor.compression == "szip" && descriptor.params.contains_key("szip_block_offsets")
291        {
292            validate_szip_block_offsets(&descriptor.params, pre_encoded_bytes.len())?;
293        }
294
295        self.write_object_inner(descriptor.clone(), pre_encoded_bytes)
296    }
297
298    /// Shared inner implementation for both [`write_object`](Self::write_object) and
299    /// [`write_object_pre_encoded`](Self::write_object_pre_encoded).
300    ///
301    /// Computes the hash, builds the data object frame, updates all bookkeeping,
302    /// consumes any pending preceder, and writes the frame to the stream.
303    fn write_object_inner(
304        &mut self,
305        mut final_desc: DataObjectDescriptor,
306        encoded_bytes: &[u8],
307    ) -> Result<()> {
308        // Compute hash
309        let hash_entry = if let Some(algorithm) = self.hash_algorithm {
310            let hash_value = compute_hash(encoded_bytes, algorithm);
311            let hash_type = algorithm.as_str().to_string();
312            let entry = Some((hash_type.clone(), hash_value.clone()));
313            final_desc.hash = Some(HashDescriptor {
314                hash_type,
315                value: hash_value,
316            });
317            entry
318        } else {
319            None
320        };
321
322        // Build the data object frame bytes
323        let frame_bytes =
324            crate::framing::encode_data_object_frame(&final_desc, encoded_bytes, false)?;
325
326        self.object_offsets.push(self.bytes_written);
327        self.object_lengths.push(frame_bytes.len() as u64);
328        self.hash_entries.push(hash_entry);
329        // Retain only the descriptor for footer metadata population.
330        // The encoded payload has already been written to the stream;
331        // keeping it in memory would negate streaming's memory benefits.
332        self.completed_objects.push(EncodedObject {
333            descriptor: final_desc,
334            encoded_payload: Vec::new(),
335        });
336
337        // Consume pending preceder — if no preceder was written for this
338        // object, record None so preceder_payloads stays aligned with objects.
339        if self.pending_preceder {
340            self.pending_preceder = false;
341        } else {
342            self.preceder_payloads.push(None);
343        }
344
345        // Write frame
346        self.writer.write_all(&frame_bytes)?;
347        self.bytes_written += frame_bytes.len() as u64;
348
349        // Align to 8 bytes
350        write_padding(&mut self.writer, &mut self.bytes_written)?;
351
352        Ok(())
353    }
354
355    /// Finalize the streaming message.
356    ///
357    /// Writes footer frames (payload metadata + hash + index) and the postamble.
358    /// Consumes the encoder and returns the underlying writer.
359    pub fn finish(mut self) -> Result<W> {
360        if self.pending_preceder {
361            return Err(TensogramError::Framing(
362                "dangling PrecederMetadata: finish called without a following write_object/write_object_pre_encoded"
363                    .to_string(),
364            ));
365        }
366
367        let footer_start = self.bytes_written;
368
369        // Footer metadata frame: updated global metadata with per-object payload entries.
370        // The header metadata was written without knowing the objects; here we write
371        // a footer metadata frame that supersedes it with payload populated.
372        // Preceder payloads are merged in so the footer is complete even for
373        // decoders that skip PrecederMetadata frames.
374        {
375            let mut enriched_meta = self.global_meta.clone();
376            populate_base_entries(&mut enriched_meta.base, &self.completed_objects);
377            populate_reserved_provenance(&mut enriched_meta.reserved);
378
379            // Merge preceder payloads into footer metadata base entries
380            // (preceder wins).  preceder_payloads is aligned 1:1 with
381            // completed_objects by write_preceder/write_object bookkeeping,
382            // so the lengths must match.
383            if self.preceder_payloads.len() != self.completed_objects.len() {
384                return Err(TensogramError::Framing(format!(
385                    "internal: preceder_payloads ({}) out of sync with completed_objects ({})",
386                    self.preceder_payloads.len(),
387                    self.completed_objects.len()
388                )));
389            }
390            for (i, prec) in self.preceder_payloads.iter().enumerate() {
391                if let Some(prec_map) = prec
392                    && i < enriched_meta.base.len()
393                {
394                    for (k, v) in prec_map {
395                        enriched_meta.base[i].insert(k.clone(), v.clone());
396                    }
397                }
398            }
399            let meta_cbor = metadata::global_metadata_to_cbor(&enriched_meta)?;
400            let frame_bytes = build_frame(FrameType::FooterMetadata, 1, 0, &meta_cbor);
401            self.writer.write_all(&frame_bytes)?;
402            self.bytes_written += frame_bytes.len() as u64;
403            write_padding(&mut self.writer, &mut self.bytes_written)?;
404        }
405
406        // Footer hash frame (if any objects had hashes)
407        let has_hashes = self.hash_entries.iter().any(|e| e.is_some());
408        if has_hashes {
409            let hash_type = self
410                .hash_algorithm
411                .map(|a| a.as_str().to_string())
412                .unwrap_or_default();
413            let hashes: Vec<String> = self
414                .hash_entries
415                .iter()
416                .map(|e| e.as_ref().map(|(_, v)| v.clone()).unwrap_or_default())
417                .collect();
418            let hash_frame = HashFrame {
419                object_count: self.object_offsets.len() as u64,
420                hash_type,
421                hashes,
422            };
423            let hash_cbor = metadata::hash_frame_to_cbor(&hash_frame)?;
424            let frame_bytes = build_frame(FrameType::FooterHash, 1, 0, &hash_cbor);
425            self.writer.write_all(&frame_bytes)?;
426            self.bytes_written += frame_bytes.len() as u64;
427
428            write_padding(&mut self.writer, &mut self.bytes_written)?;
429        }
430
431        // Footer index frame
432        let index = IndexFrame {
433            object_count: self.object_offsets.len() as u64,
434            offsets: self.object_offsets,
435            lengths: self.object_lengths,
436        };
437        let index_cbor = metadata::index_to_cbor(&index)?;
438        let frame_bytes = build_frame(FrameType::FooterIndex, 1, 0, &index_cbor);
439        self.writer.write_all(&frame_bytes)?;
440        self.bytes_written += frame_bytes.len() as u64;
441
442        write_padding(&mut self.writer, &mut self.bytes_written)?;
443
444        // Postamble
445        let postamble = Postamble {
446            first_footer_offset: footer_start,
447        };
448        let mut postamble_bytes = Vec::with_capacity(16);
449        postamble.write_to(&mut postamble_bytes);
450        self.writer.write_all(&postamble_bytes)?;
451
452        self.writer.flush()?;
453
454        Ok(self.writer)
455    }
456
457    /// Returns the number of data objects written so far.
458    pub fn object_count(&self) -> usize {
459        self.object_offsets.len()
460    }
461
462    /// Returns the total bytes written so far.
463    pub fn bytes_written(&self) -> u64 {
464        self.bytes_written
465    }
466}
467
468// ── Helpers ──────────────────────────────────────────────────────────────────
469
470fn preamble_to_bytes(preamble: &Preamble) -> Vec<u8> {
471    let mut out = Vec::with_capacity(PREAMBLE_SIZE);
472    preamble.write_to(&mut out);
473    out
474}
475
476fn build_frame(frame_type: FrameType, version: u16, flags: u16, payload: &[u8]) -> Vec<u8> {
477    let total_length = (FRAME_HEADER_SIZE + payload.len() + FRAME_END.len()) as u64;
478    let fh = FrameHeader {
479        frame_type,
480        version,
481        flags,
482        total_length,
483    };
484    let mut out = Vec::with_capacity(total_length as usize);
485    fh.write_to(&mut out);
486    out.extend_from_slice(payload);
487    out.extend_from_slice(FRAME_END);
488    out
489}
490
491const ZERO_PAD: [u8; 7] = [0; 7];
492
493fn write_padding(writer: &mut impl Write, bytes_written: &mut u64) -> std::io::Result<()> {
494    let pad = (8 - (*bytes_written as usize % 8)) % 8;
495    if pad > 0 {
496        writer.write_all(&ZERO_PAD[..pad])?;
497        *bytes_written += pad as u64;
498    }
499    Ok(())
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use crate::Dtype;
506    use crate::decode::{DecodeOptions, decode};
507    use crate::encode::{EncodeOptions, encode};
508    use crate::types::{ByteOrder, DataObjectDescriptor};
509    use std::collections::BTreeMap;
510
511    fn make_descriptor(shape: Vec<u64>) -> DataObjectDescriptor {
512        let ndim = shape.len() as u64;
513        let mut strides = vec![0u64; shape.len()];
514        if !shape.is_empty() {
515            strides[shape.len() - 1] = 1;
516            for i in (0..shape.len() - 1).rev() {
517                strides[i] = strides[i + 1] * shape[i + 1];
518            }
519        }
520        DataObjectDescriptor {
521            obj_type: "ntensor".to_string(),
522            ndim,
523            shape,
524            strides,
525            dtype: Dtype::Float32,
526            byte_order: ByteOrder::native(),
527            encoding: "none".to_string(),
528            filter: "none".to_string(),
529            compression: "none".to_string(),
530            params: BTreeMap::new(),
531            hash: None,
532        }
533    }
534
535    #[test]
536    fn streaming_single_object_round_trip() {
537        let meta = GlobalMetadata::default();
538        let desc = make_descriptor(vec![4]);
539        let data = vec![0u8; 4 * 4];
540
541        // Streaming encode
542        let buf = Vec::new();
543        let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
544        enc.write_object(&desc, &data).unwrap();
545        let result = enc.finish().unwrap();
546
547        // Decode should succeed
548        let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
549        assert_eq!(decoded_meta.version, 2);
550        assert_eq!(objects.len(), 1);
551        assert_eq!(objects[0].1, data);
552    }
553
554    #[test]
555    fn streaming_multi_object_round_trip() {
556        let meta = GlobalMetadata::default();
557        let desc1 = make_descriptor(vec![4]);
558        let desc2 = make_descriptor(vec![8]);
559        let data1 = vec![1u8; 4 * 4];
560        let data2 = vec![2u8; 8 * 4];
561
562        let buf = Vec::new();
563        let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
564        enc.write_object(&desc1, &data1).unwrap();
565        enc.write_object(&desc2, &data2).unwrap();
566        assert_eq!(enc.object_count(), 2);
567        let result = enc.finish().unwrap();
568
569        let (_, objects) = decode(&result, &DecodeOptions::default()).unwrap();
570        assert_eq!(objects.len(), 2);
571        assert_eq!(objects[0].1, data1);
572        assert_eq!(objects[1].1, data2);
573    }
574
575    #[test]
576    fn streaming_matches_buffered_single_object() {
577        let meta = GlobalMetadata::default();
578        let desc = make_descriptor(vec![4]);
579        let data = vec![42u8; 4 * 4];
580        let options = EncodeOptions {
581            compression_backend: Default::default(),
582            hash_algorithm: Some(HashAlgorithm::Xxh3),
583            emit_preceders: false,
584            threads: 0,
585            parallel_threshold_bytes: None,
586        };
587
588        // Buffered encode
589        let buffered = encode(&meta, &[(&desc, &data)], &options).unwrap();
590        let (buf_meta, buf_objects) = decode(
591            &buffered,
592            &DecodeOptions {
593                verify_hash: true,
594                ..Default::default()
595            },
596        )
597        .unwrap();
598
599        // Streaming encode
600        let buf = Vec::new();
601        let mut enc = StreamingEncoder::new(buf, &meta, &options).unwrap();
602        enc.write_object(&desc, &data).unwrap();
603        let streamed = enc.finish().unwrap();
604        let (str_meta, str_objects) = decode(
605            &streamed,
606            &DecodeOptions {
607                verify_hash: true,
608                ..Default::default()
609            },
610        )
611        .unwrap();
612
613        // Data must match (wire bytes may differ due to header vs footer layout)
614        assert_eq!(buf_meta.version, str_meta.version);
615        assert_eq!(buf_objects.len(), str_objects.len());
616        assert_eq!(buf_objects[0].0.shape, str_objects[0].0.shape);
617        assert_eq!(buf_objects[0].0.dtype, str_objects[0].0.dtype);
618        assert_eq!(buf_objects[0].1, str_objects[0].1);
619        // Hash values must match
620        assert_eq!(
621            buf_objects[0].0.hash.as_ref().unwrap().value,
622            str_objects[0].0.hash.as_ref().unwrap().value
623        );
624    }
625
626    #[test]
627    fn streaming_hash_verification() {
628        let meta = GlobalMetadata::default();
629        let desc = make_descriptor(vec![4]);
630        let data = vec![42u8; 4 * 4];
631        let options = EncodeOptions {
632            compression_backend: Default::default(),
633            hash_algorithm: Some(HashAlgorithm::Xxh3),
634            emit_preceders: false,
635            threads: 0,
636            parallel_threshold_bytes: None,
637        };
638
639        let buf = Vec::new();
640        let mut enc = StreamingEncoder::new(buf, &meta, &options).unwrap();
641        enc.write_object(&desc, &data).unwrap();
642        let result = enc.finish().unwrap();
643
644        // Verify hash passes
645        let verify_opts = DecodeOptions {
646            verify_hash: true,
647            ..Default::default()
648        };
649        let (_, objects) = decode(&result, &verify_opts).unwrap();
650        assert!(objects[0].0.hash.is_some());
651    }
652
653    #[test]
654    fn streaming_no_objects() {
655        let meta = GlobalMetadata::default();
656        let options = EncodeOptions {
657            compression_backend: Default::default(),
658            hash_algorithm: None,
659            emit_preceders: false,
660            threads: 0,
661            parallel_threshold_bytes: None,
662        };
663
664        let buf = Vec::new();
665        let enc = StreamingEncoder::new(buf, &meta, &options).unwrap();
666        assert_eq!(enc.object_count(), 0);
667        let result = enc.finish().unwrap();
668
669        let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
670        assert_eq!(decoded_meta.version, 2);
671        assert_eq!(objects.len(), 0);
672    }
673
674    /// Threads budget on `StreamingEncoder` must not change the encoded
675    /// payload for transparent pipelines.  This locks in the pass-3
676    /// consistency: axis-B dispatch inside `write_object` is opt-in and
677    /// transparent-codec output is byte-identical across thread counts.
678    #[test]
679    fn streaming_threads_byte_identical_transparent() {
680        let meta = GlobalMetadata::default();
681        // One large object — 200 KiB — above the 64 KiB default threshold.
682        let desc = make_descriptor(vec![50_000]);
683        let data: Vec<u8> = (0..50_000)
684            .flat_map(|i| (250.0f32 + (i as f32).sin() * 30.0).to_ne_bytes())
685            .collect();
686
687        let mk = |threads: u32| -> Vec<u8> {
688            let buf = Vec::new();
689            let opts = EncodeOptions {
690                threads,
691                parallel_threshold_bytes: Some(0), // force parallel
692                ..Default::default()
693            };
694            let mut enc = StreamingEncoder::new(buf, &meta, &opts).unwrap();
695            enc.write_object(&desc, &data).unwrap();
696            enc.finish().unwrap()
697        };
698
699        // Compare encoded payload bytes (ignore provenance).
700        let payloads = |buf: &[u8]| -> Vec<Vec<u8>> {
701            crate::framing::decode_message(buf)
702                .unwrap()
703                .objects
704                .iter()
705                .map(|(_, p, _)| p.to_vec())
706                .collect()
707        };
708
709        let baseline = mk(0);
710        let payloads_baseline = payloads(&baseline);
711
712        for t in [1u32, 2, 4, 8] {
713            let got = mk(t);
714            assert_eq!(
715                payloads_baseline,
716                payloads(&got),
717                "streaming threads={t} payload must match sequential"
718            );
719        }
720    }
721
722    #[test]
723    fn streaming_with_metadata() {
724        let mut extra = BTreeMap::new();
725        extra.insert(
726            "centre".to_string(),
727            ciborium::Value::Text("ecmwf".to_string()),
728        );
729        let meta = GlobalMetadata {
730            version: 2,
731            extra,
732            ..Default::default()
733        };
734
735        let desc = make_descriptor(vec![4]);
736        let data = vec![0u8; 4 * 4];
737
738        let buf = Vec::new();
739        let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
740        enc.write_object(&desc, &data).unwrap();
741        let result = enc.finish().unwrap();
742
743        let (decoded_meta, _) = decode(&result, &DecodeOptions::default()).unwrap();
744        assert_eq!(
745            decoded_meta.extra.get("centre"),
746            Some(&ciborium::Value::Text("ecmwf".to_string()))
747        );
748    }
749
750    // ── PrecederMetadata tests ───────────────────────────────────────────
751
752    #[test]
753    fn streaming_preceder_round_trip() {
754        let meta = GlobalMetadata::default();
755        let desc = make_descriptor(vec![4]);
756        let data = vec![42u8; 4 * 4];
757
758        let mut prec = BTreeMap::new();
759        prec.insert(
760            "mars".to_string(),
761            ciborium::Value::Map(vec![(
762                ciborium::Value::Text("param".to_string()),
763                ciborium::Value::Text("2t".to_string()),
764            )]),
765        );
766
767        let buf = Vec::new();
768        let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
769        enc.write_preceder(prec).unwrap();
770        enc.write_object(&desc, &data).unwrap();
771        let result = enc.finish().unwrap();
772
773        let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
774        assert_eq!(objects.len(), 1);
775        assert_eq!(objects[0].1, data);
776
777        // Preceder mars keys should be in base[0]
778        let mars = decoded_meta.base[0].get("mars");
779        assert!(mars.is_some(), "mars key should be in base[0]");
780    }
781
782    #[test]
783    fn streaming_preceder_wins_over_footer() {
784        // Pre-populate global_meta.base[0] with a value — the preceder
785        // should override it after decode.
786        let mut footer_base = BTreeMap::new();
787        footer_base.insert(
788            "source".to_string(),
789            ciborium::Value::Text("footer".to_string()),
790        );
791        let meta = GlobalMetadata {
792            version: 2,
793            base: vec![footer_base],
794            ..Default::default()
795        };
796
797        let mut prec = BTreeMap::new();
798        prec.insert(
799            "source".to_string(),
800            ciborium::Value::Text("preceder".to_string()),
801        );
802
803        let desc = make_descriptor(vec![4]);
804        let data = vec![0u8; 4 * 4];
805
806        let buf = Vec::new();
807        let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
808        enc.write_preceder(prec).unwrap();
809        enc.write_object(&desc, &data).unwrap();
810        let result = enc.finish().unwrap();
811
812        let (decoded_meta, _) = decode(&result, &DecodeOptions::default()).unwrap();
813        let source = decoded_meta.base[0].get("source").and_then(|v| match v {
814            ciborium::Value::Text(s) => Some(s.as_str()),
815            _ => None,
816        });
817        assert_eq!(source, Some("preceder"), "preceder should win over footer");
818    }
819
820    #[test]
821    fn streaming_consecutive_preceder_error() {
822        let meta = GlobalMetadata::default();
823        let buf = Vec::new();
824        let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
825
826        enc.write_preceder(BTreeMap::new()).unwrap();
827        let result = enc.write_preceder(BTreeMap::new());
828        assert!(
829            result.is_err(),
830            "two write_preceder calls without intervening write_object should fail"
831        );
832    }
833
834    #[test]
835    fn streaming_dangling_preceder_error() {
836        let meta = GlobalMetadata::default();
837        let buf = Vec::new();
838        let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
839
840        enc.write_preceder(BTreeMap::new()).unwrap();
841        let result = enc.finish();
842        assert!(
843            result.is_err(),
844            "finish with a dangling preceder should fail"
845        );
846    }
847
848    #[test]
849    fn streaming_mixed_objects_with_and_without_preceders() {
850        let meta = GlobalMetadata::default();
851        let desc0 = make_descriptor(vec![4]);
852        let desc1 = make_descriptor(vec![8]);
853        let data0 = vec![1u8; 4 * 4];
854        let data1 = vec![2u8; 8 * 4];
855
856        let mut prec = BTreeMap::new();
857        prec.insert(
858            "note".to_string(),
859            ciborium::Value::Text("only for obj 0".to_string()),
860        );
861
862        let buf = Vec::new();
863        let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
864        // Object 0: with preceder
865        enc.write_preceder(prec).unwrap();
866        enc.write_object(&desc0, &data0).unwrap();
867        // Object 1: without preceder
868        enc.write_object(&desc1, &data1).unwrap();
869        let result = enc.finish().unwrap();
870
871        let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
872        assert_eq!(objects.len(), 2);
873        assert_eq!(objects[0].1, data0);
874        assert_eq!(objects[1].1, data1);
875
876        // base[0] should have preceder entry
877        assert!(decoded_meta.base[0].contains_key("note"));
878        // base[1] should NOT have it
879        assert!(!decoded_meta.base[1].contains_key("note"));
880    }
881
882    #[test]
883    fn streaming_preceder_metadata_preservation() {
884        // Verify application metadata from preceder survives the full
885        // encode → footer-merge → decode → preceder-merge path.
886        let meta = GlobalMetadata::default();
887        let desc = make_descriptor(vec![2]);
888        let data = vec![0u8; 2 * 4];
889
890        let mut prec = BTreeMap::new();
891        prec.insert("units".to_string(), ciborium::Value::Text("K".to_string()));
892        prec.insert(
893            "mars".to_string(),
894            ciborium::Value::Map(vec![
895                (
896                    ciborium::Value::Text("param".to_string()),
897                    ciborium::Value::Text("2t".to_string()),
898                ),
899                (
900                    ciborium::Value::Text("levtype".to_string()),
901                    ciborium::Value::Text("sfc".to_string()),
902                ),
903            ]),
904        );
905
906        let buf = Vec::new();
907        let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
908        enc.write_preceder(prec).unwrap();
909        enc.write_object(&desc, &data).unwrap();
910        let result = enc.finish().unwrap();
911
912        let (decoded_meta, _) = decode(&result, &DecodeOptions::default()).unwrap();
913        let p = &decoded_meta.base[0];
914        assert_eq!(
915            p.get("units"),
916            Some(&ciborium::Value::Text("K".to_string()))
917        );
918        assert!(p.contains_key("mars"));
919        // Structural keys (ndim, shape) should be under _reserved_.tensor
920        assert!(p.contains_key("_reserved_"));
921    }
922
923    // ── Edge case: preceder with _reserved_ rejected ─────────────────────
924
925    #[test]
926    fn streaming_preceder_with_reserved_rejected() {
927        let meta = GlobalMetadata::default();
928        let buf = Vec::new();
929        let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
930
931        let mut prec = BTreeMap::new();
932        prec.insert("_reserved_".to_string(), ciborium::Value::Map(vec![]));
933
934        let result = enc.write_preceder(prec);
935        assert!(result.is_err(), "_reserved_ in preceder should be rejected");
936        let err = result.unwrap_err().to_string();
937        assert!(
938            err.contains("_reserved_"),
939            "error should mention _reserved_: {err}"
940        );
941    }
942
943    #[test]
944    fn streaming_preceder_reserved_stripped_on_decode() {
945        // If a non-standard producer includes _reserved_ in a preceder,
946        // the decoder strips it rather than failing, and the encoder's
947        // _reserved_.tensor is preserved.
948
949        // Build a raw message with a preceder that contains _reserved_.
950        // We bypass the encoder's validation by constructing frames manually.
951        let mut prec_entry = BTreeMap::new();
952        prec_entry.insert(
953            "mars".to_string(),
954            ciborium::Value::Map(vec![(
955                ciborium::Value::Text("param".to_string()),
956                ciborium::Value::Text("2t".to_string()),
957            )]),
958        );
959        prec_entry.insert(
960            "_reserved_".to_string(),
961            ciborium::Value::Map(vec![(
962                ciborium::Value::Text("rogue".to_string()),
963                ciborium::Value::Text("bad".to_string()),
964            )]),
965        );
966
967        // Encode normally to get a valid message first, then decode
968        // and verify _reserved_ from preceder doesn't clobber.
969        // We test via the framing level directly.
970        let preceder_meta = GlobalMetadata {
971            version: 2,
972            base: vec![prec_entry],
973            ..Default::default()
974        };
975        let preceder_cbor = crate::metadata::global_metadata_to_cbor(&preceder_meta).unwrap();
976
977        // Build a raw message with preceder + data object
978        let desc_for_frame = make_descriptor(vec![4]);
979        let payload = vec![0u8; 4 * 4];
980        let frame =
981            crate::framing::encode_data_object_frame(&desc_for_frame, &payload, false).unwrap();
982
983        // Footer metadata with _reserved_.tensor
984        let mut footer_base = BTreeMap::new();
985        let tensor_map = ciborium::Value::Map(vec![
986            (
987                ciborium::Value::Text("ndim".to_string()),
988                ciborium::Value::Integer(1.into()),
989            ),
990            (
991                ciborium::Value::Text("shape".to_string()),
992                ciborium::Value::Array(vec![ciborium::Value::Integer(4.into())]),
993            ),
994            (
995                ciborium::Value::Text("strides".to_string()),
996                ciborium::Value::Array(vec![ciborium::Value::Integer(1.into())]),
997            ),
998            (
999                ciborium::Value::Text("dtype".to_string()),
1000                ciborium::Value::Text("float32".to_string()),
1001            ),
1002        ]);
1003        footer_base.insert(
1004            "_reserved_".to_string(),
1005            ciborium::Value::Map(vec![(
1006                ciborium::Value::Text("tensor".to_string()),
1007                tensor_map,
1008            )]),
1009        );
1010        let footer_meta = GlobalMetadata {
1011            version: 2,
1012            base: vec![footer_base],
1013            ..Default::default()
1014        };
1015        let footer_cbor = crate::metadata::global_metadata_to_cbor(&footer_meta).unwrap();
1016
1017        // Assemble raw message
1018        use crate::wire::*;
1019        let header_meta_cbor =
1020            crate::metadata::global_metadata_to_cbor(&GlobalMetadata::default()).unwrap();
1021
1022        let mut out = Vec::new();
1023        out.extend_from_slice(&[0u8; PREAMBLE_SIZE]);
1024
1025        // Header metadata
1026        let total_length = (FRAME_HEADER_SIZE + header_meta_cbor.len() + FRAME_END.len()) as u64;
1027        let fh = FrameHeader {
1028            frame_type: FrameType::HeaderMetadata,
1029            version: 1,
1030            flags: 0,
1031            total_length,
1032        };
1033        fh.write_to(&mut out);
1034        out.extend_from_slice(&header_meta_cbor);
1035        out.extend_from_slice(FRAME_END);
1036        let pad = (8 - (out.len() % 8)) % 8;
1037        out.extend(std::iter::repeat_n(0u8, pad));
1038
1039        // Preceder metadata
1040        let total_length = (FRAME_HEADER_SIZE + preceder_cbor.len() + FRAME_END.len()) as u64;
1041        let fh = FrameHeader {
1042            frame_type: FrameType::PrecederMetadata,
1043            version: 1,
1044            flags: 0,
1045            total_length,
1046        };
1047        fh.write_to(&mut out);
1048        out.extend_from_slice(&preceder_cbor);
1049        out.extend_from_slice(FRAME_END);
1050        let pad = (8 - (out.len() % 8)) % 8;
1051        out.extend(std::iter::repeat_n(0u8, pad));
1052
1053        // Data object
1054        out.extend_from_slice(&frame);
1055        let pad = (8 - (out.len() % 8)) % 8;
1056        out.extend(std::iter::repeat_n(0u8, pad));
1057
1058        // Footer metadata
1059        let total_length = (FRAME_HEADER_SIZE + footer_cbor.len() + FRAME_END.len()) as u64;
1060        let fh = FrameHeader {
1061            frame_type: FrameType::FooterMetadata,
1062            version: 1,
1063            flags: 0,
1064            total_length,
1065        };
1066        fh.write_to(&mut out);
1067        out.extend_from_slice(&footer_cbor);
1068        out.extend_from_slice(FRAME_END);
1069        let pad = (8 - (out.len() % 8)) % 8;
1070        out.extend(std::iter::repeat_n(0u8, pad));
1071
1072        // Postamble
1073        let postamble_offset = out.len();
1074        let postamble = Postamble {
1075            first_footer_offset: postamble_offset as u64,
1076        };
1077        postamble.write_to(&mut out);
1078
1079        // Patch preamble
1080        let total_length = out.len() as u64;
1081        let mut flags = MessageFlags::default();
1082        flags.set(MessageFlags::HEADER_METADATA);
1083        flags.set(MessageFlags::FOOTER_METADATA);
1084        flags.set(MessageFlags::PRECEDER_METADATA);
1085        let preamble = Preamble {
1086            version: 2,
1087            flags,
1088            reserved: 0,
1089            total_length,
1090        };
1091        let mut preamble_bytes = Vec::new();
1092        preamble.write_to(&mut preamble_bytes);
1093        out[0..PREAMBLE_SIZE].copy_from_slice(&preamble_bytes);
1094
1095        // Decode
1096        let decoded = crate::framing::decode_message(&out).unwrap();
1097
1098        // The preceder's _reserved_ should have been stripped by the decoder.
1099        // The footer's _reserved_.tensor should be preserved.
1100        let base0 = &decoded.global_metadata.base[0];
1101        assert!(
1102            base0.contains_key("mars"),
1103            "mars from preceder should survive"
1104        );
1105        // _reserved_ should come from footer, not preceder
1106        let reserved = base0.get("_reserved_");
1107        assert!(
1108            reserved.is_some(),
1109            "_reserved_ from footer should be present"
1110        );
1111        if let Some(ciborium::Value::Map(pairs)) = reserved {
1112            let has_tensor = pairs
1113                .iter()
1114                .any(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1115            assert!(has_tensor, "tensor key from footer should be preserved");
1116            let has_rogue = pairs
1117                .iter()
1118                .any(|(k, _)| *k == ciborium::Value::Text("rogue".to_string()));
1119            assert!(
1120                !has_rogue,
1121                "rogue key from preceder's _reserved_ should have been stripped"
1122            );
1123        }
1124    }
1125
1126    // ── write_object_pre_encoded tests ───────────────────────────────────
1127
1128    #[test]
1129    fn test_streaming_mixed_mode_pre_encoded() {
1130        // write_object (raw), write_object_pre_encoded, write_object (raw) — decode all 3.
1131        let meta = GlobalMetadata::default();
1132
1133        let desc0 = make_descriptor(vec![4]);
1134        let desc2 = make_descriptor(vec![6]);
1135        // Pre-encoded object: encoding="none" so pre-encoded bytes == raw bytes.
1136        let desc1 = make_descriptor(vec![5]);
1137
1138        let data0 = vec![1u8; 4 * 4];
1139        let pre_encoded1 = vec![2u8; 5 * 4]; // treated as already-encoded
1140        let data2 = vec![3u8; 6 * 4];
1141
1142        let buf = Vec::new();
1143        let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
1144        enc.write_object(&desc0, &data0).unwrap();
1145        enc.write_object_pre_encoded(&desc1, &pre_encoded1).unwrap();
1146        enc.write_object(&desc2, &data2).unwrap();
1147        assert_eq!(enc.object_count(), 3);
1148        let result = enc.finish().unwrap();
1149
1150        let (_, objects) = decode(&result, &DecodeOptions::default()).unwrap();
1151        assert_eq!(objects.len(), 3);
1152        // Don't compare raw message bytes (provenance is non-deterministic).
1153        // Compare decoded payloads.
1154        assert_eq!(objects[0].1, data0, "object 0 payload mismatch");
1155        assert_eq!(objects[1].1, pre_encoded1, "object 1 payload mismatch");
1156        assert_eq!(objects[2].1, data2, "object 2 payload mismatch");
1157    }
1158
1159    #[test]
1160    fn test_streaming_preceder_then_pre_encoded() {
1161        // write_preceder followed by write_object_pre_encoded — preceder metadata
1162        // should appear in base[0] after decode.
1163        let meta = GlobalMetadata::default();
1164        let desc = make_descriptor(vec![4]);
1165        let pre_encoded = vec![42u8; 4 * 4];
1166
1167        let mut prec = BTreeMap::new();
1168        prec.insert(
1169            "mars".to_string(),
1170            ciborium::Value::Map(vec![(
1171                ciborium::Value::Text("param".to_string()),
1172                ciborium::Value::Text("2t".to_string()),
1173            )]),
1174        );
1175
1176        let buf = Vec::new();
1177        let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
1178        enc.write_preceder(prec).unwrap();
1179        enc.write_object_pre_encoded(&desc, &pre_encoded).unwrap();
1180        let result = enc.finish().unwrap();
1181
1182        let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
1183        assert_eq!(objects.len(), 1);
1184        // Payload must round-trip correctly.
1185        assert_eq!(objects[0].1, pre_encoded, "pre-encoded payload mismatch");
1186        // Preceder mars key should be in base[0].
1187        let mars = decoded_meta.base[0].get("mars");
1188        assert!(
1189            mars.is_some(),
1190            "mars key from preceder should be in base[0]"
1191        );
1192    }
1193
1194    #[test]
1195    fn streaming_finish_preserves_preceder_does_not_clobber_reserved_tensor() {
1196        // Verify that preceder metadata does NOT clobber the encoder's
1197        // _reserved_.tensor in the footer metadata.
1198        let meta = GlobalMetadata::default();
1199        let desc = make_descriptor(vec![4]);
1200        let data = vec![42u8; 4 * 4];
1201
1202        let mut prec = BTreeMap::new();
1203        prec.insert("units".to_string(), ciborium::Value::Text("K".to_string()));
1204
1205        let buf = Vec::new();
1206        let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
1207        enc.write_preceder(prec).unwrap();
1208        enc.write_object(&desc, &data).unwrap();
1209        let result = enc.finish().unwrap();
1210
1211        let (decoded_meta, _) = decode(&result, &DecodeOptions::default()).unwrap();
1212        let base0 = &decoded_meta.base[0];
1213
1214        // preceder key should be present
1215        assert!(base0.contains_key("units"));
1216
1217        // _reserved_.tensor should also be present
1218        let reserved = base0.get("_reserved_").expect("_reserved_ missing");
1219        if let ciborium::Value::Map(pairs) = reserved {
1220            let has_tensor = pairs
1221                .iter()
1222                .any(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1223            assert!(
1224                has_tensor,
1225                "_reserved_.tensor should be present after preceder merge"
1226            );
1227        } else {
1228            panic!("_reserved_ should be a map");
1229        }
1230    }
1231}