Skip to main content

tensogram_wasm/
extras.rs

1// (C) Copyright 2026- ECMWF and individual contributors.
2//
3// This software is licensed under the terms of the Apache Licence Version 2.0
4// which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5// In applying this licence, ECMWF does not waive the privileges and immunities
6// granted to it by virtue of its status as an intergovernmental organisation nor
7// does it submit to any jurisdiction.
8
9//! Additional WASM exports covering the Scope C.1 API-parity gap.
10//!
11//! Each function here is a thin binding over a public function in the
12//! Rust core or in `tensogram-encodings`.  All errors are mapped via
13//! [`crate::convert::js_err`] so the TypeScript wrapper's
14//! `mapTensogramError` sees a consistent message shape.
15
16use crate::convert::*;
17use serde::Serialize;
18use tensogram::{self as core, DecodeOptions};
19use tensogram_encodings::simple_packing;
20use wasm_bindgen::prelude::*;
21
22// ── decode_range ─────────────────────────────────────────────────────────────
23
24/// Decode partial sub-tensor ranges from a single data object.
25///
26/// @param buf - Complete wire-format message bytes.
27/// @param object_index - Zero-based index of the target object.
28/// @param ranges - Flat `BigUint64Array` of `[offset0, count0, offset1, count1, …]`
29///   pairs, in element units (not bytes).  Empty array returns an empty result.
30///   hash is checked against the payload before any range is decoded.
31/// @returns `{ descriptor, parts: Uint8Array[] }` — one raw-bytes view
32///   per requested range, in request order.  Callers (e.g. the TS
33///   wrapper's `decodeRange`) convert each `Uint8Array` into a
34///   dtype-typed view.
35///
36/// Throws on out-of-range object index, unsupported `filter` (e.g.
37/// shuffle), or bitmask dtype (matching the Rust-core contract).
38///
39/// Implementation note: the `parts` array is built manually with
40/// `Uint8Array::from` rather than via `to_js` — `serde_wasm_bindgen`
41/// serialises `&[u8]` as a plain JS `Array<number>` by default, and the
42/// TS wrapper expects honest `Uint8Array` instances so that
43/// `typedArrayFor` can wrap them zero-copy.
44#[wasm_bindgen]
45pub fn decode_range(
46    buf: &[u8],
47    object_index: usize,
48    ranges: &js_sys::BigUint64Array,
49) -> Result<JsValue, JsValue> {
50    let flat: Vec<u64> = ranges.to_vec();
51    if !flat.len().is_multiple_of(2) {
52        return Err(JsValue::from(js_sys::Error::new(
53            "ranges length must be a multiple of 2 (flat [offset, count] pairs)",
54        )));
55    }
56    let range_pairs: Vec<(u64, u64)> = flat.chunks_exact(2).map(|w| (w[0], w[1])).collect();
57
58    let options = DecodeOptions {
59        ..Default::default()
60    };
61    let (descriptor, parts) =
62        core::decode_range(buf, object_index, &range_pairs, &options).map_err(js_err)?;
63
64    let result = js_sys::Object::new();
65    js_sys::Reflect::set(&result, &"descriptor".into(), &to_js(&descriptor)?)
66        .map_err(|_| JsValue::from(js_sys::Error::new("failed to set descriptor")))?;
67    // `js_sys::Array::new_with_length` is defined in terms of `u32`; a
68    // Rust-side Vec of 2^32 entries would already have failed the
69    // earlier `to_vec()` from the input `BigUint64Array` (WASM linear
70    // memory is itself u32-indexed), so the cast here can't truncate
71    // in any reachable execution.
72    let parts_js = js_sys::Array::new_with_length(parts.len() as u32);
73    for (i, bytes) in parts.iter().enumerate() {
74        parts_js.set(i as u32, js_sys::Uint8Array::from(bytes.as_slice()).into());
75    }
76    js_sys::Reflect::set(&result, &"parts".into(), &parts_js)
77        .map_err(|_| JsValue::from(js_sys::Error::new("failed to set parts")))?;
78    Ok(result.into())
79}
80
81// ── compute_hash ─────────────────────────────────────────────────────────────
82
83/// Compute the hex-encoded hash of a byte slice.
84///
85/// @param data - Bytes to hash.
86/// @param algo - Algorithm name; default `"xxh3"`.  Unknown algorithm
87///   names raise a metadata error.
88/// @returns The hex digest as a string (16 chars for xxh3-64).
89#[wasm_bindgen]
90pub fn compute_hash(data: &[u8], algo: Option<String>) -> Result<String, JsValue> {
91    let name = algo.as_deref().unwrap_or("xxh3");
92    core::parse_hash_name(Some(name)).map_err(js_err)?;
93    Ok(core::compute_hash(data))
94}
95
96// ── simple_packing_compute_params ────────────────────────────────────────────
97
98/// JS-side shape of the simple-packing params — matches the ``sp_``-prefixed
99/// snake_case keys a descriptor's `params` map expects, so the caller can
100/// spread the result straight into a descriptor literal.
101#[derive(Serialize)]
102struct SimplePackingParamsJs {
103    sp_reference_value: f64,
104    sp_binary_scale_factor: i32,
105    sp_decimal_scale_factor: i32,
106    sp_bits_per_value: u32,
107}
108
109/// Compute the simple-packing parameters (reference value, binary/decimal
110/// scale factors, bits-per-value) for a float64 array.
111///
112/// @param values - `Float64Array` — finite, non-NaN.
113/// @param bits_per_value - Quantization depth (0–64; 0 denotes a
114///   constant-field packing).
115/// @param decimal_scale_factor - Power-of-10 scaling applied before
116///   packing.  Typically `0`.
117/// @returns Plain JS object with ``sp_``-prefixed keys matching the
118///   on-wire descriptor params: `{ sp_reference_value,
119///   sp_binary_scale_factor, sp_decimal_scale_factor, sp_bits_per_value }`.
120///   Spread into a descriptor to apply:
121///   `{ ...computed, encoding: "simple_packing", …}`.
122///
123///   Note: the encoder also auto-computes these values when the
124///   descriptor carries only `sp_bits_per_value` (and optionally
125///   `sp_decimal_scale_factor`) — calling this function explicitly
126///   is only needed if the caller wants to cache or inspect the
127///   derived params across multiple encodes.
128#[wasm_bindgen]
129pub fn simple_packing_compute_params(
130    values: &[f64],
131    bits_per_value: u32,
132    decimal_scale_factor: i32,
133) -> Result<JsValue, JsValue> {
134    let params = simple_packing::compute_params(values, bits_per_value, decimal_scale_factor)
135        .map_err(|e| JsValue::from(js_sys::Error::new(&format!("encoding error: {e}"))))?;
136    to_js(&SimplePackingParamsJs {
137        sp_reference_value: params.reference_value,
138        sp_binary_scale_factor: params.binary_scale_factor,
139        sp_decimal_scale_factor: params.decimal_scale_factor,
140        sp_bits_per_value: params.bits_per_value,
141    })
142}
143
144// ── encode_pre_encoded ───────────────────────────────────────────────────────
145
146/// Encode a complete Tensogram message from pre-encoded data objects.
147///
148/// Like [`crate::encode`], but each object's bytes are assumed already
149/// encoded by the caller (according to its descriptor's pipeline) and
150/// are written verbatim.  The library validates descriptor structure and
151/// any `szip_block_offsets` it finds but never runs the encoding
152/// pipeline.  The hash is always recomputed from the caller's bytes.
153///
154/// @param metadata_js - GlobalMetadata (JS object, `version: 3` required).
155/// @param objects_js - Array of `{descriptor, data}`; each `data` must
156///   be a `Uint8Array` (opaque pre-encoded bytes).
157/// @param hash - Whether to stamp an xxh3 hash.  Default `true`.
158/// @returns Full wire-format message as `Uint8Array`.
159#[wasm_bindgen]
160pub fn encode_pre_encoded(
161    metadata_js: JsValue,
162    objects_js: js_sys::Array,
163    hash: Option<bool>,
164) -> Result<js_sys::Uint8Array, JsValue> {
165    let metadata = metadata_from_js(&metadata_js)?;
166    let (descriptors, data_vec) = extract_descriptor_data_pairs(&objects_js)?;
167    let pairs: Vec<(&core::DataObjectDescriptor, &[u8])> = descriptors
168        .iter()
169        .zip(data_vec.iter())
170        .map(|(d, v)| (d, v.as_slice()))
171        .collect();
172    // encode_pre_encoded rejects strict-finite flags at the Rust
173    // level; we hardcode them off here so the WASM surface can't
174    // forward them accidentally.
175    let encoded =
176        core::encode_pre_encoded(&metadata, &pairs, &build_encode_options(hash)).map_err(js_err)?;
177    Ok(js_sys::Uint8Array::from(encoded.as_slice()))
178}
179
180// ── validate_buffer ──────────────────────────────────────────────────────────
181
182/// Validate a single Tensogram message buffer.  Returns a JSON string
183/// matching the structure emitted by `tensogram::validate::ValidationReport`:
184/// `{ issues: [...], object_count: N, hash_verified: bool }`.
185///
186/// The TypeScript wrapper parses this JSON once and exposes a typed
187/// `ValidationReport`.  Keeping the bridge as a JSON string avoids
188/// lossy conversion of large integers and keeps the WASM surface
189/// language-neutral (it already matches the FFI contract).
190///
191/// @param buf - The wire-format bytes of a single message (not a file).
192/// @param level - One of `"quick"` / `"default"` / `"checksum"` / `"full"`.
193///   `None` defaults to `"default"`.
194/// @param check_canonical - When true, adds RFC 8949 §4.2 canonical
195///   CBOR key-ordering checks.
196#[wasm_bindgen]
197pub fn validate_buffer(
198    buf: &[u8],
199    level: Option<String>,
200    check_canonical: bool,
201) -> Result<String, JsValue> {
202    let options = parse_validate_options(level.as_deref(), check_canonical)?;
203    let report = core::validate::validate_message(buf, &options);
204    serde_json::to_string(&report).map_err(|e| JsValue::from(js_sys::Error::new(&format!("encoding error: {e}"))))
205}
206
207fn parse_validate_options(
208    level: Option<&str>,
209    check_canonical: bool,
210) -> Result<core::validate::ValidateOptions, JsValue> {
211    use core::validate::{ValidateOptions, ValidationLevel};
212
213    let (max_level, checksum_only) = match level.unwrap_or("default") {
214        "quick" => (ValidationLevel::Structure, false),
215        "default" => (ValidationLevel::Integrity, false),
216        "checksum" => (ValidationLevel::Integrity, true),
217        "full" => (ValidationLevel::Fidelity, false),
218        other => {
219            return Err(JsValue::from(js_sys::Error::new(&format!(
220                "unknown validation level '{other}', expected one of: quick, default, checksum, full",
221            ))));
222        }
223    };
224    Ok(ValidateOptions {
225        max_level,
226        check_canonical,
227        checksum_only,
228    })
229}