1use ciborium::Value as CborValue;
51use tensogram_encodings::simple_packing;
52
53use crate::types::DataObjectDescriptor;
54
55#[derive(Debug, Clone)]
62pub struct DataPipeline {
63 pub encoding: String,
65 pub bits: Option<u32>,
67 pub filter: String,
69 pub compression: String,
72 pub compression_level: Option<i32>,
75}
76
77impl Default for DataPipeline {
78 fn default() -> Self {
79 Self {
80 encoding: "none".to_string(),
81 bits: None,
82 filter: "none".to_string(),
83 compression: "none".to_string(),
84 compression_level: None,
85 }
86 }
87}
88
89pub fn apply_pipeline(
117 desc: &mut DataObjectDescriptor,
118 values: Option<&[f64]>,
119 pipeline: &DataPipeline,
120 var_label: &str,
121) -> Result<(), String> {
122 let mut applied_simple_packing = false;
124 match pipeline.encoding.as_str() {
125 "none" => {}
126 "simple_packing" => match values {
127 None => {
128 eprintln!(
129 "warning: skipping simple_packing for {var_label} \
130 (not a float64 payload)"
131 );
132 }
133 Some(values) => {
134 let bits = pipeline.bits.unwrap_or(16);
135 match simple_packing::compute_params(values, bits, 0) {
136 Ok(params) => {
137 desc.encoding = "simple_packing".to_string();
138 desc.params.insert(
139 "reference_value".to_string(),
140 CborValue::Float(params.reference_value),
141 );
142 desc.params.insert(
143 "binary_scale_factor".to_string(),
144 CborValue::Integer((i64::from(params.binary_scale_factor)).into()),
145 );
146 desc.params.insert(
147 "decimal_scale_factor".to_string(),
148 CborValue::Integer((i64::from(params.decimal_scale_factor)).into()),
149 );
150 desc.params.insert(
151 "bits_per_value".to_string(),
152 CborValue::Integer((i64::from(params.bits_per_value)).into()),
153 );
154 applied_simple_packing = true;
155 }
156 Err(e) => {
157 eprintln!("warning: skipping simple_packing for {var_label}: {e}");
162 }
163 }
164 }
165 },
166 other => {
167 return Err(format!(
168 "unknown encoding '{other}'; expected 'none' or 'simple_packing'"
169 ));
170 }
171 }
172
173 match pipeline.filter.as_str() {
175 "none" => {}
176 "shuffle" => {
177 desc.filter = "shuffle".to_string();
178 let element_size = if applied_simple_packing {
183 let bpv = pipeline.bits.unwrap_or(16) as usize;
184 bpv.div_ceil(8).max(1)
185 } else {
186 desc.dtype.byte_width()
187 };
188 desc.params.insert(
189 "shuffle_element_size".to_string(),
190 CborValue::Integer((element_size as i64).into()),
191 );
192 }
193 other => {
194 return Err(format!(
195 "unknown filter '{other}'; expected 'none' or 'shuffle'"
196 ));
197 }
198 }
199
200 match pipeline.compression.as_str() {
202 "none" => {}
203 "zstd" => {
204 desc.compression = "zstd".to_string();
205 let level = pipeline.compression_level.unwrap_or(3);
206 desc.params.insert(
207 "zstd_level".to_string(),
208 CborValue::Integer((i64::from(level)).into()),
209 );
210 }
211 "lz4" => {
212 desc.compression = "lz4".to_string();
213 }
214 "blosc2" => {
215 desc.compression = "blosc2".to_string();
216 let clevel = pipeline.compression_level.unwrap_or(5);
217 desc.params.insert(
218 "blosc2_clevel".to_string(),
219 CborValue::Integer((i64::from(clevel)).into()),
220 );
221 desc.params.insert(
224 "blosc2_codec".to_string(),
225 CborValue::Text("lz4".to_string()),
226 );
227 }
228 "szip" => {
229 desc.compression = "szip".to_string();
230 desc.params
233 .insert("szip_rsi".to_string(), CborValue::Integer(128.into()));
234 desc.params
235 .insert("szip_block_size".to_string(), CborValue::Integer(16.into()));
236 desc.params
237 .insert("szip_flags".to_string(), CborValue::Integer(8.into()));
238 }
239 other => {
240 return Err(format!(
241 "unknown compression '{other}'; expected one of: none, zstd, lz4, blosc2, szip"
242 ));
243 }
244 }
245
246 Ok(())
247}
248
249#[cfg(test)]
250mod tests {
251 use std::collections::BTreeMap;
252
253 use super::*;
254 use crate::Dtype;
255 use crate::types::ByteOrder;
256
257 fn mk_desc() -> DataObjectDescriptor {
258 DataObjectDescriptor {
259 obj_type: "ntensor".to_string(),
260 ndim: 1,
261 shape: vec![4],
262 strides: vec![1],
263 dtype: Dtype::Float64,
264 byte_order: ByteOrder::Little,
265 encoding: "none".to_string(),
266 filter: "none".to_string(),
267 compression: "none".to_string(),
268 params: BTreeMap::new(),
269 hash: None,
270 }
271 }
272
273 fn int_param(desc: &DataObjectDescriptor, key: &str) -> i64 {
274 match desc.params.get(key) {
275 Some(CborValue::Integer(i)) => {
276 let n: i128 = (*i).into();
277 n as i64
278 }
279 other => panic!("{key} not an integer: {other:?}"),
280 }
281 }
282
283 #[test]
286 fn default_pipeline_is_all_none() {
287 let p = DataPipeline::default();
288 assert_eq!(p.encoding, "none");
289 assert_eq!(p.filter, "none");
290 assert_eq!(p.compression, "none");
291 assert!(p.bits.is_none());
292 assert!(p.compression_level.is_none());
293 }
294
295 #[test]
296 fn default_pipeline_leaves_descriptor_unchanged() {
297 let mut desc = mk_desc();
298 let values = [1.0, 2.0, 3.0, 4.0];
299 apply_pipeline(&mut desc, Some(&values), &DataPipeline::default(), "x").unwrap();
300 assert_eq!(desc.encoding, "none");
301 assert_eq!(desc.filter, "none");
302 assert_eq!(desc.compression, "none");
303 assert!(desc.params.is_empty());
304 }
305
306 #[test]
309 fn simple_packing_populates_four_params() {
310 let mut desc = mk_desc();
311 let p = DataPipeline {
312 encoding: "simple_packing".to_string(),
313 bits: Some(16),
314 ..Default::default()
315 };
316 let values = [0.0_f64, 1.0, 2.0, 3.0];
317 apply_pipeline(&mut desc, Some(&values), &p, "test").unwrap();
318 assert_eq!(desc.encoding, "simple_packing");
319 assert_eq!(int_param(&desc, "bits_per_value"), 16);
320 assert_eq!(int_param(&desc, "decimal_scale_factor"), 0);
321 assert!(desc.params.contains_key("reference_value"));
322 assert!(desc.params.contains_key("binary_scale_factor"));
323 }
324
325 #[test]
326 fn simple_packing_with_no_values_skips_with_warning() {
327 let mut desc = mk_desc();
328 let p = DataPipeline {
329 encoding: "simple_packing".to_string(),
330 ..Default::default()
331 };
332 apply_pipeline(&mut desc, None, &p, "int_var").unwrap();
333 assert_eq!(desc.encoding, "none", "should skip, not set");
334 assert!(desc.params.is_empty(), "no params should be inserted");
335 }
336
337 #[test]
338 fn simple_packing_with_nan_values_skips_with_warning() {
339 let mut desc = mk_desc();
340 let p = DataPipeline {
341 encoding: "simple_packing".to_string(),
342 ..Default::default()
343 };
344 let values = [1.0_f64, f64::NAN, 3.0];
345 apply_pipeline(&mut desc, Some(&values), &p, "nan_var").unwrap();
346 assert_eq!(desc.encoding, "none", "NaN → skip");
347 }
348
349 #[test]
350 fn unknown_encoding_errors() {
351 let mut desc = mk_desc();
352 let p = DataPipeline {
353 encoding: "magic_packing".to_string(),
354 ..Default::default()
355 };
356 let err = apply_pipeline(&mut desc, None, &p, "x").unwrap_err();
357 assert!(err.contains("magic_packing"));
358 assert!(err.contains("simple_packing"));
359 }
360
361 #[test]
364 fn shuffle_on_raw_f64_uses_native_byte_width() {
365 let mut desc = mk_desc(); let p = DataPipeline {
367 filter: "shuffle".to_string(),
368 ..Default::default()
369 };
370 apply_pipeline(&mut desc, None, &p, "x").unwrap();
371 assert_eq!(desc.filter, "shuffle");
372 assert_eq!(int_param(&desc, "shuffle_element_size"), 8);
373 }
374
375 #[test]
376 fn shuffle_on_simple_packed_uses_post_pack_byte_width() {
377 let mut desc = mk_desc();
378 let p = DataPipeline {
379 encoding: "simple_packing".to_string(),
380 bits: Some(16),
381 filter: "shuffle".to_string(),
382 ..Default::default()
383 };
384 let values = [0.0_f64, 1.0, 2.0, 3.0];
385 apply_pipeline(&mut desc, Some(&values), &p, "x").unwrap();
386 assert_eq!(desc.filter, "shuffle");
387 assert_eq!(
388 int_param(&desc, "shuffle_element_size"),
389 2,
390 "16-bit packed → 2-byte elements"
391 );
392 }
393
394 #[test]
395 fn shuffle_with_24bit_packing_rounds_up() {
396 let mut desc = mk_desc();
397 let p = DataPipeline {
398 encoding: "simple_packing".to_string(),
399 bits: Some(24),
400 filter: "shuffle".to_string(),
401 ..Default::default()
402 };
403 let values = [0.0_f64, 1.0, 2.0, 3.0];
404 apply_pipeline(&mut desc, Some(&values), &p, "x").unwrap();
405 assert_eq!(int_param(&desc, "shuffle_element_size"), 3);
406 }
407
408 #[test]
409 fn unknown_filter_errors() {
410 let mut desc = mk_desc();
411 let p = DataPipeline {
412 filter: "wibble".to_string(),
413 ..Default::default()
414 };
415 let err = apply_pipeline(&mut desc, None, &p, "x").unwrap_err();
416 assert!(err.contains("wibble"));
417 }
418
419 #[test]
422 fn zstd_with_default_level() {
423 let mut desc = mk_desc();
424 let p = DataPipeline {
425 compression: "zstd".to_string(),
426 ..Default::default()
427 };
428 apply_pipeline(&mut desc, None, &p, "x").unwrap();
429 assert_eq!(desc.compression, "zstd");
430 assert_eq!(int_param(&desc, "zstd_level"), 3);
431 }
432
433 #[test]
434 fn zstd_with_custom_level() {
435 let mut desc = mk_desc();
436 let p = DataPipeline {
437 compression: "zstd".to_string(),
438 compression_level: Some(9),
439 ..Default::default()
440 };
441 apply_pipeline(&mut desc, None, &p, "x").unwrap();
442 assert_eq!(int_param(&desc, "zstd_level"), 9);
443 }
444
445 #[test]
446 fn lz4_has_no_params() {
447 let mut desc = mk_desc();
448 let p = DataPipeline {
449 compression: "lz4".to_string(),
450 ..Default::default()
451 };
452 apply_pipeline(&mut desc, None, &p, "x").unwrap();
453 assert_eq!(desc.compression, "lz4");
454 assert!(desc.params.is_empty());
455 }
456
457 #[test]
458 fn blosc2_with_custom_level() {
459 let mut desc = mk_desc();
460 let p = DataPipeline {
461 compression: "blosc2".to_string(),
462 compression_level: Some(7),
463 ..Default::default()
464 };
465 apply_pipeline(&mut desc, None, &p, "x").unwrap();
466 assert_eq!(desc.compression, "blosc2");
467 assert_eq!(int_param(&desc, "blosc2_clevel"), 7);
468 match desc.params.get("blosc2_codec") {
469 Some(CborValue::Text(s)) => assert_eq!(s, "lz4"),
470 other => panic!("blosc2_codec should be lz4: {other:?}"),
471 }
472 }
473
474 #[test]
475 fn szip_sets_defaults() {
476 let mut desc = mk_desc();
477 let p = DataPipeline {
478 compression: "szip".to_string(),
479 ..Default::default()
480 };
481 apply_pipeline(&mut desc, None, &p, "x").unwrap();
482 assert_eq!(desc.compression, "szip");
483 assert_eq!(int_param(&desc, "szip_rsi"), 128);
484 assert_eq!(int_param(&desc, "szip_block_size"), 16);
485 assert_eq!(int_param(&desc, "szip_flags"), 8);
486 }
487
488 #[test]
489 fn unknown_compression_errors() {
490 let mut desc = mk_desc();
491 let p = DataPipeline {
492 compression: "bogus".to_string(),
493 ..Default::default()
494 };
495 let err = apply_pipeline(&mut desc, None, &p, "x").unwrap_err();
496 assert!(err.contains("bogus"));
497 }
498
499 #[test]
502 fn full_pipeline_simple_packing_shuffle_zstd() {
503 let mut desc = mk_desc();
504 let p = DataPipeline {
505 encoding: "simple_packing".to_string(),
506 bits: Some(24),
507 filter: "shuffle".to_string(),
508 compression: "zstd".to_string(),
509 compression_level: Some(5),
510 };
511 let values = [1.0_f64, 2.0, 3.0, 4.0];
512 apply_pipeline(&mut desc, Some(&values), &p, "x").unwrap();
513 assert_eq!(desc.encoding, "simple_packing");
514 assert_eq!(desc.filter, "shuffle");
515 assert_eq!(desc.compression, "zstd");
516 assert_eq!(int_param(&desc, "bits_per_value"), 24);
517 assert_eq!(int_param(&desc, "shuffle_element_size"), 3);
518 assert_eq!(int_param(&desc, "zstd_level"), 5);
519 }
520}