Skip to main content

tensogram_core/
encode.rs

1// (C) Copyright 2026- ECMWF and individual contributors.
2//
3// This software is licensed under the terms of the Apache Licence Version 2.0
4// which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5// In applying this licence, ECMWF does not waive the privileges and immunities
6// granted to it by virtue of its status as an intergovernmental organisation nor
7// does it submit to any jurisdiction.
8
9use std::collections::BTreeMap;
10
11use crate::dtype::Dtype;
12use crate::error::{Result, TensogramError};
13use crate::framing::{self, EncodedObject};
14use crate::hash::{HashAlgorithm, compute_hash};
15use crate::metadata::RESERVED_KEY;
16use crate::types::{DataObjectDescriptor, GlobalMetadata, HashDescriptor};
17#[cfg(feature = "blosc2")]
18use tensogram_encodings::pipeline::Blosc2Codec;
19#[cfg(feature = "sz3")]
20use tensogram_encodings::pipeline::Sz3ErrorBound;
21#[cfg(feature = "zfp")]
22use tensogram_encodings::pipeline::ZfpMode;
23use tensogram_encodings::pipeline::{
24    self, CompressionType, EncodingType, FilterType, PipelineConfig,
25};
26use tensogram_encodings::simple_packing::SimplePackingParams;
27
28/// Options for encoding.
29#[derive(Debug, Clone)]
30pub struct EncodeOptions {
31    /// Hash algorithm to use for payload integrity. None = no hashing.
32    pub hash_algorithm: Option<HashAlgorithm>,
33    /// Reserved for future buffered-mode preceder support.
34    ///
35    /// Currently, setting this to `true` in buffered mode (`encode()`)
36    /// returns an error — use [`StreamingEncoder::write_preceder`](crate::streaming::StreamingEncoder::write_preceder) instead.
37    /// The streaming encoder ignores this field; it emits preceders only
38    /// when `write_preceder()` is called explicitly.
39    pub emit_preceders: bool,
40    /// Which backend to use for szip / zstd when both FFI and pure-Rust
41    /// implementations are compiled in.
42    ///
43    /// Defaults to `Ffi` on native (faster, battle-tested) and `Pure` on
44    /// `wasm32` (FFI cannot exist).  Override with
45    /// `TENSOGRAM_COMPRESSION_BACKEND=pure` env variable, or set this
46    /// field explicitly.
47    pub compression_backend: pipeline::CompressionBackend,
48    /// Thread budget for the multi-threaded coding pipeline.
49    ///
50    /// - `0` (default) — sequential (current behaviour).  Can be
51    ///   overridden at runtime via `TENSOGRAM_THREADS=N`.
52    /// - `1` — explicit single-threaded execution (bypasses env).
53    /// - `N ≥ 2` — scoped pool of `N` workers.  Output bytes are
54    ///   byte-identical to the sequential path regardless of `N`.
55    ///
56    /// When more than one data object is being encoded the budget is
57    /// spent axis-B-first (intra-codec parallelism) — this codebase
58    /// tends to have a small number of very large messages.  See the
59    /// [multi-threaded pipeline guide](../../docs/src/guide/multi-threaded-pipeline.md)
60    /// for the full policy.
61    ///
62    /// Ignored with a one-time `tracing::warn!` when the `threads`
63    /// cargo feature is disabled.
64    pub threads: u32,
65    /// Minimum total payload bytes below which the parallel path is
66    /// skipped even when `threads > 0`.
67    ///
68    /// `None` uses [`crate::DEFAULT_PARALLEL_THRESHOLD_BYTES`] (64 KiB).
69    /// Set to `Some(0)` to force the parallel path for testing; set to
70    /// `Some(usize::MAX)` to force sequential.
71    pub parallel_threshold_bytes: Option<usize>,
72}
73
74impl Default for EncodeOptions {
75    fn default() -> Self {
76        Self {
77            hash_algorithm: Some(HashAlgorithm::Xxh3),
78            emit_preceders: false,
79            compression_backend: pipeline::CompressionBackend::default(),
80            threads: 0,
81            parallel_threshold_bytes: None,
82        }
83    }
84}
85
86pub(crate) fn validate_object(desc: &DataObjectDescriptor, data_len: usize) -> Result<()> {
87    if desc.obj_type.is_empty() {
88        return Err(TensogramError::Metadata(
89            "obj_type must not be empty".to_string(),
90        ));
91    }
92    if desc.ndim as usize != desc.shape.len() {
93        return Err(TensogramError::Metadata(format!(
94            "ndim {} does not match shape.len() {}",
95            desc.ndim,
96            desc.shape.len()
97        )));
98    }
99    if desc.strides.len() != desc.shape.len() {
100        return Err(TensogramError::Metadata(format!(
101            "strides.len() {} does not match shape.len() {}",
102            desc.strides.len(),
103            desc.shape.len()
104        )));
105    }
106    if desc.encoding == "none" {
107        let product = desc
108            .shape
109            .iter()
110            .try_fold(1u64, |acc, &x| acc.checked_mul(x))
111            .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
112        if desc.dtype.byte_width() > 0 {
113            let expected_bytes = product
114                .checked_mul(desc.dtype.byte_width() as u64)
115                .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
116            if expected_bytes != data_len as u64 {
117                return Err(TensogramError::Metadata(format!(
118                    "data_len {data_len} does not match expected {expected_bytes} bytes from shape and dtype"
119                )));
120            }
121        } else if desc.dtype == crate::Dtype::Bitmask {
122            // Bitmask: expected data length is ceil(shape_product / 8)
123            let expected_bytes = product.div_ceil(8);
124            if expected_bytes != data_len as u64 {
125                return Err(TensogramError::Metadata(format!(
126                    "data_len {data_len} does not match expected {expected_bytes} bytes for bitmask (ceil({product}/8))"
127                )));
128            }
129        }
130    }
131    Ok(())
132}
133
134#[derive(Debug, Clone, Copy)]
135enum EncodeMode {
136    Raw,
137    PreEncoded,
138}
139
140/// Encode a single object: run the pipeline (or validate pre-encoded
141/// bytes), compute its hash, and return the `EncodedObject`.
142///
143/// `intra_codec_threads` is passed through to [`PipelineConfig`] and
144/// honoured by axis-B-capable codecs (blosc2, zstd, simple_packing,
145/// shuffle).  Pure functional — no shared state, safe to call from
146/// multiple rayon workers in parallel.
147fn encode_one_object(
148    desc: &DataObjectDescriptor,
149    data: &[u8],
150    mode: EncodeMode,
151    options: &EncodeOptions,
152    intra_codec_threads: u32,
153) -> Result<EncodedObject> {
154    validate_object(desc, data.len())?;
155
156    let shape_product = desc
157        .shape
158        .iter()
159        .try_fold(1u64, |acc, &x| acc.checked_mul(x))
160        .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
161    let num_elements = usize::try_from(shape_product)
162        .map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))?;
163    let dtype = desc.dtype;
164
165    let config = build_pipeline_config_with_backend(
166        desc,
167        num_elements,
168        dtype,
169        options.compression_backend,
170        intra_codec_threads,
171    )?;
172
173    // Build the final descriptor with computed fields
174    let mut final_desc = desc.clone();
175
176    let encoded_payload: Vec<u8> = match mode {
177        EncodeMode::Raw => {
178            // Run the full encoding pipeline.
179            let result = pipeline::encode_pipeline(data, &config)
180                .map_err(|e| TensogramError::Encoding(e.to_string()))?;
181
182            // Store szip block offsets if produced
183            if let Some(offsets) = &result.block_offsets {
184                final_desc.params.insert(
185                    "szip_block_offsets".to_string(),
186                    ciborium::Value::Array(
187                        offsets
188                            .iter()
189                            .map(|&o| ciborium::Value::Integer(o.into()))
190                            .collect(),
191                    ),
192                );
193            }
194
195            result.encoded_bytes
196        }
197        EncodeMode::PreEncoded => {
198            // Caller's bytes are already encoded — use them directly.
199            // build_pipeline_config was already called above for
200            // defense-in-depth validation of encoding/compression params.
201            validate_no_szip_offsets_for_non_szip(desc)?;
202            if desc.compression == "szip" && desc.params.contains_key("szip_block_offsets") {
203                validate_szip_block_offsets(&desc.params, data.len())?;
204            }
205            data.to_vec()
206        }
207    };
208
209    // Compute hash over encoded payload (overwrites any caller-supplied hash).
210    if let Some(algorithm) = options.hash_algorithm {
211        let hash_value = compute_hash(&encoded_payload, algorithm);
212        final_desc.hash = Some(HashDescriptor {
213            hash_type: algorithm.as_str().to_string(),
214            value: hash_value,
215        });
216    }
217
218    Ok(EncodedObject {
219        descriptor: final_desc,
220        encoded_payload,
221    })
222}
223
224fn encode_inner(
225    global_metadata: &GlobalMetadata,
226    descriptors: &[(&DataObjectDescriptor, &[u8])],
227    options: &EncodeOptions,
228    mode: EncodeMode,
229) -> Result<Vec<u8>> {
230    // Buffered encode does not support emit_preceders — use StreamingEncoder
231    // with write_preceder() instead.
232    if options.emit_preceders {
233        return Err(TensogramError::Encoding(
234            "emit_preceders is not supported in buffered mode; use StreamingEncoder::write_preceder() instead".to_string(),
235        ));
236    }
237
238    // ── Thread-budget dispatch (axis-B-first policy) ────────────────────
239    //
240    // Resolve the effective thread budget (explicit option > env var),
241    // decide if the workload is large enough to parallelise, and pick
242    // axis A (par_iter across objects) vs axis B (sequential, codec
243    // uses the budget internally).
244    let budget = crate::parallel::resolve_budget(options.threads);
245    let total_bytes: usize = descriptors.iter().map(|(_, d)| d.len()).sum();
246    let parallel =
247        crate::parallel::should_parallelise(budget, total_bytes, options.parallel_threshold_bytes);
248
249    let any_axis_b = descriptors
250        .iter()
251        .any(|(d, _)| crate::parallel::is_axis_b_friendly(&d.encoding, &d.filter, &d.compression));
252    let use_axis_a = parallel && crate::parallel::use_axis_a(descriptors.len(), budget, any_axis_b);
253
254    // Axis B gets the full budget; axis A keeps codecs sequential so
255    // that the product of axis A and axis B threads never exceeds the
256    // caller's ask.
257    let intra_codec_threads = if parallel && !use_axis_a { budget } else { 0 };
258
259    let encode_one = |(desc, data): &(&DataObjectDescriptor, &[u8])| {
260        encode_one_object(desc, data, mode, options, intra_codec_threads)
261    };
262
263    let encoded_objects: Vec<EncodedObject> = if use_axis_a {
264        // Axis A: par_iter across objects.  Requires the `threads`
265        // feature; when it's off, the caller's budget silently falls
266        // back to sequential (with a one-time warning from `with_pool`).
267        #[cfg(feature = "threads")]
268        {
269            use rayon::prelude::*;
270            crate::parallel::with_pool(budget, || {
271                descriptors
272                    .par_iter()
273                    .map(&encode_one)
274                    .collect::<Result<Vec<_>>>()
275            })?
276        }
277        #[cfg(not(feature = "threads"))]
278        {
279            descriptors.iter().map(encode_one).collect::<Result<_>>()?
280        }
281    } else {
282        // Axis B (or purely sequential): iterate objects in order.
283        // Install the pool when there's an intra-codec budget so that
284        // parallel primitives inside codec implementations (e.g.
285        // `simple_packing` chunked par_iter) actually use it.
286        crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, || {
287            descriptors.iter().map(encode_one).collect::<Result<_>>()
288        })?
289    };
290
291    // Validate that the caller hasn't written to _reserved_ at any level.
292    validate_no_client_reserved(global_metadata)?;
293
294    // Validate base/descriptor count: base may be shorter (auto-extended) or
295    // equal, but having MORE base entries than descriptors is an error —
296    // the extra entries would be silently discarded.
297    if global_metadata.base.len() > descriptors.len() {
298        return Err(TensogramError::Metadata(format!(
299            "metadata base has {} entries but only {} descriptors provided; \
300             extra base entries would be discarded",
301            global_metadata.base.len(),
302            descriptors.len()
303        )));
304    }
305
306    // Populate per-object base entries with _reserved_.tensor (ndim/shape/strides/dtype).
307    // Pre-existing application keys (e.g. "mars") are preserved.
308    let mut enriched_meta = global_metadata.clone();
309    populate_base_entries(&mut enriched_meta.base, &encoded_objects);
310    populate_reserved_provenance(&mut enriched_meta.reserved);
311
312    framing::encode_message(&enriched_meta, &encoded_objects)
313}
314
315/// Encode a complete Tensogram message.
316///
317/// `global_metadata` is the message-level metadata (version, MARS keys, etc.).
318/// `descriptors` is a list of (DataObjectDescriptor, raw_data) pairs.
319/// Returns the complete wire-format message.
320#[tracing::instrument(skip(global_metadata, descriptors, options), fields(objects = descriptors.len()))]
321pub fn encode(
322    global_metadata: &GlobalMetadata,
323    descriptors: &[(&DataObjectDescriptor, &[u8])],
324    options: &EncodeOptions,
325) -> Result<Vec<u8>> {
326    encode_inner(global_metadata, descriptors, options, EncodeMode::Raw)
327}
328
329/// Encode a pre-encoded Tensogram message where callers supply already-encoded bytes.
330///
331/// Use this when the payload bytes have already been encoded/compressed by an external
332/// pipeline. The library will:
333/// - Validate object descriptors (shape, dtype, etc.)
334/// - Validate encoding/compression params via `build_pipeline_config()` (defense-in-depth)
335/// - Use the caller's bytes directly as the encoded payload (no pipeline call)
336/// - Compute a fresh xxh3 hash over the caller's bytes (overwrites any caller-supplied hash)
337/// - Preserve caller-supplied `szip_block_offsets` in descriptor params
338///
339/// Callers must NOT set `emit_preceders = true` — use `StreamingEncoder::write_preceder()`
340/// for streaming preceder support.
341#[tracing::instrument(name = "encode_pre_encoded", skip_all, fields(num_objects = descriptors.len()))]
342pub fn encode_pre_encoded(
343    global_metadata: &GlobalMetadata,
344    descriptors: &[(&DataObjectDescriptor, &[u8])],
345    options: &EncodeOptions,
346) -> Result<Vec<u8>> {
347    encode_inner(
348        global_metadata,
349        descriptors,
350        options,
351        EncodeMode::PreEncoded,
352    )
353}
354
355/// Validate that the caller hasn't written to `_reserved_` at any level.
356///
357/// The `_reserved_` namespace is library-managed.  Client code must not
358/// set it in the message-level metadata or in any `base[i]` entry.
359fn validate_no_client_reserved(meta: &GlobalMetadata) -> Result<()> {
360    if !meta.reserved.is_empty() {
361        return Err(TensogramError::Metadata(format!(
362            "client code must not write to '{RESERVED_KEY}' at message level; \
363             this field is populated by the library"
364        )));
365    }
366    for (i, entry) in meta.base.iter().enumerate() {
367        if entry.contains_key(RESERVED_KEY) {
368            return Err(TensogramError::Metadata(format!(
369                "client code must not write to '{RESERVED_KEY}' in base[{i}]; \
370                 this field is populated by the library"
371            )));
372        }
373    }
374    Ok(())
375}
376
377/// Populate per-object base entries with tensor metadata under `_reserved_.tensor`.
378///
379/// Resizes `base` to match the object count, then inserts a `_reserved_`
380/// map containing `tensor: {ndim, shape, strides, dtype}` into each entry.
381/// Pre-existing application keys (e.g. `"mars"`) are preserved.
382pub(crate) fn populate_base_entries(
383    base: &mut Vec<BTreeMap<String, ciborium::Value>>,
384    encoded_objects: &[crate::framing::EncodedObject],
385) {
386    use ciborium::Value;
387
388    // Ensure base has exactly one entry per object.
389    base.resize_with(encoded_objects.len(), BTreeMap::new);
390
391    for (entry, obj) in base.iter_mut().zip(encoded_objects.iter()) {
392        let desc = &obj.descriptor;
393
394        let tensor_map = Value::Map(vec![
395            (
396                Value::Text("ndim".to_string()),
397                Value::Integer(desc.ndim.into()),
398            ),
399            (
400                Value::Text("shape".to_string()),
401                Value::Array(
402                    desc.shape
403                        .iter()
404                        .map(|&d| Value::Integer(d.into()))
405                        .collect(),
406                ),
407            ),
408            (
409                Value::Text("strides".to_string()),
410                Value::Array(
411                    desc.strides
412                        .iter()
413                        .map(|&s| Value::Integer(s.into()))
414                        .collect(),
415                ),
416            ),
417            (
418                Value::Text("dtype".to_string()),
419                Value::Text(desc.dtype.to_string()),
420            ),
421        ]);
422
423        let reserved_map = Value::Map(vec![(Value::Text("tensor".to_string()), tensor_map)]);
424
425        entry.insert(RESERVED_KEY.to_string(), reserved_map);
426    }
427}
428
429/// Populate the `reserved` section with provenance fields as specified in
430/// `WIRE_FORMAT.md`:
431///
432/// - `encoder.name` — `"tensogram"`
433/// - `encoder.version` — library version at encode time
434/// - `time` — UTC ISO 8601 timestamp
435/// - `uuid` — RFC 4122 v4 UUID
436///
437/// Pre-existing keys in `reserved` are preserved; only these four are
438/// set (or overwritten).
439pub(crate) fn populate_reserved_provenance(reserved: &mut BTreeMap<String, ciborium::Value>) {
440    use ciborium::Value;
441    #[cfg(not(target_arch = "wasm32"))]
442    use std::time::SystemTime;
443
444    // encoder.name + encoder.version
445    let version_str = env!("CARGO_PKG_VERSION");
446    let encoder_map = Value::Map(vec![
447        (
448            Value::Text("name".to_string()),
449            Value::Text("tensogram".to_string()),
450        ),
451        (
452            Value::Text("version".to_string()),
453            Value::Text(version_str.to_string()),
454        ),
455    ]);
456    reserved.insert("encoder".to_string(), encoder_map);
457
458    // time — ISO 8601 UTC
459    // On wasm32-unknown-unknown, SystemTime::now() panics. Skip the `time`
460    // field entirely rather than encoding a misleading epoch-0 timestamp.
461    // Callers can set a timestamp via `_extra_` if needed.
462    #[cfg(not(target_arch = "wasm32"))]
463    {
464        let secs = SystemTime::now()
465            .duration_since(SystemTime::UNIX_EPOCH)
466            .unwrap_or_default()
467            .as_secs();
468
469        // Simple UTC format: YYYY-MM-DDThh:mm:ssZ
470        // We compute from epoch seconds to avoid adding a datetime crate.
471        let days = secs / 86400;
472        let day_secs = secs % 86400;
473        let hours = day_secs / 3600;
474        let minutes = (day_secs % 3600) / 60;
475        let seconds = day_secs % 60;
476        // Civil date from days since 1970-01-01 (Howard Hinnant algorithm)
477        let (y, m, d) = civil_from_days(days as i64);
478        let timestamp = format!("{y:04}-{m:02}-{d:02}T{hours:02}:{minutes:02}:{seconds:02}Z");
479        reserved.insert("time".to_string(), Value::Text(timestamp));
480    }
481
482    // uuid — RFC 4122 v4
483    let id = uuid::Uuid::new_v4();
484    reserved.insert("uuid".to_string(), Value::Text(id.to_string()));
485}
486
487/// Convert days since 1970-01-01 to (year, month, day).
488/// Howard Hinnant's algorithm (public domain).
489#[cfg(not(target_arch = "wasm32"))]
490fn civil_from_days(days: i64) -> (i64, u32, u32) {
491    let z = days + 719468;
492    let era = if z >= 0 { z } else { z - 146096 } / 146097;
493    // doe (day of era) is guaranteed in [0, 146096] by the era computation,
494    // so the u32 cast cannot truncate.
495    let doe = (z - era * 146097) as u32;
496    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
497    let y = yoe as i64 + era * 400;
498    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
499    let mp = (5 * doy + 2) / 153;
500    let d = doy - (153 * mp + 2) / 5 + 1;
501    let m = if mp < 10 { mp + 3 } else { mp - 9 };
502    let y = if m <= 2 { y + 1 } else { y };
503    (y, m, d)
504}
505
506pub(crate) fn build_pipeline_config(
507    desc: &DataObjectDescriptor,
508    num_values: usize,
509    dtype: Dtype,
510) -> Result<PipelineConfig> {
511    build_pipeline_config_with_backend(
512        desc,
513        num_values,
514        dtype,
515        pipeline::CompressionBackend::default(),
516        0,
517    )
518}
519
520/// Build a pipeline config with an explicit compression backend override
521/// and an intra-codec thread budget.
522///
523/// `intra_codec_threads == 0` preserves the pre-threads behaviour and is
524/// what direct pipeline callers (benchmarks, external code) should use.
525pub(crate) fn build_pipeline_config_with_backend(
526    desc: &DataObjectDescriptor,
527    num_values: usize,
528    dtype: Dtype,
529    compression_backend: pipeline::CompressionBackend,
530    intra_codec_threads: u32,
531) -> Result<PipelineConfig> {
532    let encoding = match desc.encoding.as_str() {
533        "none" => EncodingType::None,
534        "simple_packing" => {
535            if dtype.byte_width() != 8 {
536                return Err(TensogramError::Encoding(
537                    "simple_packing only supports float64 dtype".to_string(),
538                ));
539            }
540            let params = extract_simple_packing_params(&desc.params)?;
541            EncodingType::SimplePacking(params)
542        }
543        other => {
544            return Err(TensogramError::Encoding(format!(
545                "unknown encoding: {other}"
546            )));
547        }
548    };
549
550    let filter = match desc.filter.as_str() {
551        "none" => FilterType::None,
552        "shuffle" => {
553            let element_size = usize::try_from(get_u64_param(
554                &desc.params,
555                "shuffle_element_size",
556            )?)
557            .map_err(|_| {
558                TensogramError::Metadata("shuffle_element_size out of usize range".to_string())
559            })?;
560            FilterType::Shuffle { element_size }
561        }
562        other => return Err(TensogramError::Encoding(format!("unknown filter: {other}"))),
563    };
564
565    let compression = match desc.compression.as_str() {
566        "none" => CompressionType::None,
567        #[cfg(any(feature = "szip", feature = "szip-pure"))]
568        "szip" => {
569            let rsi = u32::try_from(get_u64_param(&desc.params, "szip_rsi")?)
570                .map_err(|_| TensogramError::Metadata("szip_rsi out of u32 range".to_string()))?;
571            let block_size = u32::try_from(get_u64_param(&desc.params, "szip_block_size")?)
572                .map_err(|_| {
573                    TensogramError::Metadata("szip_block_size out of u32 range".to_string())
574                })?;
575            let flags = u32::try_from(get_u64_param(&desc.params, "szip_flags")?)
576                .map_err(|_| TensogramError::Metadata("szip_flags out of u32 range".to_string()))?;
577            let bits_per_sample = match (&encoding, &filter) {
578                (EncodingType::SimplePacking(params), _) => params.bits_per_value,
579                (EncodingType::None, FilterType::Shuffle { .. }) => 8,
580                (EncodingType::None, FilterType::None) => (dtype.byte_width() * 8) as u32,
581            };
582            CompressionType::Szip {
583                rsi,
584                block_size,
585                flags,
586                bits_per_sample,
587            }
588        }
589        #[cfg(any(feature = "zstd", feature = "zstd-pure"))]
590        "zstd" => {
591            let level_i64 = get_i64_param(&desc.params, "zstd_level").unwrap_or(3);
592            let level = i32::try_from(level_i64).map_err(|_| {
593                TensogramError::Metadata(format!("zstd_level value {level_i64} out of i32 range"))
594            })?;
595            CompressionType::Zstd { level }
596        }
597        #[cfg(feature = "lz4")]
598        "lz4" => CompressionType::Lz4,
599        #[cfg(feature = "blosc2")]
600        "blosc2" => {
601            let codec_str = match desc.params.get("blosc2_codec") {
602                Some(ciborium::Value::Text(s)) => s.as_str(),
603                _ => "lz4",
604            };
605            let codec = match codec_str {
606                "blosclz" => Blosc2Codec::Blosclz,
607                "lz4" => Blosc2Codec::Lz4,
608                "lz4hc" => Blosc2Codec::Lz4hc,
609                "zlib" => Blosc2Codec::Zlib,
610                "zstd" => Blosc2Codec::Zstd,
611                other => {
612                    return Err(TensogramError::Encoding(format!(
613                        "unknown blosc2 codec: {other}"
614                    )));
615                }
616            };
617            let clevel_i64 = get_i64_param(&desc.params, "blosc2_clevel").unwrap_or(5);
618            let clevel = i32::try_from(clevel_i64).map_err(|_| {
619                TensogramError::Metadata(format!(
620                    "blosc2_clevel value {clevel_i64} out of i32 range"
621                ))
622            })?;
623            let typesize = match (&encoding, &filter) {
624                (EncodingType::SimplePacking(params), _) => {
625                    (params.bits_per_value as usize).div_ceil(8)
626                }
627                (EncodingType::None, FilterType::Shuffle { .. }) => 1,
628                (EncodingType::None, FilterType::None) => dtype.byte_width(),
629            };
630            CompressionType::Blosc2 {
631                codec,
632                clevel,
633                typesize,
634            }
635        }
636        #[cfg(feature = "zfp")]
637        "zfp" => {
638            let mode_str = match desc.params.get("zfp_mode") {
639                Some(ciborium::Value::Text(s)) => s.clone(),
640                _ => {
641                    return Err(TensogramError::Metadata(
642                        "missing required parameter: zfp_mode".to_string(),
643                    ));
644                }
645            };
646            let mode = match mode_str.as_str() {
647                "fixed_rate" => {
648                    let rate = get_f64_param(&desc.params, "zfp_rate")?;
649                    ZfpMode::FixedRate { rate }
650                }
651                "fixed_precision" => {
652                    let precision = u32::try_from(get_u64_param(&desc.params, "zfp_precision")?)
653                        .map_err(|_| {
654                            TensogramError::Metadata("zfp_precision out of u32 range".to_string())
655                        })?;
656                    ZfpMode::FixedPrecision { precision }
657                }
658                "fixed_accuracy" => {
659                    let tolerance = get_f64_param(&desc.params, "zfp_tolerance")?;
660                    ZfpMode::FixedAccuracy { tolerance }
661                }
662                other => {
663                    return Err(TensogramError::Encoding(format!(
664                        "unknown zfp_mode: {other}"
665                    )));
666                }
667            };
668            CompressionType::Zfp { mode }
669        }
670        #[cfg(feature = "sz3")]
671        "sz3" => {
672            let mode_str = match desc.params.get("sz3_error_bound_mode") {
673                Some(ciborium::Value::Text(s)) => s.clone(),
674                _ => {
675                    return Err(TensogramError::Metadata(
676                        "missing required parameter: sz3_error_bound_mode".to_string(),
677                    ));
678                }
679            };
680            let bound_val = get_f64_param(&desc.params, "sz3_error_bound")?;
681            let error_bound = match mode_str.as_str() {
682                "abs" => Sz3ErrorBound::Absolute(bound_val),
683                "rel" => Sz3ErrorBound::Relative(bound_val),
684                "psnr" => Sz3ErrorBound::Psnr(bound_val),
685                other => {
686                    return Err(TensogramError::Encoding(format!(
687                        "unknown sz3_error_bound_mode: {other}"
688                    )));
689                }
690            };
691            CompressionType::Sz3 { error_bound }
692        }
693        other => {
694            return Err(TensogramError::Encoding(format!(
695                "unknown compression: {other}"
696            )));
697        }
698    };
699
700    Ok(PipelineConfig {
701        encoding,
702        filter,
703        compression,
704        num_values,
705        byte_order: desc.byte_order,
706        dtype_byte_width: dtype.byte_width(),
707        swap_unit_size: dtype.swap_unit_size(),
708        compression_backend,
709        intra_codec_threads,
710    })
711}
712
713fn extract_simple_packing_params(
714    params: &BTreeMap<String, ciborium::Value>,
715) -> Result<SimplePackingParams> {
716    let reference_value = get_f64_param(params, "reference_value")?;
717    if reference_value.is_nan() || reference_value.is_infinite() {
718        return Err(TensogramError::Metadata(format!(
719            "reference_value must be finite, got {reference_value}"
720        )));
721    }
722    Ok(SimplePackingParams {
723        reference_value,
724        binary_scale_factor: i32::try_from(get_i64_param(params, "binary_scale_factor")?).map_err(
725            |_| TensogramError::Metadata("binary_scale_factor out of i32 range".to_string()),
726        )?,
727        decimal_scale_factor: i32::try_from(get_i64_param(params, "decimal_scale_factor")?)
728            .map_err(|_| {
729                TensogramError::Metadata("decimal_scale_factor out of i32 range".to_string())
730            })?,
731        bits_per_value: u32::try_from(get_u64_param(params, "bits_per_value")?)
732            .map_err(|_| TensogramError::Metadata("bits_per_value out of u32 range".to_string()))?,
733    })
734}
735
736pub(crate) fn get_f64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<f64> {
737    match params.get(key) {
738        Some(ciborium::Value::Float(f)) => Ok(*f),
739        Some(ciborium::Value::Integer(i)) => {
740            // i128 → f64 may lose precision for very large integers (> 2^53),
741            // but this is acceptable for a float accessor on an integer value.
742            let n: i128 = (*i).into();
743            Ok(n as f64)
744        }
745        Some(other) => Err(TensogramError::Metadata(format!(
746            "expected number for {key}, got {other:?}"
747        ))),
748        None => Err(TensogramError::Metadata(format!(
749            "missing required parameter: {key}"
750        ))),
751    }
752}
753
754pub(crate) fn get_i64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<i64> {
755    match params.get(key) {
756        Some(ciborium::Value::Integer(i)) => {
757            let n: i128 = (*i).into();
758            i64::try_from(n).map_err(|_| {
759                TensogramError::Metadata(format!("integer value {n} out of i64 range for {key}"))
760            })
761        }
762        Some(other) => Err(TensogramError::Metadata(format!(
763            "expected integer for {key}, got {other:?}"
764        ))),
765        None => Err(TensogramError::Metadata(format!(
766            "missing required parameter: {key}"
767        ))),
768    }
769}
770
771pub(crate) fn get_u64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<u64> {
772    match params.get(key) {
773        Some(ciborium::Value::Integer(i)) => {
774            let n: i128 = (*i).into();
775            u64::try_from(n).map_err(|_| {
776                TensogramError::Metadata(format!("integer value {n} out of u64 range for {key}"))
777            })
778        }
779        Some(other) => Err(TensogramError::Metadata(format!(
780            "expected integer for {key}, got {other:?}"
781        ))),
782        None => Err(TensogramError::Metadata(format!(
783            "missing required parameter: {key}"
784        ))),
785    }
786}
787
788pub(crate) fn validate_szip_block_offsets(
789    params: &BTreeMap<String, ciborium::Value>,
790    encoded_bytes_len: usize,
791) -> Result<()> {
792    let value = params.get("szip_block_offsets").ok_or_else(|| {
793        TensogramError::Metadata(
794            "missing required parameter: szip_block_offsets for szip compression".to_string(),
795        )
796    })?;
797
798    let offsets = match value {
799        ciborium::Value::Array(arr) => arr,
800        other => {
801            return Err(TensogramError::Metadata(format!(
802                "szip_block_offsets must be an array, got {other:?}"
803            )));
804        }
805    };
806
807    if offsets.is_empty() {
808        return Err(TensogramError::Metadata(
809            "szip_block_offsets must not be empty; first offset must be 0".to_string(),
810        ));
811    }
812
813    let bit_bound = encoded_bytes_len.checked_mul(8).ok_or_else(|| {
814        TensogramError::Metadata(format!(
815            "encoded byte length {encoded_bytes_len} overflows bit-bound calculation"
816        ))
817    })?;
818    let bit_bound_u64 = u64::try_from(bit_bound).map_err(|_| {
819        TensogramError::Metadata(format!(
820            "bit-bound {bit_bound} derived from {encoded_bytes_len} bytes does not fit in u64"
821        ))
822    })?;
823
824    let mut parsed_offsets = Vec::with_capacity(offsets.len());
825    for (idx, item) in offsets.iter().enumerate() {
826        let offset = match item {
827            ciborium::Value::Integer(i) => {
828                let n: i128 = (*i).into();
829                u64::try_from(n).map_err(|_| {
830                    TensogramError::Metadata(format!(
831                        "szip_block_offsets[{idx}] = {n} out of u64 range"
832                    ))
833                })?
834            }
835            other => {
836                return Err(TensogramError::Metadata(format!(
837                    "szip_block_offsets[{idx}] must be an integer, got {other:?}"
838                )));
839            }
840        };
841
842        if offset > bit_bound_u64 {
843            return Err(TensogramError::Metadata(format!(
844                "szip_block_offsets[{idx}] = {offset} exceeds bit bound {bit_bound_u64} (encoded_bytes_len = {encoded_bytes_len} bytes, {bit_bound_u64} bits)"
845            )));
846        }
847
848        if idx == 0 {
849            if offset != 0 {
850                return Err(TensogramError::Metadata(format!(
851                    "szip_block_offsets[0] must be 0, got {offset}"
852                )));
853            }
854        } else {
855            let prev = parsed_offsets[idx - 1];
856            if offset <= prev {
857                return Err(TensogramError::Metadata(format!(
858                    "szip_block_offsets must be strictly increasing: szip_block_offsets[{}] = {}, szip_block_offsets[{idx}] = {offset}",
859                    idx - 1,
860                    prev
861                )));
862            }
863        }
864
865        parsed_offsets.push(offset);
866    }
867
868    Ok(())
869}
870
871pub(crate) fn validate_no_szip_offsets_for_non_szip(desc: &DataObjectDescriptor) -> Result<()> {
872    if desc.compression != "szip" && desc.params.contains_key("szip_block_offsets") {
873        return Err(TensogramError::Metadata(format!(
874            "szip_block_offsets provided but compression is '{}', not 'szip'",
875            desc.compression
876        )));
877    }
878    Ok(())
879}
880
881// ── Edge case tests ─────────────────────────────────────────────────────────
882
883#[cfg(test)]
884mod tests {
885    use super::*;
886    use crate::decode::{DecodeOptions, decode};
887    use crate::types::{ByteOrder, GlobalMetadata};
888    use std::collections::BTreeMap;
889
890    fn make_descriptor(shape: Vec<u64>) -> DataObjectDescriptor {
891        let strides = {
892            let mut s = vec![1u64; shape.len()];
893            for i in (0..shape.len().saturating_sub(1)).rev() {
894                s[i] = s[i + 1] * shape[i + 1];
895            }
896            s
897        };
898        DataObjectDescriptor {
899            obj_type: "ntensor".to_string(),
900            ndim: shape.len() as u64,
901            shape,
902            strides,
903            dtype: Dtype::Float32,
904            byte_order: ByteOrder::native(),
905            encoding: "none".to_string(),
906            filter: "none".to_string(),
907            compression: "none".to_string(),
908            params: BTreeMap::new(),
909            hash: None,
910        }
911    }
912
913    // ── Category 1: base array mismatches ────────────────────────────────
914
915    #[test]
916    fn test_base_more_entries_than_descriptors_rejected() {
917        // base has 5 entries but only 2 descriptors — should error.
918        let meta = GlobalMetadata {
919            version: 2,
920            base: vec![
921                BTreeMap::new(),
922                BTreeMap::new(),
923                BTreeMap::new(),
924                BTreeMap::new(),
925                BTreeMap::new(),
926            ],
927            ..Default::default()
928        };
929        let desc = make_descriptor(vec![4]);
930        let data = vec![0u8; 16];
931        let options = EncodeOptions {
932            hash_algorithm: None,
933            ..Default::default()
934        };
935        let result = encode(
936            &meta,
937            &[(&desc, data.as_slice()), (&desc, data.as_slice())],
938            &options,
939        );
940        assert!(
941            result.is_err(),
942            "5 base entries with 2 descriptors should fail"
943        );
944        let err = result.unwrap_err().to_string();
945        assert!(
946            err.contains("5") && err.contains("2"),
947            "error should mention counts: {err}"
948        );
949    }
950
951    #[test]
952    fn test_base_fewer_entries_than_descriptors_auto_extended() {
953        // base has 0 entries but 3 descriptors — auto-extends, _reserved_ inserted.
954        let meta = GlobalMetadata {
955            version: 2,
956            base: vec![],
957            ..Default::default()
958        };
959        let desc = make_descriptor(vec![2]);
960        let data = vec![0u8; 8];
961        let options = EncodeOptions {
962            hash_algorithm: None,
963            ..Default::default()
964        };
965        let msg = encode(
966            &meta,
967            &[
968                (&desc, data.as_slice()),
969                (&desc, data.as_slice()),
970                (&desc, data.as_slice()),
971            ],
972            &options,
973        )
974        .unwrap();
975
976        let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
977        assert_eq!(decoded.base.len(), 3);
978        // Each entry should have _reserved_ with tensor info
979        for entry in &decoded.base {
980            assert!(
981                entry.contains_key("_reserved_"),
982                "auto-extended base entry should have _reserved_"
983            );
984        }
985    }
986
987    #[test]
988    fn test_base_entry_with_top_level_key_names_no_collision() {
989        // base[0] contains a key named "version" — no collision with top-level version.
990        let mut entry = BTreeMap::new();
991        entry.insert(
992            "version".to_string(),
993            ciborium::Value::Text("my-version".to_string()),
994        );
995        entry.insert(
996            "base".to_string(),
997            ciborium::Value::Text("not-the-real-base".to_string()),
998        );
999        let meta = GlobalMetadata {
1000            version: 2,
1001            base: vec![entry],
1002            ..Default::default()
1003        };
1004        let desc = make_descriptor(vec![2]);
1005        let data = vec![0u8; 8];
1006        let options = EncodeOptions {
1007            hash_algorithm: None,
1008            ..Default::default()
1009        };
1010        let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1011        let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1012
1013        // Top-level version is still 2
1014        assert_eq!(decoded.version, 2);
1015        // base[0] should have both custom keys preserved
1016        assert_eq!(
1017            decoded.base[0].get("version"),
1018            Some(&ciborium::Value::Text("my-version".to_string()))
1019        );
1020        assert_eq!(
1021            decoded.base[0].get("base"),
1022            Some(&ciborium::Value::Text("not-the-real-base".to_string()))
1023        );
1024    }
1025
1026    #[test]
1027    fn test_base_entry_with_deeply_nested_reserved_allowed() {
1028        // Only top-level _reserved_ in base[i] should be rejected.
1029        // Deeply nested _reserved_ (like {"foo": {"_reserved_": ...}}) is fine.
1030        let nested = ciborium::Value::Map(vec![(
1031            ciborium::Value::Text("_reserved_".to_string()),
1032            ciborium::Value::Text("nested-is-ok".to_string()),
1033        )]);
1034        let mut entry = BTreeMap::new();
1035        entry.insert("foo".to_string(), nested);
1036        let meta = GlobalMetadata {
1037            version: 2,
1038            base: vec![entry],
1039            ..Default::default()
1040        };
1041        let desc = make_descriptor(vec![2]);
1042        let data = vec![0u8; 8];
1043        let options = EncodeOptions {
1044            hash_algorithm: None,
1045            ..Default::default()
1046        };
1047        // Should succeed — only top-level _reserved_ is rejected
1048        let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1049        let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1050        // The nested _reserved_ should survive
1051        let foo = decoded.base[0].get("foo").unwrap();
1052        if let ciborium::Value::Map(pairs) = foo {
1053            assert_eq!(pairs.len(), 1);
1054        } else {
1055            panic!("expected map for foo");
1056        }
1057    }
1058
1059    // ── Category 2: _reserved_ edge cases ────────────────────────────────
1060
1061    #[test]
1062    fn test_reserved_rejected_at_message_level() {
1063        let mut reserved = BTreeMap::new();
1064        reserved.insert(
1065            "rogue".to_string(),
1066            ciborium::Value::Text("bad".to_string()),
1067        );
1068        let meta = GlobalMetadata {
1069            version: 2,
1070            reserved,
1071            ..Default::default()
1072        };
1073        let desc = make_descriptor(vec![2]);
1074        let data = vec![0u8; 8];
1075        let result = encode(
1076            &meta,
1077            &[(&desc, data.as_slice())],
1078            &EncodeOptions::default(),
1079        );
1080        assert!(result.is_err());
1081        let err = result.unwrap_err().to_string();
1082        assert!(
1083            err.contains("_reserved_") && err.contains("message level"),
1084            "error: {err}"
1085        );
1086    }
1087
1088    #[test]
1089    fn test_reserved_rejected_in_base_entry() {
1090        let mut entry = BTreeMap::new();
1091        entry.insert("_reserved_".to_string(), ciborium::Value::Map(vec![]));
1092        let meta = GlobalMetadata {
1093            version: 2,
1094            base: vec![entry],
1095            ..Default::default()
1096        };
1097        let desc = make_descriptor(vec![2]);
1098        let data = vec![0u8; 8];
1099        let result = encode(
1100            &meta,
1101            &[(&desc, data.as_slice())],
1102            &EncodeOptions::default(),
1103        );
1104        assert!(result.is_err());
1105        let err = result.unwrap_err().to_string();
1106        assert!(
1107            err.contains("_reserved_") && err.contains("base[0]"),
1108            "error: {err}"
1109        );
1110    }
1111
1112    #[test]
1113    fn test_reserved_tensor_has_four_keys_after_encode() {
1114        let meta = GlobalMetadata::default();
1115        let desc = make_descriptor(vec![3, 4]);
1116        let data = vec![0u8; 3 * 4 * 4]; // 3*4 float32
1117        let options = EncodeOptions {
1118            hash_algorithm: None,
1119            ..Default::default()
1120        };
1121        let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1122        let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1123
1124        let reserved = decoded.base[0]
1125            .get("_reserved_")
1126            .expect("_reserved_ missing");
1127        if let ciborium::Value::Map(pairs) = reserved {
1128            // Should have "tensor" key
1129            let tensor_entry = pairs
1130                .iter()
1131                .find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1132            assert!(tensor_entry.is_some(), "missing tensor key in _reserved_");
1133            if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
1134                let keys: Vec<String> = tensor_pairs
1135                    .iter()
1136                    .filter_map(|(k, _)| {
1137                        if let ciborium::Value::Text(s) = k {
1138                            Some(s.clone())
1139                        } else {
1140                            None
1141                        }
1142                    })
1143                    .collect();
1144                assert_eq!(keys.len(), 4, "tensor should have 4 keys, got: {keys:?}");
1145                assert!(keys.contains(&"ndim".to_string()));
1146                assert!(keys.contains(&"shape".to_string()));
1147                assert!(keys.contains(&"strides".to_string()));
1148                assert!(keys.contains(&"dtype".to_string()));
1149            } else {
1150                panic!("tensor is not a map");
1151            }
1152        } else {
1153            panic!("_reserved_ is not a map");
1154        }
1155    }
1156
1157    #[test]
1158    fn test_reserved_tensor_scalar_ndim_zero() {
1159        // Scalar: ndim=0, shape=[], strides=[]
1160        let desc = DataObjectDescriptor {
1161            obj_type: "ntensor".to_string(),
1162            ndim: 0,
1163            shape: vec![],
1164            strides: vec![],
1165            dtype: Dtype::Float32,
1166            byte_order: ByteOrder::native(),
1167            encoding: "none".to_string(),
1168            filter: "none".to_string(),
1169            compression: "none".to_string(),
1170            params: BTreeMap::new(),
1171            hash: None,
1172        };
1173        let data = vec![0u8; 4]; // 1 float32
1174        let meta = GlobalMetadata::default();
1175        let options = EncodeOptions {
1176            hash_algorithm: None,
1177            ..Default::default()
1178        };
1179        let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1180        let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1181
1182        let reserved = decoded.base[0]
1183            .get("_reserved_")
1184            .expect("_reserved_ missing");
1185        if let ciborium::Value::Map(pairs) = reserved {
1186            let tensor_entry = pairs
1187                .iter()
1188                .find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1189            if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
1190                // ndim should be 0
1191                let ndim = tensor_pairs
1192                    .iter()
1193                    .find(|(k, _)| *k == ciborium::Value::Text("ndim".to_string()));
1194                assert!(
1195                    matches!(ndim, Some((_, ciborium::Value::Integer(i))) if i128::from(*i) == 0),
1196                    "ndim should be 0 for scalar"
1197                );
1198                // shape should be empty array
1199                let shape = tensor_pairs
1200                    .iter()
1201                    .find(|(k, _)| *k == ciborium::Value::Text("shape".to_string()));
1202                assert!(
1203                    matches!(shape, Some((_, ciborium::Value::Array(a))) if a.is_empty()),
1204                    "shape should be [] for scalar"
1205                );
1206            } else {
1207                panic!("tensor missing or not a map");
1208            }
1209        } else {
1210            panic!("_reserved_ is not a map");
1211        }
1212    }
1213
1214    // ── Category 3: _extra_ edge cases ───────────────────────────────────
1215
1216    #[test]
1217    fn test_extra_with_keys_colliding_with_base_entry_keys() {
1218        // _extra_ has key "mars", base[0] also has key "mars" — different scopes, both survive
1219        let mut entry = BTreeMap::new();
1220        entry.insert(
1221            "mars".to_string(),
1222            ciborium::Value::Text("base-mars".to_string()),
1223        );
1224        let mut extra = BTreeMap::new();
1225        extra.insert(
1226            "mars".to_string(),
1227            ciborium::Value::Text("extra-mars".to_string()),
1228        );
1229        let meta = GlobalMetadata {
1230            version: 2,
1231            base: vec![entry],
1232            extra,
1233            ..Default::default()
1234        };
1235        let desc = make_descriptor(vec![2]);
1236        let data = vec![0u8; 8];
1237        let options = EncodeOptions {
1238            hash_algorithm: None,
1239            ..Default::default()
1240        };
1241        let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1242        let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1243
1244        assert_eq!(
1245            decoded.base[0].get("mars"),
1246            Some(&ciborium::Value::Text("base-mars".to_string()))
1247        );
1248        assert_eq!(
1249            decoded.extra.get("mars"),
1250            Some(&ciborium::Value::Text("extra-mars".to_string()))
1251        );
1252    }
1253
1254    #[test]
1255    fn test_empty_extra_omitted_from_cbor() {
1256        let meta = GlobalMetadata {
1257            version: 2,
1258            extra: BTreeMap::new(),
1259            ..Default::default()
1260        };
1261        let desc = make_descriptor(vec![2]);
1262        let data = vec![0u8; 8];
1263        let options = EncodeOptions {
1264            hash_algorithm: None,
1265            ..Default::default()
1266        };
1267        let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1268        let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1269        assert!(decoded.extra.is_empty());
1270    }
1271
1272    #[test]
1273    fn test_extra_with_nested_maps_round_trips() {
1274        let nested = ciborium::Value::Map(vec![
1275            (
1276                ciborium::Value::Text("key1".to_string()),
1277                ciborium::Value::Integer(42.into()),
1278            ),
1279            (
1280                ciborium::Value::Text("key2".to_string()),
1281                ciborium::Value::Map(vec![(
1282                    ciborium::Value::Text("deep".to_string()),
1283                    ciborium::Value::Bool(true),
1284                )]),
1285            ),
1286        ]);
1287        let mut extra = BTreeMap::new();
1288        extra.insert("nested".to_string(), nested.clone());
1289        let meta = GlobalMetadata {
1290            version: 2,
1291            extra,
1292            ..Default::default()
1293        };
1294        let desc = make_descriptor(vec![2]);
1295        let data = vec![0u8; 8];
1296        let options = EncodeOptions {
1297            hash_algorithm: None,
1298            ..Default::default()
1299        };
1300        let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1301        let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1302        // Nested maps should round-trip
1303        assert!(decoded.extra.contains_key("nested"));
1304    }
1305
1306    // ── Category 4: Serde deserialization ────────────────────────────────
1307
1308    #[test]
1309    fn test_old_common_payload_keys_silently_ignored() {
1310        // Simulate an old v2 message with "common" and "payload" keys at top level.
1311        // GlobalMetadata uses `deny_unknown_fields` is NOT set (serde default),
1312        // so unknown keys should be silently ignored.
1313        use ciborium::Value;
1314        let cbor = Value::Map(vec![
1315            (Value::Text("version".to_string()), Value::Integer(2.into())),
1316            (Value::Text("common".to_string()), Value::Map(vec![])),
1317            (Value::Text("payload".to_string()), Value::Array(vec![])),
1318        ]);
1319        let mut bytes = Vec::new();
1320        ciborium::into_writer(&cbor, &mut bytes).unwrap();
1321
1322        let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
1323        assert_eq!(decoded.version, 2);
1324        assert!(decoded.base.is_empty());
1325        assert!(decoded.extra.is_empty());
1326        assert!(decoded.reserved.is_empty());
1327    }
1328
1329    #[test]
1330    fn test_old_reserved_key_name_ignored() {
1331        // "reserved" (old name) should be ignored, only "_reserved_" is captured.
1332        use ciborium::Value;
1333        let cbor = Value::Map(vec![
1334            (Value::Text("version".to_string()), Value::Integer(2.into())),
1335            (
1336                Value::Text("reserved".to_string()),
1337                Value::Map(vec![(
1338                    Value::Text("rogue".to_string()),
1339                    Value::Text("value".to_string()),
1340                )]),
1341            ),
1342        ]);
1343        let mut bytes = Vec::new();
1344        ciborium::into_writer(&cbor, &mut bytes).unwrap();
1345
1346        let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
1347        assert!(
1348            decoded.reserved.is_empty(),
1349            "old 'reserved' key should be ignored"
1350        );
1351    }
1352
1353    // ── Category 4b: validate_no_client_reserved — multi-entry base ────
1354
1355    #[test]
1356    fn test_reserved_rejected_in_second_base_entry_only() {
1357        // base[0] is clean, base[1] has _reserved_ → should fail, mentioning base[1]
1358        let mut entry0 = BTreeMap::new();
1359        entry0.insert("clean".to_string(), ciborium::Value::Text("ok".to_string()));
1360        let mut entry1 = BTreeMap::new();
1361        entry1.insert("_reserved_".to_string(), ciborium::Value::Map(vec![]));
1362        let meta = GlobalMetadata {
1363            version: 2,
1364            base: vec![entry0, entry1],
1365            ..Default::default()
1366        };
1367        let desc = make_descriptor(vec![2]);
1368        let data = vec![0u8; 8];
1369        let result = encode(
1370            &meta,
1371            &[(&desc, data.as_slice()), (&desc, data.as_slice())],
1372            &EncodeOptions::default(),
1373        );
1374        assert!(result.is_err());
1375        let err = result.unwrap_err().to_string();
1376        assert!(
1377            err.contains("base[1]"),
1378            "error should mention base[1]: {err}"
1379        );
1380    }
1381
1382    #[test]
1383    fn test_reserved_accepted_when_all_base_entries_clean() {
1384        // Multiple base entries, none have _reserved_ → should succeed
1385        let mut e0 = BTreeMap::new();
1386        e0.insert(
1387            "key0".to_string(),
1388            ciborium::Value::Text("val0".to_string()),
1389        );
1390        let mut e1 = BTreeMap::new();
1391        e1.insert(
1392            "key1".to_string(),
1393            ciborium::Value::Text("val1".to_string()),
1394        );
1395        let meta = GlobalMetadata {
1396            version: 2,
1397            base: vec![e0, e1],
1398            ..Default::default()
1399        };
1400        let desc = make_descriptor(vec![2]);
1401        let data = vec![0u8; 8];
1402        let options = EncodeOptions {
1403            hash_algorithm: None,
1404            ..Default::default()
1405        };
1406        let msg = encode(
1407            &meta,
1408            &[(&desc, data.as_slice()), (&desc, data.as_slice())],
1409            &options,
1410        )
1411        .unwrap();
1412        let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1413        assert_eq!(decoded.base.len(), 2);
1414        assert!(decoded.base[0].contains_key("key0"));
1415        assert!(decoded.base[1].contains_key("key1"));
1416    }
1417
1418    // ── Category 5: populate_base_entries — all dtypes ───────────────────
1419
1420    #[test]
1421    fn test_reserved_tensor_dtype_strings_for_all_dtypes() {
1422        // Verify that _reserved_.tensor.dtype string is correct for every Dtype variant
1423        let dtypes_and_expected = [
1424            (Dtype::Float16, "float16"),
1425            (Dtype::Bfloat16, "bfloat16"),
1426            (Dtype::Float32, "float32"),
1427            (Dtype::Float64, "float64"),
1428            (Dtype::Complex64, "complex64"),
1429            (Dtype::Complex128, "complex128"),
1430            (Dtype::Int8, "int8"),
1431            (Dtype::Int16, "int16"),
1432            (Dtype::Int32, "int32"),
1433            (Dtype::Int64, "int64"),
1434            (Dtype::Uint8, "uint8"),
1435            (Dtype::Uint16, "uint16"),
1436            (Dtype::Uint32, "uint32"),
1437            (Dtype::Uint64, "uint64"),
1438        ];
1439
1440        for (dtype, expected_str) in dtypes_and_expected {
1441            let byte_width = dtype.byte_width();
1442            let num_elements: u64 = 4;
1443            let data_len = num_elements as usize * byte_width;
1444
1445            let desc = DataObjectDescriptor {
1446                obj_type: "ntensor".to_string(),
1447                ndim: 1,
1448                shape: vec![num_elements],
1449                strides: vec![1],
1450                dtype,
1451                byte_order: ByteOrder::native(),
1452                encoding: "none".to_string(),
1453                filter: "none".to_string(),
1454                compression: "none".to_string(),
1455                params: BTreeMap::new(),
1456                hash: None,
1457            };
1458            let data = vec![0u8; data_len];
1459            let meta = GlobalMetadata::default();
1460            let options = EncodeOptions {
1461                hash_algorithm: None,
1462                ..Default::default()
1463            };
1464            let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1465            let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1466
1467            let reserved = decoded.base[0]
1468                .get("_reserved_")
1469                .unwrap_or_else(|| panic!("_reserved_ missing for dtype {dtype}"));
1470            if let ciborium::Value::Map(pairs) = reserved {
1471                let tensor_entry = pairs
1472                    .iter()
1473                    .find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1474                if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
1475                    let dtype_val = tensor_pairs
1476                        .iter()
1477                        .find(|(k, _)| *k == ciborium::Value::Text("dtype".to_string()));
1478                    assert!(
1479                        matches!(
1480                            dtype_val,
1481                            Some((_, ciborium::Value::Text(s))) if s == expected_str
1482                        ),
1483                        "dtype for {dtype} should be '{expected_str}', got: {dtype_val:?}"
1484                    );
1485                } else {
1486                    panic!("tensor missing or not a map for dtype {dtype}");
1487                }
1488            } else {
1489                panic!("_reserved_ is not a map for dtype {dtype}");
1490            }
1491        }
1492    }
1493
1494    // ── Category 6: GlobalMetadata serde with all fields ─────────────────
1495
1496    #[test]
1497    fn test_global_metadata_serde_all_fields_populated() {
1498        // base + reserved + extra all populated — verify CBOR round-trip
1499        use ciborium::Value;
1500
1501        let mut base_entry = BTreeMap::new();
1502        base_entry.insert("key".to_string(), Value::Text("base_val".to_string()));
1503        let mut reserved = BTreeMap::new();
1504        reserved.insert("encoder".to_string(), Value::Text("test".to_string()));
1505        let mut extra = BTreeMap::new();
1506        extra.insert("custom".to_string(), Value::Integer(42.into()));
1507
1508        let meta = GlobalMetadata {
1509            version: 2,
1510            base: vec![base_entry],
1511            reserved,
1512            extra,
1513        };
1514
1515        // Serialize to CBOR and back
1516        let cbor_bytes = crate::metadata::global_metadata_to_cbor(&meta).unwrap();
1517        let decoded: GlobalMetadata =
1518            crate::metadata::cbor_to_global_metadata(&cbor_bytes).unwrap();
1519
1520        assert_eq!(decoded.version, 2);
1521        assert_eq!(decoded.base.len(), 1);
1522        assert_eq!(
1523            decoded.base[0].get("key"),
1524            Some(&Value::Text("base_val".to_string()))
1525        );
1526        assert!(decoded.reserved.contains_key("encoder"));
1527        assert_eq!(
1528            decoded.extra.get("custom"),
1529            Some(&Value::Integer(42.into()))
1530        );
1531    }
1532
1533    // ── Category 7: populate_reserved_provenance ─────────────────────────
1534
1535    #[test]
1536    fn test_provenance_fields_present_after_encode() {
1537        let meta = GlobalMetadata::default();
1538        let desc = make_descriptor(vec![2]);
1539        let data = vec![0u8; 8];
1540        let options = EncodeOptions {
1541            hash_algorithm: None,
1542            ..Default::default()
1543        };
1544        let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1545        let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1546
1547        // Message-level reserved should have encoder, time, uuid
1548        assert!(decoded.reserved.contains_key("encoder"));
1549        assert!(decoded.reserved.contains_key("time"));
1550        assert!(decoded.reserved.contains_key("uuid"));
1551
1552        // encoder should contain name and version
1553        if let ciborium::Value::Map(pairs) = decoded.reserved.get("encoder").unwrap() {
1554            let has_name = pairs
1555                .iter()
1556                .any(|(k, _)| *k == ciborium::Value::Text("name".to_string()));
1557            let has_version = pairs
1558                .iter()
1559                .any(|(k, _)| *k == ciborium::Value::Text("version".to_string()));
1560            assert!(has_name, "encoder map should have 'name' key");
1561            assert!(has_version, "encoder map should have 'version' key");
1562        } else {
1563            panic!("encoder should be a map");
1564        }
1565
1566        // uuid should be a valid UUID string (36 chars with hyphens)
1567        if let ciborium::Value::Text(uuid_str) = decoded.reserved.get("uuid").unwrap() {
1568            assert_eq!(uuid_str.len(), 36, "UUID should be 36 chars: {uuid_str}");
1569            assert_eq!(
1570                uuid_str.chars().filter(|c| *c == '-').count(),
1571                4,
1572                "UUID should have 4 hyphens: {uuid_str}"
1573            );
1574        } else {
1575            panic!("uuid should be a text");
1576        }
1577
1578        // time should be an ISO 8601 timestamp ending with Z
1579        if let ciborium::Value::Text(time_str) = decoded.reserved.get("time").unwrap() {
1580            assert!(
1581                time_str.ends_with('Z'),
1582                "time should end with Z: {time_str}"
1583            );
1584            assert!(
1585                time_str.contains('T'),
1586                "time should contain T separator: {time_str}"
1587            );
1588        } else {
1589            panic!("time should be a text");
1590        }
1591    }
1592
1593    #[test]
1594    fn test_both_reserved_and_reserved_underscore_only_new_captured() {
1595        // Both "reserved" and "_reserved_" present — only "_reserved_" should be captured.
1596        use ciborium::Value;
1597        let cbor = Value::Map(vec![
1598            (
1599                Value::Text("_reserved_".to_string()),
1600                Value::Map(vec![(
1601                    Value::Text("encoder".to_string()),
1602                    Value::Text("tensogram".to_string()),
1603                )]),
1604            ),
1605            (
1606                Value::Text("reserved".to_string()),
1607                Value::Map(vec![(
1608                    Value::Text("old".to_string()),
1609                    Value::Text("ignored".to_string()),
1610                )]),
1611            ),
1612            (Value::Text("version".to_string()), Value::Integer(2.into())),
1613        ]);
1614        let mut bytes = Vec::new();
1615        ciborium::into_writer(&cbor, &mut bytes).unwrap();
1616
1617        let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
1618        assert!(decoded.reserved.contains_key("encoder"));
1619        assert!(!decoded.reserved.contains_key("old"));
1620    }
1621
1622    // ── Category 8: encode_pre_encoded smoke tests ───────────────────────
1623
1624    /// Roundtrip: encode raw bytes via encode(), then re-encode the exact same
1625    /// payload bytes via encode_pre_encoded(). Both decoded payloads must be
1626    /// byte-identical. We compare payload bytes, NOT raw wire messages (provenance
1627    /// UUIDs make raw message equality impossible).
1628    #[test]
1629    fn test_encode_pre_encoded_roundtrip_simple_packing() {
1630        // Use encoding="none" (raw float32) for maximum portability — no feature flags needed.
1631        let desc = make_descriptor(vec![4]);
1632        let raw_data: Vec<u8> = vec![0u8; 4 * 4]; // 4 float32 values, all-zero
1633
1634        let meta = GlobalMetadata::default();
1635        let options = EncodeOptions::default();
1636
1637        // Step 1: encode normally
1638        let msg1 = encode(&meta, &[(&desc, raw_data.as_slice())], &options).unwrap();
1639
1640        // Step 2: decode to get the encoded payload bytes + descriptor
1641        let (_, objects1) = decode(&msg1, &DecodeOptions::default()).unwrap();
1642        let (decoded_desc1, decoded_payload1) = &objects1[0];
1643
1644        // Step 3: re-encode the same bytes via encode_pre_encoded
1645        let msg2 = encode_pre_encoded(
1646            &meta,
1647            &[(&decoded_desc1.clone(), decoded_payload1.as_slice())],
1648            &options,
1649        )
1650        .unwrap();
1651
1652        // Step 4: decode the second message
1653        let (_, objects2) = decode(&msg2, &DecodeOptions::default()).unwrap();
1654        let (_, decoded_payload2) = &objects2[0];
1655
1656        // Payloads must be identical — same bytes, same encoding
1657        // (raw wire messages differ due to non-deterministic provenance UUIDs)
1658        assert_eq!(
1659            decoded_payload1, decoded_payload2,
1660            "decoded payloads should be equal after encode/re-encode roundtrip"
1661        );
1662    }
1663
1664    /// emit_preceders=true must be rejected by encode_pre_encoded (buffered mode).
1665    #[test]
1666    fn test_encode_pre_encoded_rejects_emit_preceders() {
1667        let desc = make_descriptor(vec![2]);
1668        let data = vec![0u8; 8];
1669        let meta = GlobalMetadata::default();
1670        let options = EncodeOptions {
1671            emit_preceders: true,
1672            ..Default::default()
1673        };
1674        let result = encode_pre_encoded(&meta, &[(&desc, data.as_slice())], &options);
1675        assert!(
1676            result.is_err(),
1677            "encode_pre_encoded with emit_preceders=true should fail"
1678        );
1679        let err = result.unwrap_err().to_string();
1680        assert!(
1681            err.contains("emit_preceders"),
1682            "error should mention emit_preceders: {err}"
1683        );
1684    }
1685
1686    /// Caller-supplied hash in descriptor must be overwritten by the library-computed hash.
1687    #[test]
1688    fn test_encode_pre_encoded_overwrites_caller_hash() {
1689        let mut desc = make_descriptor(vec![2]);
1690        // Plant garbage hash in descriptor
1691        desc.hash = Some(HashDescriptor {
1692            hash_type: "xxh3".to_string(),
1693            value: "deadbeefdeadbeef".to_string(),
1694        });
1695
1696        let data = vec![0xAB_u8; 8]; // non-trivial payload bytes
1697        let meta = GlobalMetadata::default();
1698        let options = EncodeOptions::default(); // includes xxh3 hashing
1699
1700        let msg = encode_pre_encoded(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1701        let (_, objects) = decode(&msg, &DecodeOptions::default()).unwrap();
1702        let (decoded_desc, decoded_payload) = &objects[0];
1703
1704        // Library should have computed a fresh hash
1705        let computed_hash = match options.hash_algorithm {
1706            Some(algorithm) => compute_hash(decoded_payload, algorithm),
1707            None => panic!("expected hash algorithm"),
1708        };
1709
1710        let stored_hash = decoded_desc
1711            .hash
1712            .as_ref()
1713            .expect("hash should be present in decoded descriptor")
1714            .value
1715            .clone();
1716
1717        assert_ne!(
1718            stored_hash, "deadbeefdeadbeef",
1719            "caller's garbage hash must be overwritten"
1720        );
1721        assert_eq!(
1722            stored_hash, computed_hash,
1723            "library-computed hash must match hash over decoded payload"
1724        );
1725    }
1726
1727    #[test]
1728    fn test_validate_szip_block_offsets_happy_path() {
1729        let mut params = BTreeMap::new();
1730        params.insert(
1731            "szip_block_offsets".to_string(),
1732            ciborium::Value::Array(vec![0u64, 100, 200].into_iter().map(|n| n.into()).collect()),
1733        );
1734
1735        assert!(validate_szip_block_offsets(&params, 100).is_ok());
1736    }
1737
1738    #[test]
1739    fn test_validate_szip_block_offsets_missing_key() {
1740        let params = BTreeMap::new();
1741
1742        let err = validate_szip_block_offsets(&params, 100)
1743            .unwrap_err()
1744            .to_string();
1745        assert!(
1746            err.contains("missing") && err.contains("szip_block_offsets"),
1747            "error: {err}"
1748        );
1749    }
1750
1751    #[test]
1752    fn test_validate_szip_block_offsets_not_array() {
1753        let mut params = BTreeMap::new();
1754        params.insert(
1755            "szip_block_offsets".to_string(),
1756            ciborium::Value::Integer(0.into()),
1757        );
1758
1759        let err = validate_szip_block_offsets(&params, 100)
1760            .unwrap_err()
1761            .to_string();
1762        assert!(
1763            err.contains("array") && err.contains("szip_block_offsets"),
1764            "error: {err}"
1765        );
1766    }
1767
1768    #[test]
1769    fn test_validate_szip_block_offsets_non_integer_element() {
1770        let mut params = BTreeMap::new();
1771        params.insert(
1772            "szip_block_offsets".to_string(),
1773            ciborium::Value::Array(vec![
1774                ciborium::Value::Integer(0.into()),
1775                ciborium::Value::Text("x".to_string()),
1776            ]),
1777        );
1778
1779        let err = validate_szip_block_offsets(&params, 100)
1780            .unwrap_err()
1781            .to_string();
1782        assert!(
1783            err.contains("integer") && err.contains("szip_block_offsets"),
1784            "error: {err}"
1785        );
1786    }
1787
1788    #[test]
1789    fn test_validate_szip_block_offsets_nonzero_first() {
1790        let mut params = BTreeMap::new();
1791        params.insert(
1792            "szip_block_offsets".to_string(),
1793            ciborium::Value::Array(vec![5u64, 100, 200].into_iter().map(|n| n.into()).collect()),
1794        );
1795
1796        let err = validate_szip_block_offsets(&params, 100)
1797            .unwrap_err()
1798            .to_string();
1799        assert!(
1800            err.contains("must be 0") && err.contains("got 5"),
1801            "error: {err}"
1802        );
1803    }
1804
1805    #[test]
1806    fn test_validate_szip_block_offsets_non_monotonic() {
1807        let mut params = BTreeMap::new();
1808        params.insert(
1809            "szip_block_offsets".to_string(),
1810            ciborium::Value::Array(vec![0u64, 100, 50].into_iter().map(|n| n.into()).collect()),
1811        );
1812
1813        let err = validate_szip_block_offsets(&params, 100)
1814            .unwrap_err()
1815            .to_string();
1816        assert!(
1817            err.contains("increasing") || err.contains("monotonic"),
1818            "error: {err}"
1819        );
1820    }
1821
1822    #[test]
1823    fn test_validate_szip_block_offsets_offset_beyond_bound() {
1824        let mut params = BTreeMap::new();
1825        params.insert(
1826            "szip_block_offsets".to_string(),
1827            ciborium::Value::Array(vec![0u64, 100, 801].into_iter().map(|n| n.into()).collect()),
1828        );
1829
1830        let err = validate_szip_block_offsets(&params, 100)
1831            .unwrap_err()
1832            .to_string();
1833        assert!(err.contains("800") && err.contains("801"), "error: {err}");
1834    }
1835
1836    #[test]
1837    fn test_validate_no_szip_offsets_for_non_szip_rejects() {
1838        let mut params = BTreeMap::new();
1839        params.insert(
1840            "szip_block_offsets".to_string(),
1841            ciborium::Value::Array(vec![0u64, 1].into_iter().map(|n| n.into()).collect()),
1842        );
1843        let desc = DataObjectDescriptor {
1844            obj_type: "ntensor".to_string(),
1845            ndim: 1,
1846            shape: vec![2],
1847            strides: vec![1],
1848            dtype: Dtype::Float32,
1849            byte_order: ByteOrder::native(),
1850            encoding: "none".to_string(),
1851            filter: "none".to_string(),
1852            compression: "zstd".to_string(),
1853            params,
1854            hash: None,
1855        };
1856
1857        let err = validate_no_szip_offsets_for_non_szip(&desc)
1858            .unwrap_err()
1859            .to_string();
1860        assert!(
1861            err.contains("szip_block_offsets") && err.contains("zstd"),
1862            "error: {err}"
1863        );
1864    }
1865
1866    #[test]
1867    fn test_validate_no_szip_offsets_for_non_szip_allows_szip() {
1868        let mut params = BTreeMap::new();
1869        params.insert(
1870            "szip_block_offsets".to_string(),
1871            ciborium::Value::Array(vec![0u64, 1].into_iter().map(|n| n.into()).collect()),
1872        );
1873        let desc = DataObjectDescriptor {
1874            obj_type: "ntensor".to_string(),
1875            ndim: 1,
1876            shape: vec![2],
1877            strides: vec![1],
1878            dtype: Dtype::Float32,
1879            byte_order: ByteOrder::native(),
1880            encoding: "none".to_string(),
1881            filter: "none".to_string(),
1882            compression: "szip".to_string(),
1883            params,
1884            hash: None,
1885        };
1886
1887        assert!(validate_no_szip_offsets_for_non_szip(&desc).is_ok());
1888    }
1889}