Skip to main content

tensogram/
types.rs

1// (C) Copyright 2026- ECMWF and individual contributors.
2//
3// This software is licensed under the terms of the Apache Licence Version 2.0
4// which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5// In applying this licence, ECMWF does not waive the privileges and immunities
6// granted to it by virtue of its status as an intergovernmental organisation nor
7// does it submit to any jurisdiction.
8
9use serde::{Deserialize, Serialize};
10use std::collections::BTreeMap;
11
12use crate::dtype::Dtype;
13use crate::error::Result;
14use crate::error::TensogramError;
15
16pub use tensogram_encodings::ByteOrder;
17
18// `HashDescriptor` was the v2 per-object hash carrier.  In v3 the
19// per-object hash lives in the inline hash slot of the data-object
20// frame footer (see `plans/WIRE_FORMAT.md` §2.2 and §2.4), and the
21// message-level aggregate frame ([`HashFrame`]) stores hex-encoded
22// digest strings directly.  The struct was removed in Wave 2.2 along
23// with the standalone `hash::verify_hash(data, &HashDescriptor)`
24// helper — frame-level verification goes through
25// [`crate::hash::hash_frame_body`] / [`crate::hash::verify_frame_hash`]
26// instead.
27
28/// On-wire descriptor for one of the three NaN / Inf companion-frame
29/// masks (see `plans/WIRE_FORMAT.md` §6.5.1).
30///
31/// `offset` and `length` locate the mask blob inside the frame's
32/// payload region; `method` names the compression scheme (`rle`,
33/// `roaring`, `blosc2`, `zstd`, `lz4`, or `none`); `params` carries
34/// any method-specific parameters (e.g. zstd level, blosc2 sub-codec).
35#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
36pub struct MaskDescriptor {
37    /// Canonical method name — `rle` | `roaring` | `blosc2` | `zstd` |
38    /// `lz4` | `none`.
39    pub method: String,
40    /// Byte offset of the mask blob, measured from the start of the
41    /// frame's payload region (= the first byte after the 16-byte
42    /// frame header).
43    pub offset: u64,
44    /// Byte length of the (compressed) mask blob on disk.
45    pub length: u64,
46    /// Method-specific parameters (e.g. `{ "level": 3 }` for zstd).
47    /// Empty map is serialised as absent to match the canonical
48    /// zero-cost form.
49    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
50    pub params: BTreeMap<String, ciborium::Value>,
51}
52
53/// Top-level `masks` sub-map for the `NTensorFrame` (wire type 9,
54/// see `plans/WIRE_FORMAT.md` §6.5).
55///
56/// All three fields are optional — a frame can carry any subset (or
57/// none, in which case the entire `masks` sub-map is absent).  Field
58/// names serialise as `nan`, `inf+`, `inf-` per the canonical sort
59/// order (byte-lex: `inf+` < `inf-` < `nan`).
60#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
61pub struct MasksMetadata {
62    /// Mask recording element positions that were NaN on encode.
63    #[serde(default, skip_serializing_if = "Option::is_none")]
64    pub nan: Option<MaskDescriptor>,
65    /// Mask recording element positions that were `+Inf` on encode.
66    #[serde(rename = "inf+", default, skip_serializing_if = "Option::is_none")]
67    pub pos_inf: Option<MaskDescriptor>,
68    /// Mask recording element positions that were `-Inf` on encode.
69    #[serde(rename = "inf-", default, skip_serializing_if = "Option::is_none")]
70    pub neg_inf: Option<MaskDescriptor>,
71}
72
73impl MasksMetadata {
74    /// `true` when every kind is absent.  In that case the `masks`
75    /// field on the descriptor should be `None` rather than
76    /// `Some(empty)`, to match the canonical zero-cost form.
77    pub fn is_empty(&self) -> bool {
78        self.nan.is_none() && self.pos_inf.is_none() && self.neg_inf.is_none()
79    }
80}
81
82/// Per-object descriptor — merges tensor metadata and encoding instructions.
83///
84/// Each data object frame carries one of these as its CBOR descriptor.
85/// This replaces the v1 split between `ObjectDescriptor` (tensor info)
86/// and `PayloadDescriptor` (encoding info).
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct DataObjectDescriptor {
89    // ── Tensor metadata ──
90    #[serde(rename = "type")]
91    pub obj_type: String,
92    pub ndim: u64,
93    pub shape: Vec<u64>,
94    pub strides: Vec<u64>,
95    pub dtype: Dtype,
96
97    // ── Encoding pipeline ──
98    /// Wire byte order for the payload.  Optional on deserialize — a
99    /// missing key defaults to [`ByteOrder::native()`], matching the
100    /// behaviour of the Python binding's dict-to-descriptor path and
101    /// the payload bytes a caller typically produces with
102    /// native-endian conversions (`to_ne_bytes`, `std::vector<T>`, etc.).
103    #[serde(default = "ByteOrder::native")]
104    pub byte_order: ByteOrder,
105    pub encoding: String,
106    pub filter: String,
107    pub compression: String,
108
109    /// Optional NaN / Inf companion-mask metadata (`NTensorFrame`,
110    /// wire type 9 — see `plans/WIRE_FORMAT.md` §6.5).  `None` means
111    /// no mask sections are present, and the frame is byte-compatible
112    /// with an `NTensorFrame` emitted without the `allow_nan` /
113    /// `allow_inf` opt-in.
114    /// Declared **before** `params` so that the flattened `params` map
115    /// below does not absorb the `masks` key at deserialisation time.
116    #[serde(default, skip_serializing_if = "Option::is_none")]
117    pub masks: Option<MasksMetadata>,
118
119    /// Encoding/filter/compression parameters (reference_value, bits_per_value,
120    /// szip_block_offsets, etc.). Stored as ciborium::Value for flexibility.
121    #[serde(flatten)]
122    pub params: BTreeMap<String, ciborium::Value>,
123}
124
125impl DataObjectDescriptor {
126    /// Compute the total number of elements implied by `shape`, validating
127    /// against u64 overflow and usize range.
128    ///
129    /// Returns the element count in `usize`, or an error when the product
130    /// overflows `u64` or cannot fit in `usize`.
131    #[inline]
132    pub fn num_elements(&self) -> Result<usize> {
133        let shape_product = self
134            .shape
135            .iter()
136            .try_fold(1u64, |acc, &x| acc.checked_mul(x))
137            .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
138        usize::try_from(shape_product)
139            .map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))
140    }
141}
142
143/// Global message metadata (carried in header/footer metadata frames).
144///
145/// The CBOR metadata frame is **fully free-form**.  The only named
146/// top-level sections the library interprets are:
147/// - `base`: per-object metadata array — one entry per data object, each
148///   entry holds ALL structured metadata for that object independently.
149///   The encoder auto-populates `_reserved_.tensor` (ndim/shape/strides/dtype)
150///   in each entry.
151/// - `_reserved_`: library internals (provenance: encoder info, time, uuid).
152///   Client code can read but MUST NOT write — the encoder validates this.
153/// - `_extra_`: client-writable catch-all for ad-hoc message-level annotations.
154///
155/// Any other top-level key supplied by the caller (including a stray
156/// legacy `"version"` key) is routed into `_extra_` on decode.  The
157/// wire-format version lives **only** in the preamble — see
158/// [`crate::wire::WIRE_VERSION`] and [`crate::wire::Preamble`].
159#[derive(Debug, Clone, Default, Serialize, Deserialize)]
160pub struct GlobalMetadata {
161    /// Per-object metadata array.  Each entry holds ALL structured metadata
162    /// for that data object.  Entries are independent — no tracking of what
163    /// is common across objects.
164    ///
165    /// The encoder auto-populates `_reserved_.tensor` (with ndim, shape,
166    /// strides, dtype) in each entry.  Application code may pre-populate
167    /// additional keys (e.g. `"mars": {…}`) before encoding; the encoder
168    /// preserves them.
169    #[serde(default, skip_serializing_if = "Vec::is_empty")]
170    pub base: Vec<BTreeMap<String, ciborium::Value>>,
171
172    /// Library internals — provenance info (encoder, time, uuid).
173    /// Client code can read but MUST NOT write; the encoder overwrites this.
174    #[serde(
175        rename = "_reserved_",
176        default,
177        skip_serializing_if = "BTreeMap::is_empty"
178    )]
179    pub reserved: BTreeMap<String, ciborium::Value>,
180
181    /// Client-writable catch-all for ad-hoc message-level annotations.
182    #[serde(
183        rename = "_extra_",
184        default,
185        skip_serializing_if = "BTreeMap::is_empty"
186    )]
187    pub extra: BTreeMap<String, ciborium::Value>,
188}
189
190/// Index frame payload — maps object ordinals to byte offsets.
191///
192/// v3 CBOR schema (see `plans/WIRE_FORMAT.md` §6.2):
193///
194/// ```cbor
195/// { "offsets": [u64, ...], "lengths": [u64, ...] }
196/// ```
197///
198/// Object count is derived from `offsets.len()`.  The previously
199/// serialised `object_count` key is dropped.
200#[derive(Debug, Clone, Default)]
201pub struct IndexFrame {
202    /// Byte offset of each data object frame from message start.
203    pub offsets: Vec<u64>,
204    /// Total byte length of each data object frame, excluding alignment padding.
205    pub lengths: Vec<u64>,
206}
207
208/// Hash frame payload — per-object integrity hashes.
209///
210/// v3 CBOR schema (see `plans/WIRE_FORMAT.md` §6.3):
211///
212/// ```cbor
213/// { "algorithm": "xxh3", "hashes": ["hex", "hex", ...] }
214/// ```
215///
216/// The `hash_type` key was renamed to `algorithm` to signal that the
217/// value names the algorithm rather than a type identifier.  Object
218/// count is derived from `hashes.len()`.
219#[derive(Debug, Clone)]
220pub struct HashFrame {
221    pub algorithm: String,
222    pub hashes: Vec<String>,
223}
224
225/// A decoded object: its descriptor paired with its raw decoded payload bytes.
226pub type DecodedObject = (DataObjectDescriptor, Vec<u8>);
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231
232    #[test]
233    fn masks_metadata_is_empty_detects_every_kind_absent() {
234        let empty = MasksMetadata::default();
235        assert!(empty.is_empty());
236    }
237
238    #[test]
239    fn masks_metadata_is_empty_false_when_any_kind_present() {
240        let any_mask = MaskDescriptor {
241            method: "roaring".to_string(),
242            offset: 0,
243            length: 1,
244            params: BTreeMap::new(),
245        };
246        let nan_only = MasksMetadata {
247            nan: Some(any_mask.clone()),
248            ..MasksMetadata::default()
249        };
250        let pos_only = MasksMetadata {
251            pos_inf: Some(any_mask.clone()),
252            ..MasksMetadata::default()
253        };
254        let neg_only = MasksMetadata {
255            neg_inf: Some(any_mask),
256            ..MasksMetadata::default()
257        };
258        assert!(!nan_only.is_empty());
259        assert!(!pos_only.is_empty());
260        assert!(!neg_only.is_empty());
261    }
262
263    #[test]
264    fn descriptor_deserialize_defaults_byte_order_to_native() {
265        let json = r#"{
266            "type": "ntensor",
267            "ndim": 1,
268            "shape": [4],
269            "strides": [1],
270            "dtype": "float32",
271            "encoding": "none",
272            "filter": "none",
273            "compression": "none"
274        }"#;
275        let desc: DataObjectDescriptor =
276            serde_json::from_str(json).expect("deserialize should succeed without byte_order");
277        assert_eq!(desc.byte_order, ByteOrder::native());
278    }
279
280    #[test]
281    fn descriptor_deserialize_honours_explicit_byte_order() {
282        for (literal, expected) in [("little", ByteOrder::Little), ("big", ByteOrder::Big)] {
283            let json = format!(
284                r#"{{
285                    "type": "ntensor", "ndim": 1, "shape": [4], "strides": [1],
286                    "dtype": "float32", "byte_order": "{literal}",
287                    "encoding": "none", "filter": "none", "compression": "none"
288                }}"#
289            );
290            let desc: DataObjectDescriptor =
291                serde_json::from_str(&json).expect("deserialize should accept explicit byte_order");
292            assert_eq!(desc.byte_order, expected);
293        }
294    }
295
296    /// Build a minimal descriptor with the given shape; the other fields
297    /// are irrelevant to `num_elements()` because the method only inspects
298    /// `shape`.
299    fn descriptor_with_shape(shape: Vec<u64>) -> DataObjectDescriptor {
300        DataObjectDescriptor {
301            obj_type: "ntensor".to_string(),
302            ndim: shape.len() as u64,
303            shape,
304            strides: Vec::new(),
305            dtype: Dtype::Float32,
306            byte_order: ByteOrder::native(),
307            encoding: "none".to_string(),
308            filter: "none".to_string(),
309            compression: "none".to_string(),
310            masks: None,
311            params: BTreeMap::new(),
312        }
313    }
314
315    #[test]
316    fn num_elements_empty_shape_is_one() {
317        // Empty shape (scalar tensor, ndim=0) — by convention the
318        // empty product is 1.
319        let desc = descriptor_with_shape(vec![]);
320        assert_eq!(desc.num_elements().unwrap(), 1);
321    }
322
323    #[test]
324    fn num_elements_single_dim() {
325        let desc = descriptor_with_shape(vec![100]);
326        assert_eq!(desc.num_elements().unwrap(), 100);
327    }
328
329    #[test]
330    fn num_elements_multi_dim() {
331        let desc = descriptor_with_shape(vec![3, 4, 5]);
332        assert_eq!(desc.num_elements().unwrap(), 60);
333    }
334
335    #[test]
336    fn num_elements_zero_dim_yields_zero() {
337        // A dimension of zero produces a zero-element tensor; this is
338        // valid and must not error.
339        let desc = descriptor_with_shape(vec![10, 0, 5]);
340        assert_eq!(desc.num_elements().unwrap(), 0);
341    }
342
343    #[test]
344    fn num_elements_u64_overflow_is_metadata_error() {
345        // u64::MAX × 2 overflows checked_mul.
346        let desc = descriptor_with_shape(vec![u64::MAX, 2]);
347        let err = desc.num_elements().unwrap_err();
348        match err {
349            TensogramError::Metadata(msg) => {
350                assert!(
351                    msg.contains("shape product overflow"),
352                    "unexpected message: {msg}"
353                );
354            }
355            other => panic!("expected Metadata error, got: {other:?}"),
356        }
357    }
358
359    #[cfg(target_pointer_width = "32")]
360    #[test]
361    fn num_elements_usize_overflow_is_metadata_error_on_32bit() {
362        // On 32-bit targets, a u64 product larger than usize::MAX
363        // surfaces the second error path.  On 64-bit targets the
364        // shape would have to be larger than `u64::MAX` to trigger
365        // this, which the previous test already covers.
366        let desc = descriptor_with_shape(vec![(usize::MAX as u64) + 1]);
367        let err = desc.num_elements().unwrap_err();
368        match err {
369            TensogramError::Metadata(msg) => {
370                assert!(
371                    msg.contains("element count overflows usize"),
372                    "unexpected message: {msg}"
373                );
374            }
375            other => panic!("expected Metadata error, got: {other:?}"),
376        }
377    }
378}