Skip to main content

tensogram/
pipeline.rs

1// (C) Copyright 2026- ECMWF and individual contributors.
2//
3// This software is licensed under the terms of the Apache Licence Version 2.0
4// which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5// In applying this licence, ECMWF does not waive the privileges and immunities
6// granted to it by virtue of its status as an intergovernmental organisation nor
7// does it submit to any jurisdiction.
8
9//! Shared encoding/filter/compression pipeline helpers for converters.
10//!
11//! Both `tensogram-grib` and `tensogram-netcdf` accept the same set of CLI
12//! flags (`--encoding` / `--bits` / `--filter` / `--compression` /
13//! `--compression-level`) and translate them into
14//! [`DataObjectDescriptor`] fields + `params` map entries in exactly the
15//! same way. This module centralises that translation so the two
16//! converters cannot drift out of sync.
17//!
18//! ## Usage
19//!
20//! ```no_run
21//! use tensogram::pipeline::{apply_pipeline, DataPipeline};
22//! use tensogram::types::{ByteOrder, DataObjectDescriptor};
23//! use tensogram::Dtype;
24//! use std::collections::BTreeMap;
25//!
26//! let mut desc = DataObjectDescriptor {
27//!     obj_type: "ntensor".to_string(),
28//!     ndim: 1,
29//!     shape: vec![4],
30//!     strides: vec![1],
31//!     dtype: Dtype::Float64,
32//!     byte_order: ByteOrder::Little,
33//!     encoding: "none".to_string(),
34//!     filter: "none".to_string(),
35//!     compression: "none".to_string(),
36//!     params: BTreeMap::new(),
37//!     hash: None,
38//! };
39//!
40//! let pipeline = DataPipeline {
41//!     compression: "zstd".to_string(),
42//!     ..Default::default()
43//! };
44//!
45//! let values = [1.0_f64, 2.0, 3.0, 4.0];
46//! apply_pipeline(&mut desc, Some(&values), &pipeline, "my_var").unwrap();
47//! assert_eq!(desc.compression, "zstd");
48//! ```
49
50use ciborium::Value as CborValue;
51use tensogram_encodings::simple_packing;
52
53use crate::types::DataObjectDescriptor;
54
55/// Encoding/filter/compression configuration for data objects.
56///
57/// Defaults to all `"none"` — produces uncompressed raw little-endian
58/// payloads identical to the pre-pipeline behaviour. This is the shared
59/// type used by both `tensogram-grib` and `tensogram-netcdf`; the two
60/// crates re-export it from their own `lib.rs` for convenience.
61#[derive(Debug, Clone)]
62pub struct DataPipeline {
63    /// Encoding stage: `"none"` (default) or `"simple_packing"`.
64    pub encoding: String,
65    /// Bits per value for `simple_packing`. Defaults to 16 when `None`.
66    pub bits: Option<u32>,
67    /// Filter stage: `"none"` (default) or `"shuffle"`.
68    pub filter: String,
69    /// Compression codec: `"none"` (default), `"zstd"`, `"lz4"`,
70    /// `"blosc2"`, or `"szip"`.
71    pub compression: String,
72    /// Optional compression level (used by `zstd` and `blosc2`; ignored
73    /// by other codecs).
74    pub compression_level: Option<i32>,
75}
76
77impl Default for DataPipeline {
78    fn default() -> Self {
79        Self {
80            encoding: "none".to_string(),
81            bits: None,
82            filter: "none".to_string(),
83            compression: "none".to_string(),
84            compression_level: None,
85        }
86    }
87}
88
89/// Apply a [`DataPipeline`] to a [`DataObjectDescriptor`] by setting its
90/// `encoding` / `filter` / `compression` fields and populating `params`.
91///
92/// `values` carries the float64 payload when available — `simple_packing`
93/// is a float64-only encoding, so this parameter is `Some(&[f64])` when
94/// the caller has typed f64 values and `None` otherwise (e.g. integer
95/// variables in a mixed NetCDF file). When `pipeline.encoding ==
96/// "simple_packing"` but `values` is `None`, the encoding stage is
97/// skipped with a stderr warning and the conversion continues with
98/// `encoding = "none"`.
99///
100/// `var_label` is embedded in warning/error messages for human-readable
101/// diagnostics — typically the variable name (`"temperature"`) or
102/// something like `"GRIB message"`.
103///
104/// # Errors
105///
106/// Returns a human-readable error string when `pipeline.encoding`,
107/// `pipeline.filter`, or `pipeline.compression` is not one of the
108/// recognised values. Callers wrap this string into their own
109/// converter-specific error type (`GribError::InvalidData` /
110/// `NetcdfError::InvalidData`).
111///
112/// Soft failures (`simple_packing` rejecting `NaN`-containing data, or
113/// `simple_packing` requested on a non-f64 variable) are reported as
114/// stderr warnings and do NOT return an error — the variable falls
115/// back to `encoding = "none"` and the conversion continues.
116pub fn apply_pipeline(
117    desc: &mut DataObjectDescriptor,
118    values: Option<&[f64]>,
119    pipeline: &DataPipeline,
120    var_label: &str,
121) -> Result<(), String> {
122    // ── Encoding stage ─────────────────────────────────────────────────
123    let mut applied_simple_packing = false;
124    match pipeline.encoding.as_str() {
125        "none" => {}
126        "simple_packing" => match values {
127            None => {
128                eprintln!(
129                    "warning: skipping simple_packing for {var_label} \
130                     (not a float64 payload)"
131                );
132            }
133            Some(values) => {
134                let bits = pipeline.bits.unwrap_or(16);
135                match simple_packing::compute_params(values, bits, 0) {
136                    Ok(params) => {
137                        desc.encoding = "simple_packing".to_string();
138                        desc.params.insert(
139                            "reference_value".to_string(),
140                            CborValue::Float(params.reference_value),
141                        );
142                        desc.params.insert(
143                            "binary_scale_factor".to_string(),
144                            CborValue::Integer((i64::from(params.binary_scale_factor)).into()),
145                        );
146                        desc.params.insert(
147                            "decimal_scale_factor".to_string(),
148                            CborValue::Integer((i64::from(params.decimal_scale_factor)).into()),
149                        );
150                        desc.params.insert(
151                            "bits_per_value".to_string(),
152                            CborValue::Integer((i64::from(params.bits_per_value)).into()),
153                        );
154                        applied_simple_packing = true;
155                    }
156                    Err(e) => {
157                        // Common cause: NaN values from unpacked fill_value.
158                        // simple_packing rejects NaN; the variable falls
159                        // back to encoding="none" so the conversion still
160                        // succeeds.
161                        eprintln!("warning: skipping simple_packing for {var_label}: {e}");
162                    }
163                }
164            }
165        },
166        other => {
167            return Err(format!(
168                "unknown encoding '{other}'; expected 'none' or 'simple_packing'"
169            ));
170        }
171    }
172
173    // ── Filter stage ───────────────────────────────────────────────────
174    match pipeline.filter.as_str() {
175        "none" => {}
176        "shuffle" => {
177            desc.filter = "shuffle".to_string();
178            // shuffle is run AFTER encoding by the pipeline, so the
179            // element size is the *post-encoding* byte width:
180            //   - simple_packing applied → ⌈bpv/8⌉
181            //   - otherwise → native dtype byte width
182            let element_size = if applied_simple_packing {
183                let bpv = pipeline.bits.unwrap_or(16) as usize;
184                bpv.div_ceil(8).max(1)
185            } else {
186                desc.dtype.byte_width()
187            };
188            desc.params.insert(
189                "shuffle_element_size".to_string(),
190                CborValue::Integer((element_size as i64).into()),
191            );
192        }
193        other => {
194            return Err(format!(
195                "unknown filter '{other}'; expected 'none' or 'shuffle'"
196            ));
197        }
198    }
199
200    // ── Compression stage ──────────────────────────────────────────────
201    match pipeline.compression.as_str() {
202        "none" => {}
203        "zstd" => {
204            desc.compression = "zstd".to_string();
205            let level = pipeline.compression_level.unwrap_or(3);
206            desc.params.insert(
207                "zstd_level".to_string(),
208                CborValue::Integer((i64::from(level)).into()),
209            );
210        }
211        "lz4" => {
212            desc.compression = "lz4".to_string();
213        }
214        "blosc2" => {
215            desc.compression = "blosc2".to_string();
216            let clevel = pipeline.compression_level.unwrap_or(5);
217            desc.params.insert(
218                "blosc2_clevel".to_string(),
219                CborValue::Integer((i64::from(clevel)).into()),
220            );
221            // Default sub-codec — users wanting a different one should
222            // construct a `DataObjectDescriptor` manually.
223            desc.params.insert(
224                "blosc2_codec".to_string(),
225                CborValue::Text("lz4".to_string()),
226            );
227        }
228        "szip" => {
229            desc.compression = "szip".to_string();
230            // Sensible szip defaults consistent with the rest of the
231            // codebase (see `tensogram` tests + examples).
232            desc.params
233                .insert("szip_rsi".to_string(), CborValue::Integer(128.into()));
234            desc.params
235                .insert("szip_block_size".to_string(), CborValue::Integer(16.into()));
236            desc.params
237                .insert("szip_flags".to_string(), CborValue::Integer(8.into()));
238        }
239        other => {
240            return Err(format!(
241                "unknown compression '{other}'; expected one of: none, zstd, lz4, blosc2, szip"
242            ));
243        }
244    }
245
246    Ok(())
247}
248
249#[cfg(test)]
250mod tests {
251    use std::collections::BTreeMap;
252
253    use super::*;
254    use crate::Dtype;
255    use crate::types::ByteOrder;
256
257    fn mk_desc() -> DataObjectDescriptor {
258        DataObjectDescriptor {
259            obj_type: "ntensor".to_string(),
260            ndim: 1,
261            shape: vec![4],
262            strides: vec![1],
263            dtype: Dtype::Float64,
264            byte_order: ByteOrder::Little,
265            encoding: "none".to_string(),
266            filter: "none".to_string(),
267            compression: "none".to_string(),
268            params: BTreeMap::new(),
269            hash: None,
270        }
271    }
272
273    fn int_param(desc: &DataObjectDescriptor, key: &str) -> i64 {
274        match desc.params.get(key) {
275            Some(CborValue::Integer(i)) => {
276                let n: i128 = (*i).into();
277                n as i64
278            }
279            other => panic!("{key} not an integer: {other:?}"),
280        }
281    }
282
283    // ── Defaults ────────────────────────────────────────────────────
284
285    #[test]
286    fn default_pipeline_is_all_none() {
287        let p = DataPipeline::default();
288        assert_eq!(p.encoding, "none");
289        assert_eq!(p.filter, "none");
290        assert_eq!(p.compression, "none");
291        assert!(p.bits.is_none());
292        assert!(p.compression_level.is_none());
293    }
294
295    #[test]
296    fn default_pipeline_leaves_descriptor_unchanged() {
297        let mut desc = mk_desc();
298        let values = [1.0, 2.0, 3.0, 4.0];
299        apply_pipeline(&mut desc, Some(&values), &DataPipeline::default(), "x").unwrap();
300        assert_eq!(desc.encoding, "none");
301        assert_eq!(desc.filter, "none");
302        assert_eq!(desc.compression, "none");
303        assert!(desc.params.is_empty());
304    }
305
306    // ── Encoding ────────────────────────────────────────────────────
307
308    #[test]
309    fn simple_packing_populates_four_params() {
310        let mut desc = mk_desc();
311        let p = DataPipeline {
312            encoding: "simple_packing".to_string(),
313            bits: Some(16),
314            ..Default::default()
315        };
316        let values = [0.0_f64, 1.0, 2.0, 3.0];
317        apply_pipeline(&mut desc, Some(&values), &p, "test").unwrap();
318        assert_eq!(desc.encoding, "simple_packing");
319        assert_eq!(int_param(&desc, "bits_per_value"), 16);
320        assert_eq!(int_param(&desc, "decimal_scale_factor"), 0);
321        assert!(desc.params.contains_key("reference_value"));
322        assert!(desc.params.contains_key("binary_scale_factor"));
323    }
324
325    #[test]
326    fn simple_packing_with_no_values_skips_with_warning() {
327        let mut desc = mk_desc();
328        let p = DataPipeline {
329            encoding: "simple_packing".to_string(),
330            ..Default::default()
331        };
332        apply_pipeline(&mut desc, None, &p, "int_var").unwrap();
333        assert_eq!(desc.encoding, "none", "should skip, not set");
334        assert!(desc.params.is_empty(), "no params should be inserted");
335    }
336
337    #[test]
338    fn simple_packing_with_nan_values_skips_with_warning() {
339        let mut desc = mk_desc();
340        let p = DataPipeline {
341            encoding: "simple_packing".to_string(),
342            ..Default::default()
343        };
344        let values = [1.0_f64, f64::NAN, 3.0];
345        apply_pipeline(&mut desc, Some(&values), &p, "nan_var").unwrap();
346        assert_eq!(desc.encoding, "none", "NaN → skip");
347    }
348
349    #[test]
350    fn unknown_encoding_errors() {
351        let mut desc = mk_desc();
352        let p = DataPipeline {
353            encoding: "magic_packing".to_string(),
354            ..Default::default()
355        };
356        let err = apply_pipeline(&mut desc, None, &p, "x").unwrap_err();
357        assert!(err.contains("magic_packing"));
358        assert!(err.contains("simple_packing"));
359    }
360
361    // ── Filter ──────────────────────────────────────────────────────
362
363    #[test]
364    fn shuffle_on_raw_f64_uses_native_byte_width() {
365        let mut desc = mk_desc(); // f64 → 8 bytes
366        let p = DataPipeline {
367            filter: "shuffle".to_string(),
368            ..Default::default()
369        };
370        apply_pipeline(&mut desc, None, &p, "x").unwrap();
371        assert_eq!(desc.filter, "shuffle");
372        assert_eq!(int_param(&desc, "shuffle_element_size"), 8);
373    }
374
375    #[test]
376    fn shuffle_on_simple_packed_uses_post_pack_byte_width() {
377        let mut desc = mk_desc();
378        let p = DataPipeline {
379            encoding: "simple_packing".to_string(),
380            bits: Some(16),
381            filter: "shuffle".to_string(),
382            ..Default::default()
383        };
384        let values = [0.0_f64, 1.0, 2.0, 3.0];
385        apply_pipeline(&mut desc, Some(&values), &p, "x").unwrap();
386        assert_eq!(desc.filter, "shuffle");
387        assert_eq!(
388            int_param(&desc, "shuffle_element_size"),
389            2,
390            "16-bit packed → 2-byte elements"
391        );
392    }
393
394    #[test]
395    fn shuffle_with_24bit_packing_rounds_up() {
396        let mut desc = mk_desc();
397        let p = DataPipeline {
398            encoding: "simple_packing".to_string(),
399            bits: Some(24),
400            filter: "shuffle".to_string(),
401            ..Default::default()
402        };
403        let values = [0.0_f64, 1.0, 2.0, 3.0];
404        apply_pipeline(&mut desc, Some(&values), &p, "x").unwrap();
405        assert_eq!(int_param(&desc, "shuffle_element_size"), 3);
406    }
407
408    #[test]
409    fn unknown_filter_errors() {
410        let mut desc = mk_desc();
411        let p = DataPipeline {
412            filter: "wibble".to_string(),
413            ..Default::default()
414        };
415        let err = apply_pipeline(&mut desc, None, &p, "x").unwrap_err();
416        assert!(err.contains("wibble"));
417    }
418
419    // ── Compression ─────────────────────────────────────────────────
420
421    #[test]
422    fn zstd_with_default_level() {
423        let mut desc = mk_desc();
424        let p = DataPipeline {
425            compression: "zstd".to_string(),
426            ..Default::default()
427        };
428        apply_pipeline(&mut desc, None, &p, "x").unwrap();
429        assert_eq!(desc.compression, "zstd");
430        assert_eq!(int_param(&desc, "zstd_level"), 3);
431    }
432
433    #[test]
434    fn zstd_with_custom_level() {
435        let mut desc = mk_desc();
436        let p = DataPipeline {
437            compression: "zstd".to_string(),
438            compression_level: Some(9),
439            ..Default::default()
440        };
441        apply_pipeline(&mut desc, None, &p, "x").unwrap();
442        assert_eq!(int_param(&desc, "zstd_level"), 9);
443    }
444
445    #[test]
446    fn lz4_has_no_params() {
447        let mut desc = mk_desc();
448        let p = DataPipeline {
449            compression: "lz4".to_string(),
450            ..Default::default()
451        };
452        apply_pipeline(&mut desc, None, &p, "x").unwrap();
453        assert_eq!(desc.compression, "lz4");
454        assert!(desc.params.is_empty());
455    }
456
457    #[test]
458    fn blosc2_with_custom_level() {
459        let mut desc = mk_desc();
460        let p = DataPipeline {
461            compression: "blosc2".to_string(),
462            compression_level: Some(7),
463            ..Default::default()
464        };
465        apply_pipeline(&mut desc, None, &p, "x").unwrap();
466        assert_eq!(desc.compression, "blosc2");
467        assert_eq!(int_param(&desc, "blosc2_clevel"), 7);
468        match desc.params.get("blosc2_codec") {
469            Some(CborValue::Text(s)) => assert_eq!(s, "lz4"),
470            other => panic!("blosc2_codec should be lz4: {other:?}"),
471        }
472    }
473
474    #[test]
475    fn szip_sets_defaults() {
476        let mut desc = mk_desc();
477        let p = DataPipeline {
478            compression: "szip".to_string(),
479            ..Default::default()
480        };
481        apply_pipeline(&mut desc, None, &p, "x").unwrap();
482        assert_eq!(desc.compression, "szip");
483        assert_eq!(int_param(&desc, "szip_rsi"), 128);
484        assert_eq!(int_param(&desc, "szip_block_size"), 16);
485        assert_eq!(int_param(&desc, "szip_flags"), 8);
486    }
487
488    #[test]
489    fn unknown_compression_errors() {
490        let mut desc = mk_desc();
491        let p = DataPipeline {
492            compression: "bogus".to_string(),
493            ..Default::default()
494        };
495        let err = apply_pipeline(&mut desc, None, &p, "x").unwrap_err();
496        assert!(err.contains("bogus"));
497    }
498
499    // ── Combined ────────────────────────────────────────────────────
500
501    #[test]
502    fn full_pipeline_simple_packing_shuffle_zstd() {
503        let mut desc = mk_desc();
504        let p = DataPipeline {
505            encoding: "simple_packing".to_string(),
506            bits: Some(24),
507            filter: "shuffle".to_string(),
508            compression: "zstd".to_string(),
509            compression_level: Some(5),
510        };
511        let values = [1.0_f64, 2.0, 3.0, 4.0];
512        apply_pipeline(&mut desc, Some(&values), &p, "x").unwrap();
513        assert_eq!(desc.encoding, "simple_packing");
514        assert_eq!(desc.filter, "shuffle");
515        assert_eq!(desc.compression, "zstd");
516        assert_eq!(int_param(&desc, "bits_per_value"), 24);
517        assert_eq!(int_param(&desc, "shuffle_element_size"), 3);
518        assert_eq!(int_param(&desc, "zstd_level"), 5);
519    }
520}