1use std::collections::BTreeMap;
10
11use crate::dtype::Dtype;
12use crate::error::{Result, TensogramError};
13use crate::framing::{self, EncodedObject};
14use crate::hash::{HashAlgorithm, compute_hash};
15use crate::metadata::RESERVED_KEY;
16use crate::types::{DataObjectDescriptor, GlobalMetadata, HashDescriptor};
17#[cfg(feature = "blosc2")]
18use tensogram_encodings::pipeline::Blosc2Codec;
19#[cfg(feature = "sz3")]
20use tensogram_encodings::pipeline::Sz3ErrorBound;
21#[cfg(feature = "zfp")]
22use tensogram_encodings::pipeline::ZfpMode;
23use tensogram_encodings::pipeline::{
24 self, CompressionType, EncodingType, FilterType, PipelineConfig,
25};
26use tensogram_encodings::simple_packing::SimplePackingParams;
27
28#[derive(Debug, Clone)]
30pub struct EncodeOptions {
31 pub hash_algorithm: Option<HashAlgorithm>,
33 pub emit_preceders: bool,
40 pub compression_backend: pipeline::CompressionBackend,
48 pub threads: u32,
65 pub parallel_threshold_bytes: Option<usize>,
72}
73
74impl Default for EncodeOptions {
75 fn default() -> Self {
76 Self {
77 hash_algorithm: Some(HashAlgorithm::Xxh3),
78 emit_preceders: false,
79 compression_backend: pipeline::CompressionBackend::default(),
80 threads: 0,
81 parallel_threshold_bytes: None,
82 }
83 }
84}
85
86pub(crate) fn validate_object(desc: &DataObjectDescriptor, data_len: usize) -> Result<()> {
87 if desc.obj_type.is_empty() {
88 return Err(TensogramError::Metadata(
89 "obj_type must not be empty".to_string(),
90 ));
91 }
92 if desc.ndim as usize != desc.shape.len() {
93 return Err(TensogramError::Metadata(format!(
94 "ndim {} does not match shape.len() {}",
95 desc.ndim,
96 desc.shape.len()
97 )));
98 }
99 if desc.strides.len() != desc.shape.len() {
100 return Err(TensogramError::Metadata(format!(
101 "strides.len() {} does not match shape.len() {}",
102 desc.strides.len(),
103 desc.shape.len()
104 )));
105 }
106 if desc.encoding == "none" {
107 let product = desc
108 .shape
109 .iter()
110 .try_fold(1u64, |acc, &x| acc.checked_mul(x))
111 .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
112 if desc.dtype.byte_width() > 0 {
113 let expected_bytes = product
114 .checked_mul(desc.dtype.byte_width() as u64)
115 .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
116 if expected_bytes != data_len as u64 {
117 return Err(TensogramError::Metadata(format!(
118 "data_len {data_len} does not match expected {expected_bytes} bytes from shape and dtype"
119 )));
120 }
121 } else if desc.dtype == crate::Dtype::Bitmask {
122 let expected_bytes = product.div_ceil(8);
124 if expected_bytes != data_len as u64 {
125 return Err(TensogramError::Metadata(format!(
126 "data_len {data_len} does not match expected {expected_bytes} bytes for bitmask (ceil({product}/8))"
127 )));
128 }
129 }
130 }
131 Ok(())
132}
133
134#[derive(Debug, Clone, Copy)]
135enum EncodeMode {
136 Raw,
137 PreEncoded,
138}
139
140fn encode_one_object(
148 desc: &DataObjectDescriptor,
149 data: &[u8],
150 mode: EncodeMode,
151 options: &EncodeOptions,
152 intra_codec_threads: u32,
153) -> Result<EncodedObject> {
154 validate_object(desc, data.len())?;
155
156 let shape_product = desc
157 .shape
158 .iter()
159 .try_fold(1u64, |acc, &x| acc.checked_mul(x))
160 .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
161 let num_elements = usize::try_from(shape_product)
162 .map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))?;
163 let dtype = desc.dtype;
164
165 let config = build_pipeline_config_with_backend(
166 desc,
167 num_elements,
168 dtype,
169 options.compression_backend,
170 intra_codec_threads,
171 )?;
172
173 let mut final_desc = desc.clone();
175
176 let encoded_payload: Vec<u8> = match mode {
177 EncodeMode::Raw => {
178 let result = pipeline::encode_pipeline(data, &config)
180 .map_err(|e| TensogramError::Encoding(e.to_string()))?;
181
182 if let Some(offsets) = &result.block_offsets {
184 final_desc.params.insert(
185 "szip_block_offsets".to_string(),
186 ciborium::Value::Array(
187 offsets
188 .iter()
189 .map(|&o| ciborium::Value::Integer(o.into()))
190 .collect(),
191 ),
192 );
193 }
194
195 result.encoded_bytes
196 }
197 EncodeMode::PreEncoded => {
198 validate_no_szip_offsets_for_non_szip(desc)?;
202 if desc.compression == "szip" && desc.params.contains_key("szip_block_offsets") {
203 validate_szip_block_offsets(&desc.params, data.len())?;
204 }
205 data.to_vec()
206 }
207 };
208
209 if let Some(algorithm) = options.hash_algorithm {
211 let hash_value = compute_hash(&encoded_payload, algorithm);
212 final_desc.hash = Some(HashDescriptor {
213 hash_type: algorithm.as_str().to_string(),
214 value: hash_value,
215 });
216 }
217
218 Ok(EncodedObject {
219 descriptor: final_desc,
220 encoded_payload,
221 })
222}
223
224fn encode_inner(
225 global_metadata: &GlobalMetadata,
226 descriptors: &[(&DataObjectDescriptor, &[u8])],
227 options: &EncodeOptions,
228 mode: EncodeMode,
229) -> Result<Vec<u8>> {
230 if options.emit_preceders {
233 return Err(TensogramError::Encoding(
234 "emit_preceders is not supported in buffered mode; use StreamingEncoder::write_preceder() instead".to_string(),
235 ));
236 }
237
238 let budget = crate::parallel::resolve_budget(options.threads);
245 let total_bytes: usize = descriptors.iter().map(|(_, d)| d.len()).sum();
246 let parallel =
247 crate::parallel::should_parallelise(budget, total_bytes, options.parallel_threshold_bytes);
248
249 let any_axis_b = descriptors
250 .iter()
251 .any(|(d, _)| crate::parallel::is_axis_b_friendly(&d.encoding, &d.filter, &d.compression));
252 let use_axis_a = parallel && crate::parallel::use_axis_a(descriptors.len(), budget, any_axis_b);
253
254 let intra_codec_threads = if parallel && !use_axis_a { budget } else { 0 };
258
259 let encode_one = |(desc, data): &(&DataObjectDescriptor, &[u8])| {
260 encode_one_object(desc, data, mode, options, intra_codec_threads)
261 };
262
263 let encoded_objects: Vec<EncodedObject> = if use_axis_a {
264 #[cfg(feature = "threads")]
268 {
269 use rayon::prelude::*;
270 crate::parallel::with_pool(budget, || {
271 descriptors
272 .par_iter()
273 .map(&encode_one)
274 .collect::<Result<Vec<_>>>()
275 })?
276 }
277 #[cfg(not(feature = "threads"))]
278 {
279 descriptors.iter().map(encode_one).collect::<Result<_>>()?
280 }
281 } else {
282 crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, || {
287 descriptors.iter().map(encode_one).collect::<Result<_>>()
288 })?
289 };
290
291 validate_no_client_reserved(global_metadata)?;
293
294 if global_metadata.base.len() > descriptors.len() {
298 return Err(TensogramError::Metadata(format!(
299 "metadata base has {} entries but only {} descriptors provided; \
300 extra base entries would be discarded",
301 global_metadata.base.len(),
302 descriptors.len()
303 )));
304 }
305
306 let mut enriched_meta = global_metadata.clone();
309 populate_base_entries(&mut enriched_meta.base, &encoded_objects);
310 populate_reserved_provenance(&mut enriched_meta.reserved);
311
312 framing::encode_message(&enriched_meta, &encoded_objects)
313}
314
315#[tracing::instrument(skip(global_metadata, descriptors, options), fields(objects = descriptors.len()))]
321pub fn encode(
322 global_metadata: &GlobalMetadata,
323 descriptors: &[(&DataObjectDescriptor, &[u8])],
324 options: &EncodeOptions,
325) -> Result<Vec<u8>> {
326 encode_inner(global_metadata, descriptors, options, EncodeMode::Raw)
327}
328
329#[tracing::instrument(name = "encode_pre_encoded", skip_all, fields(num_objects = descriptors.len()))]
342pub fn encode_pre_encoded(
343 global_metadata: &GlobalMetadata,
344 descriptors: &[(&DataObjectDescriptor, &[u8])],
345 options: &EncodeOptions,
346) -> Result<Vec<u8>> {
347 encode_inner(
348 global_metadata,
349 descriptors,
350 options,
351 EncodeMode::PreEncoded,
352 )
353}
354
355fn validate_no_client_reserved(meta: &GlobalMetadata) -> Result<()> {
360 if !meta.reserved.is_empty() {
361 return Err(TensogramError::Metadata(format!(
362 "client code must not write to '{RESERVED_KEY}' at message level; \
363 this field is populated by the library"
364 )));
365 }
366 for (i, entry) in meta.base.iter().enumerate() {
367 if entry.contains_key(RESERVED_KEY) {
368 return Err(TensogramError::Metadata(format!(
369 "client code must not write to '{RESERVED_KEY}' in base[{i}]; \
370 this field is populated by the library"
371 )));
372 }
373 }
374 Ok(())
375}
376
377pub(crate) fn populate_base_entries(
383 base: &mut Vec<BTreeMap<String, ciborium::Value>>,
384 encoded_objects: &[crate::framing::EncodedObject],
385) {
386 use ciborium::Value;
387
388 base.resize_with(encoded_objects.len(), BTreeMap::new);
390
391 for (entry, obj) in base.iter_mut().zip(encoded_objects.iter()) {
392 let desc = &obj.descriptor;
393
394 let tensor_map = Value::Map(vec![
395 (
396 Value::Text("ndim".to_string()),
397 Value::Integer(desc.ndim.into()),
398 ),
399 (
400 Value::Text("shape".to_string()),
401 Value::Array(
402 desc.shape
403 .iter()
404 .map(|&d| Value::Integer(d.into()))
405 .collect(),
406 ),
407 ),
408 (
409 Value::Text("strides".to_string()),
410 Value::Array(
411 desc.strides
412 .iter()
413 .map(|&s| Value::Integer(s.into()))
414 .collect(),
415 ),
416 ),
417 (
418 Value::Text("dtype".to_string()),
419 Value::Text(desc.dtype.to_string()),
420 ),
421 ]);
422
423 let reserved_map = Value::Map(vec![(Value::Text("tensor".to_string()), tensor_map)]);
424
425 entry.insert(RESERVED_KEY.to_string(), reserved_map);
426 }
427}
428
429pub(crate) fn populate_reserved_provenance(reserved: &mut BTreeMap<String, ciborium::Value>) {
440 use ciborium::Value;
441 #[cfg(not(target_arch = "wasm32"))]
442 use std::time::SystemTime;
443
444 let version_str = env!("CARGO_PKG_VERSION");
446 let encoder_map = Value::Map(vec![
447 (
448 Value::Text("name".to_string()),
449 Value::Text("tensogram".to_string()),
450 ),
451 (
452 Value::Text("version".to_string()),
453 Value::Text(version_str.to_string()),
454 ),
455 ]);
456 reserved.insert("encoder".to_string(), encoder_map);
457
458 #[cfg(not(target_arch = "wasm32"))]
463 {
464 let secs = SystemTime::now()
465 .duration_since(SystemTime::UNIX_EPOCH)
466 .unwrap_or_default()
467 .as_secs();
468
469 let days = secs / 86400;
472 let day_secs = secs % 86400;
473 let hours = day_secs / 3600;
474 let minutes = (day_secs % 3600) / 60;
475 let seconds = day_secs % 60;
476 let (y, m, d) = civil_from_days(days as i64);
478 let timestamp = format!("{y:04}-{m:02}-{d:02}T{hours:02}:{minutes:02}:{seconds:02}Z");
479 reserved.insert("time".to_string(), Value::Text(timestamp));
480 }
481
482 let id = uuid::Uuid::new_v4();
484 reserved.insert("uuid".to_string(), Value::Text(id.to_string()));
485}
486
487#[cfg(not(target_arch = "wasm32"))]
490fn civil_from_days(days: i64) -> (i64, u32, u32) {
491 let z = days + 719468;
492 let era = if z >= 0 { z } else { z - 146096 } / 146097;
493 let doe = (z - era * 146097) as u32;
496 let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
497 let y = yoe as i64 + era * 400;
498 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
499 let mp = (5 * doy + 2) / 153;
500 let d = doy - (153 * mp + 2) / 5 + 1;
501 let m = if mp < 10 { mp + 3 } else { mp - 9 };
502 let y = if m <= 2 { y + 1 } else { y };
503 (y, m, d)
504}
505
506pub(crate) fn build_pipeline_config(
507 desc: &DataObjectDescriptor,
508 num_values: usize,
509 dtype: Dtype,
510) -> Result<PipelineConfig> {
511 build_pipeline_config_with_backend(
512 desc,
513 num_values,
514 dtype,
515 pipeline::CompressionBackend::default(),
516 0,
517 )
518}
519
520pub(crate) fn build_pipeline_config_with_backend(
526 desc: &DataObjectDescriptor,
527 num_values: usize,
528 dtype: Dtype,
529 compression_backend: pipeline::CompressionBackend,
530 intra_codec_threads: u32,
531) -> Result<PipelineConfig> {
532 let encoding = match desc.encoding.as_str() {
533 "none" => EncodingType::None,
534 "simple_packing" => {
535 if dtype.byte_width() != 8 {
536 return Err(TensogramError::Encoding(
537 "simple_packing only supports float64 dtype".to_string(),
538 ));
539 }
540 let params = extract_simple_packing_params(&desc.params)?;
541 EncodingType::SimplePacking(params)
542 }
543 other => {
544 return Err(TensogramError::Encoding(format!(
545 "unknown encoding: {other}"
546 )));
547 }
548 };
549
550 let filter = match desc.filter.as_str() {
551 "none" => FilterType::None,
552 "shuffle" => {
553 let element_size = usize::try_from(get_u64_param(
554 &desc.params,
555 "shuffle_element_size",
556 )?)
557 .map_err(|_| {
558 TensogramError::Metadata("shuffle_element_size out of usize range".to_string())
559 })?;
560 FilterType::Shuffle { element_size }
561 }
562 other => return Err(TensogramError::Encoding(format!("unknown filter: {other}"))),
563 };
564
565 let compression = match desc.compression.as_str() {
566 "none" => CompressionType::None,
567 #[cfg(any(feature = "szip", feature = "szip-pure"))]
568 "szip" => {
569 let rsi = u32::try_from(get_u64_param(&desc.params, "szip_rsi")?)
570 .map_err(|_| TensogramError::Metadata("szip_rsi out of u32 range".to_string()))?;
571 let block_size = u32::try_from(get_u64_param(&desc.params, "szip_block_size")?)
572 .map_err(|_| {
573 TensogramError::Metadata("szip_block_size out of u32 range".to_string())
574 })?;
575 let flags = u32::try_from(get_u64_param(&desc.params, "szip_flags")?)
576 .map_err(|_| TensogramError::Metadata("szip_flags out of u32 range".to_string()))?;
577 let bits_per_sample = match (&encoding, &filter) {
578 (EncodingType::SimplePacking(params), _) => params.bits_per_value,
579 (EncodingType::None, FilterType::Shuffle { .. }) => 8,
580 (EncodingType::None, FilterType::None) => (dtype.byte_width() * 8) as u32,
581 };
582 CompressionType::Szip {
583 rsi,
584 block_size,
585 flags,
586 bits_per_sample,
587 }
588 }
589 #[cfg(any(feature = "zstd", feature = "zstd-pure"))]
590 "zstd" => {
591 let level_i64 = get_i64_param(&desc.params, "zstd_level").unwrap_or(3);
592 let level = i32::try_from(level_i64).map_err(|_| {
593 TensogramError::Metadata(format!("zstd_level value {level_i64} out of i32 range"))
594 })?;
595 CompressionType::Zstd { level }
596 }
597 #[cfg(feature = "lz4")]
598 "lz4" => CompressionType::Lz4,
599 #[cfg(feature = "blosc2")]
600 "blosc2" => {
601 let codec_str = match desc.params.get("blosc2_codec") {
602 Some(ciborium::Value::Text(s)) => s.as_str(),
603 _ => "lz4",
604 };
605 let codec = match codec_str {
606 "blosclz" => Blosc2Codec::Blosclz,
607 "lz4" => Blosc2Codec::Lz4,
608 "lz4hc" => Blosc2Codec::Lz4hc,
609 "zlib" => Blosc2Codec::Zlib,
610 "zstd" => Blosc2Codec::Zstd,
611 other => {
612 return Err(TensogramError::Encoding(format!(
613 "unknown blosc2 codec: {other}"
614 )));
615 }
616 };
617 let clevel_i64 = get_i64_param(&desc.params, "blosc2_clevel").unwrap_or(5);
618 let clevel = i32::try_from(clevel_i64).map_err(|_| {
619 TensogramError::Metadata(format!(
620 "blosc2_clevel value {clevel_i64} out of i32 range"
621 ))
622 })?;
623 let typesize = match (&encoding, &filter) {
624 (EncodingType::SimplePacking(params), _) => {
625 (params.bits_per_value as usize).div_ceil(8)
626 }
627 (EncodingType::None, FilterType::Shuffle { .. }) => 1,
628 (EncodingType::None, FilterType::None) => dtype.byte_width(),
629 };
630 CompressionType::Blosc2 {
631 codec,
632 clevel,
633 typesize,
634 }
635 }
636 #[cfg(feature = "zfp")]
637 "zfp" => {
638 let mode_str = match desc.params.get("zfp_mode") {
639 Some(ciborium::Value::Text(s)) => s.clone(),
640 _ => {
641 return Err(TensogramError::Metadata(
642 "missing required parameter: zfp_mode".to_string(),
643 ));
644 }
645 };
646 let mode = match mode_str.as_str() {
647 "fixed_rate" => {
648 let rate = get_f64_param(&desc.params, "zfp_rate")?;
649 ZfpMode::FixedRate { rate }
650 }
651 "fixed_precision" => {
652 let precision = u32::try_from(get_u64_param(&desc.params, "zfp_precision")?)
653 .map_err(|_| {
654 TensogramError::Metadata("zfp_precision out of u32 range".to_string())
655 })?;
656 ZfpMode::FixedPrecision { precision }
657 }
658 "fixed_accuracy" => {
659 let tolerance = get_f64_param(&desc.params, "zfp_tolerance")?;
660 ZfpMode::FixedAccuracy { tolerance }
661 }
662 other => {
663 return Err(TensogramError::Encoding(format!(
664 "unknown zfp_mode: {other}"
665 )));
666 }
667 };
668 CompressionType::Zfp { mode }
669 }
670 #[cfg(feature = "sz3")]
671 "sz3" => {
672 let mode_str = match desc.params.get("sz3_error_bound_mode") {
673 Some(ciborium::Value::Text(s)) => s.clone(),
674 _ => {
675 return Err(TensogramError::Metadata(
676 "missing required parameter: sz3_error_bound_mode".to_string(),
677 ));
678 }
679 };
680 let bound_val = get_f64_param(&desc.params, "sz3_error_bound")?;
681 let error_bound = match mode_str.as_str() {
682 "abs" => Sz3ErrorBound::Absolute(bound_val),
683 "rel" => Sz3ErrorBound::Relative(bound_val),
684 "psnr" => Sz3ErrorBound::Psnr(bound_val),
685 other => {
686 return Err(TensogramError::Encoding(format!(
687 "unknown sz3_error_bound_mode: {other}"
688 )));
689 }
690 };
691 CompressionType::Sz3 { error_bound }
692 }
693 other => {
694 return Err(TensogramError::Encoding(format!(
695 "unknown compression: {other}"
696 )));
697 }
698 };
699
700 Ok(PipelineConfig {
701 encoding,
702 filter,
703 compression,
704 num_values,
705 byte_order: desc.byte_order,
706 dtype_byte_width: dtype.byte_width(),
707 swap_unit_size: dtype.swap_unit_size(),
708 compression_backend,
709 intra_codec_threads,
710 })
711}
712
713fn extract_simple_packing_params(
714 params: &BTreeMap<String, ciborium::Value>,
715) -> Result<SimplePackingParams> {
716 let reference_value = get_f64_param(params, "reference_value")?;
717 if reference_value.is_nan() || reference_value.is_infinite() {
718 return Err(TensogramError::Metadata(format!(
719 "reference_value must be finite, got {reference_value}"
720 )));
721 }
722 Ok(SimplePackingParams {
723 reference_value,
724 binary_scale_factor: i32::try_from(get_i64_param(params, "binary_scale_factor")?).map_err(
725 |_| TensogramError::Metadata("binary_scale_factor out of i32 range".to_string()),
726 )?,
727 decimal_scale_factor: i32::try_from(get_i64_param(params, "decimal_scale_factor")?)
728 .map_err(|_| {
729 TensogramError::Metadata("decimal_scale_factor out of i32 range".to_string())
730 })?,
731 bits_per_value: u32::try_from(get_u64_param(params, "bits_per_value")?)
732 .map_err(|_| TensogramError::Metadata("bits_per_value out of u32 range".to_string()))?,
733 })
734}
735
736pub(crate) fn get_f64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<f64> {
737 match params.get(key) {
738 Some(ciborium::Value::Float(f)) => Ok(*f),
739 Some(ciborium::Value::Integer(i)) => {
740 let n: i128 = (*i).into();
743 Ok(n as f64)
744 }
745 Some(other) => Err(TensogramError::Metadata(format!(
746 "expected number for {key}, got {other:?}"
747 ))),
748 None => Err(TensogramError::Metadata(format!(
749 "missing required parameter: {key}"
750 ))),
751 }
752}
753
754pub(crate) fn get_i64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<i64> {
755 match params.get(key) {
756 Some(ciborium::Value::Integer(i)) => {
757 let n: i128 = (*i).into();
758 i64::try_from(n).map_err(|_| {
759 TensogramError::Metadata(format!("integer value {n} out of i64 range for {key}"))
760 })
761 }
762 Some(other) => Err(TensogramError::Metadata(format!(
763 "expected integer for {key}, got {other:?}"
764 ))),
765 None => Err(TensogramError::Metadata(format!(
766 "missing required parameter: {key}"
767 ))),
768 }
769}
770
771pub(crate) fn get_u64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<u64> {
772 match params.get(key) {
773 Some(ciborium::Value::Integer(i)) => {
774 let n: i128 = (*i).into();
775 u64::try_from(n).map_err(|_| {
776 TensogramError::Metadata(format!("integer value {n} out of u64 range for {key}"))
777 })
778 }
779 Some(other) => Err(TensogramError::Metadata(format!(
780 "expected integer for {key}, got {other:?}"
781 ))),
782 None => Err(TensogramError::Metadata(format!(
783 "missing required parameter: {key}"
784 ))),
785 }
786}
787
788pub(crate) fn validate_szip_block_offsets(
789 params: &BTreeMap<String, ciborium::Value>,
790 encoded_bytes_len: usize,
791) -> Result<()> {
792 let value = params.get("szip_block_offsets").ok_or_else(|| {
793 TensogramError::Metadata(
794 "missing required parameter: szip_block_offsets for szip compression".to_string(),
795 )
796 })?;
797
798 let offsets = match value {
799 ciborium::Value::Array(arr) => arr,
800 other => {
801 return Err(TensogramError::Metadata(format!(
802 "szip_block_offsets must be an array, got {other:?}"
803 )));
804 }
805 };
806
807 if offsets.is_empty() {
808 return Err(TensogramError::Metadata(
809 "szip_block_offsets must not be empty; first offset must be 0".to_string(),
810 ));
811 }
812
813 let bit_bound = encoded_bytes_len.checked_mul(8).ok_or_else(|| {
814 TensogramError::Metadata(format!(
815 "encoded byte length {encoded_bytes_len} overflows bit-bound calculation"
816 ))
817 })?;
818 let bit_bound_u64 = u64::try_from(bit_bound).map_err(|_| {
819 TensogramError::Metadata(format!(
820 "bit-bound {bit_bound} derived from {encoded_bytes_len} bytes does not fit in u64"
821 ))
822 })?;
823
824 let mut parsed_offsets = Vec::with_capacity(offsets.len());
825 for (idx, item) in offsets.iter().enumerate() {
826 let offset = match item {
827 ciborium::Value::Integer(i) => {
828 let n: i128 = (*i).into();
829 u64::try_from(n).map_err(|_| {
830 TensogramError::Metadata(format!(
831 "szip_block_offsets[{idx}] = {n} out of u64 range"
832 ))
833 })?
834 }
835 other => {
836 return Err(TensogramError::Metadata(format!(
837 "szip_block_offsets[{idx}] must be an integer, got {other:?}"
838 )));
839 }
840 };
841
842 if offset > bit_bound_u64 {
843 return Err(TensogramError::Metadata(format!(
844 "szip_block_offsets[{idx}] = {offset} exceeds bit bound {bit_bound_u64} (encoded_bytes_len = {encoded_bytes_len} bytes, {bit_bound_u64} bits)"
845 )));
846 }
847
848 if idx == 0 {
849 if offset != 0 {
850 return Err(TensogramError::Metadata(format!(
851 "szip_block_offsets[0] must be 0, got {offset}"
852 )));
853 }
854 } else {
855 let prev = parsed_offsets[idx - 1];
856 if offset <= prev {
857 return Err(TensogramError::Metadata(format!(
858 "szip_block_offsets must be strictly increasing: szip_block_offsets[{}] = {}, szip_block_offsets[{idx}] = {offset}",
859 idx - 1,
860 prev
861 )));
862 }
863 }
864
865 parsed_offsets.push(offset);
866 }
867
868 Ok(())
869}
870
871pub(crate) fn validate_no_szip_offsets_for_non_szip(desc: &DataObjectDescriptor) -> Result<()> {
872 if desc.compression != "szip" && desc.params.contains_key("szip_block_offsets") {
873 return Err(TensogramError::Metadata(format!(
874 "szip_block_offsets provided but compression is '{}', not 'szip'",
875 desc.compression
876 )));
877 }
878 Ok(())
879}
880
881#[cfg(test)]
884mod tests {
885 use super::*;
886 use crate::decode::{DecodeOptions, decode};
887 use crate::types::{ByteOrder, GlobalMetadata};
888 use std::collections::BTreeMap;
889
890 fn make_descriptor(shape: Vec<u64>) -> DataObjectDescriptor {
891 let strides = {
892 let mut s = vec![1u64; shape.len()];
893 for i in (0..shape.len().saturating_sub(1)).rev() {
894 s[i] = s[i + 1] * shape[i + 1];
895 }
896 s
897 };
898 DataObjectDescriptor {
899 obj_type: "ntensor".to_string(),
900 ndim: shape.len() as u64,
901 shape,
902 strides,
903 dtype: Dtype::Float32,
904 byte_order: ByteOrder::native(),
905 encoding: "none".to_string(),
906 filter: "none".to_string(),
907 compression: "none".to_string(),
908 params: BTreeMap::new(),
909 hash: None,
910 }
911 }
912
913 #[test]
916 fn test_base_more_entries_than_descriptors_rejected() {
917 let meta = GlobalMetadata {
919 version: 2,
920 base: vec![
921 BTreeMap::new(),
922 BTreeMap::new(),
923 BTreeMap::new(),
924 BTreeMap::new(),
925 BTreeMap::new(),
926 ],
927 ..Default::default()
928 };
929 let desc = make_descriptor(vec![4]);
930 let data = vec![0u8; 16];
931 let options = EncodeOptions {
932 hash_algorithm: None,
933 ..Default::default()
934 };
935 let result = encode(
936 &meta,
937 &[(&desc, data.as_slice()), (&desc, data.as_slice())],
938 &options,
939 );
940 assert!(
941 result.is_err(),
942 "5 base entries with 2 descriptors should fail"
943 );
944 let err = result.unwrap_err().to_string();
945 assert!(
946 err.contains("5") && err.contains("2"),
947 "error should mention counts: {err}"
948 );
949 }
950
951 #[test]
952 fn test_base_fewer_entries_than_descriptors_auto_extended() {
953 let meta = GlobalMetadata {
955 version: 2,
956 base: vec![],
957 ..Default::default()
958 };
959 let desc = make_descriptor(vec![2]);
960 let data = vec![0u8; 8];
961 let options = EncodeOptions {
962 hash_algorithm: None,
963 ..Default::default()
964 };
965 let msg = encode(
966 &meta,
967 &[
968 (&desc, data.as_slice()),
969 (&desc, data.as_slice()),
970 (&desc, data.as_slice()),
971 ],
972 &options,
973 )
974 .unwrap();
975
976 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
977 assert_eq!(decoded.base.len(), 3);
978 for entry in &decoded.base {
980 assert!(
981 entry.contains_key("_reserved_"),
982 "auto-extended base entry should have _reserved_"
983 );
984 }
985 }
986
987 #[test]
988 fn test_base_entry_with_top_level_key_names_no_collision() {
989 let mut entry = BTreeMap::new();
991 entry.insert(
992 "version".to_string(),
993 ciborium::Value::Text("my-version".to_string()),
994 );
995 entry.insert(
996 "base".to_string(),
997 ciborium::Value::Text("not-the-real-base".to_string()),
998 );
999 let meta = GlobalMetadata {
1000 version: 2,
1001 base: vec![entry],
1002 ..Default::default()
1003 };
1004 let desc = make_descriptor(vec![2]);
1005 let data = vec![0u8; 8];
1006 let options = EncodeOptions {
1007 hash_algorithm: None,
1008 ..Default::default()
1009 };
1010 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1011 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1012
1013 assert_eq!(decoded.version, 2);
1015 assert_eq!(
1017 decoded.base[0].get("version"),
1018 Some(&ciborium::Value::Text("my-version".to_string()))
1019 );
1020 assert_eq!(
1021 decoded.base[0].get("base"),
1022 Some(&ciborium::Value::Text("not-the-real-base".to_string()))
1023 );
1024 }
1025
1026 #[test]
1027 fn test_base_entry_with_deeply_nested_reserved_allowed() {
1028 let nested = ciborium::Value::Map(vec![(
1031 ciborium::Value::Text("_reserved_".to_string()),
1032 ciborium::Value::Text("nested-is-ok".to_string()),
1033 )]);
1034 let mut entry = BTreeMap::new();
1035 entry.insert("foo".to_string(), nested);
1036 let meta = GlobalMetadata {
1037 version: 2,
1038 base: vec![entry],
1039 ..Default::default()
1040 };
1041 let desc = make_descriptor(vec![2]);
1042 let data = vec![0u8; 8];
1043 let options = EncodeOptions {
1044 hash_algorithm: None,
1045 ..Default::default()
1046 };
1047 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1049 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1050 let foo = decoded.base[0].get("foo").unwrap();
1052 if let ciborium::Value::Map(pairs) = foo {
1053 assert_eq!(pairs.len(), 1);
1054 } else {
1055 panic!("expected map for foo");
1056 }
1057 }
1058
1059 #[test]
1062 fn test_reserved_rejected_at_message_level() {
1063 let mut reserved = BTreeMap::new();
1064 reserved.insert(
1065 "rogue".to_string(),
1066 ciborium::Value::Text("bad".to_string()),
1067 );
1068 let meta = GlobalMetadata {
1069 version: 2,
1070 reserved,
1071 ..Default::default()
1072 };
1073 let desc = make_descriptor(vec![2]);
1074 let data = vec![0u8; 8];
1075 let result = encode(
1076 &meta,
1077 &[(&desc, data.as_slice())],
1078 &EncodeOptions::default(),
1079 );
1080 assert!(result.is_err());
1081 let err = result.unwrap_err().to_string();
1082 assert!(
1083 err.contains("_reserved_") && err.contains("message level"),
1084 "error: {err}"
1085 );
1086 }
1087
1088 #[test]
1089 fn test_reserved_rejected_in_base_entry() {
1090 let mut entry = BTreeMap::new();
1091 entry.insert("_reserved_".to_string(), ciborium::Value::Map(vec![]));
1092 let meta = GlobalMetadata {
1093 version: 2,
1094 base: vec![entry],
1095 ..Default::default()
1096 };
1097 let desc = make_descriptor(vec![2]);
1098 let data = vec![0u8; 8];
1099 let result = encode(
1100 &meta,
1101 &[(&desc, data.as_slice())],
1102 &EncodeOptions::default(),
1103 );
1104 assert!(result.is_err());
1105 let err = result.unwrap_err().to_string();
1106 assert!(
1107 err.contains("_reserved_") && err.contains("base[0]"),
1108 "error: {err}"
1109 );
1110 }
1111
1112 #[test]
1113 fn test_reserved_tensor_has_four_keys_after_encode() {
1114 let meta = GlobalMetadata::default();
1115 let desc = make_descriptor(vec![3, 4]);
1116 let data = vec![0u8; 3 * 4 * 4]; let options = EncodeOptions {
1118 hash_algorithm: None,
1119 ..Default::default()
1120 };
1121 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1122 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1123
1124 let reserved = decoded.base[0]
1125 .get("_reserved_")
1126 .expect("_reserved_ missing");
1127 if let ciborium::Value::Map(pairs) = reserved {
1128 let tensor_entry = pairs
1130 .iter()
1131 .find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1132 assert!(tensor_entry.is_some(), "missing tensor key in _reserved_");
1133 if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
1134 let keys: Vec<String> = tensor_pairs
1135 .iter()
1136 .filter_map(|(k, _)| {
1137 if let ciborium::Value::Text(s) = k {
1138 Some(s.clone())
1139 } else {
1140 None
1141 }
1142 })
1143 .collect();
1144 assert_eq!(keys.len(), 4, "tensor should have 4 keys, got: {keys:?}");
1145 assert!(keys.contains(&"ndim".to_string()));
1146 assert!(keys.contains(&"shape".to_string()));
1147 assert!(keys.contains(&"strides".to_string()));
1148 assert!(keys.contains(&"dtype".to_string()));
1149 } else {
1150 panic!("tensor is not a map");
1151 }
1152 } else {
1153 panic!("_reserved_ is not a map");
1154 }
1155 }
1156
1157 #[test]
1158 fn test_reserved_tensor_scalar_ndim_zero() {
1159 let desc = DataObjectDescriptor {
1161 obj_type: "ntensor".to_string(),
1162 ndim: 0,
1163 shape: vec![],
1164 strides: vec![],
1165 dtype: Dtype::Float32,
1166 byte_order: ByteOrder::native(),
1167 encoding: "none".to_string(),
1168 filter: "none".to_string(),
1169 compression: "none".to_string(),
1170 params: BTreeMap::new(),
1171 hash: None,
1172 };
1173 let data = vec![0u8; 4]; let meta = GlobalMetadata::default();
1175 let options = EncodeOptions {
1176 hash_algorithm: None,
1177 ..Default::default()
1178 };
1179 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1180 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1181
1182 let reserved = decoded.base[0]
1183 .get("_reserved_")
1184 .expect("_reserved_ missing");
1185 if let ciborium::Value::Map(pairs) = reserved {
1186 let tensor_entry = pairs
1187 .iter()
1188 .find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1189 if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
1190 let ndim = tensor_pairs
1192 .iter()
1193 .find(|(k, _)| *k == ciborium::Value::Text("ndim".to_string()));
1194 assert!(
1195 matches!(ndim, Some((_, ciborium::Value::Integer(i))) if i128::from(*i) == 0),
1196 "ndim should be 0 for scalar"
1197 );
1198 let shape = tensor_pairs
1200 .iter()
1201 .find(|(k, _)| *k == ciborium::Value::Text("shape".to_string()));
1202 assert!(
1203 matches!(shape, Some((_, ciborium::Value::Array(a))) if a.is_empty()),
1204 "shape should be [] for scalar"
1205 );
1206 } else {
1207 panic!("tensor missing or not a map");
1208 }
1209 } else {
1210 panic!("_reserved_ is not a map");
1211 }
1212 }
1213
1214 #[test]
1217 fn test_extra_with_keys_colliding_with_base_entry_keys() {
1218 let mut entry = BTreeMap::new();
1220 entry.insert(
1221 "mars".to_string(),
1222 ciborium::Value::Text("base-mars".to_string()),
1223 );
1224 let mut extra = BTreeMap::new();
1225 extra.insert(
1226 "mars".to_string(),
1227 ciborium::Value::Text("extra-mars".to_string()),
1228 );
1229 let meta = GlobalMetadata {
1230 version: 2,
1231 base: vec![entry],
1232 extra,
1233 ..Default::default()
1234 };
1235 let desc = make_descriptor(vec![2]);
1236 let data = vec![0u8; 8];
1237 let options = EncodeOptions {
1238 hash_algorithm: None,
1239 ..Default::default()
1240 };
1241 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1242 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1243
1244 assert_eq!(
1245 decoded.base[0].get("mars"),
1246 Some(&ciborium::Value::Text("base-mars".to_string()))
1247 );
1248 assert_eq!(
1249 decoded.extra.get("mars"),
1250 Some(&ciborium::Value::Text("extra-mars".to_string()))
1251 );
1252 }
1253
1254 #[test]
1255 fn test_empty_extra_omitted_from_cbor() {
1256 let meta = GlobalMetadata {
1257 version: 2,
1258 extra: BTreeMap::new(),
1259 ..Default::default()
1260 };
1261 let desc = make_descriptor(vec![2]);
1262 let data = vec![0u8; 8];
1263 let options = EncodeOptions {
1264 hash_algorithm: None,
1265 ..Default::default()
1266 };
1267 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1268 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1269 assert!(decoded.extra.is_empty());
1270 }
1271
1272 #[test]
1273 fn test_extra_with_nested_maps_round_trips() {
1274 let nested = ciborium::Value::Map(vec![
1275 (
1276 ciborium::Value::Text("key1".to_string()),
1277 ciborium::Value::Integer(42.into()),
1278 ),
1279 (
1280 ciborium::Value::Text("key2".to_string()),
1281 ciborium::Value::Map(vec![(
1282 ciborium::Value::Text("deep".to_string()),
1283 ciborium::Value::Bool(true),
1284 )]),
1285 ),
1286 ]);
1287 let mut extra = BTreeMap::new();
1288 extra.insert("nested".to_string(), nested.clone());
1289 let meta = GlobalMetadata {
1290 version: 2,
1291 extra,
1292 ..Default::default()
1293 };
1294 let desc = make_descriptor(vec![2]);
1295 let data = vec![0u8; 8];
1296 let options = EncodeOptions {
1297 hash_algorithm: None,
1298 ..Default::default()
1299 };
1300 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1301 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1302 assert!(decoded.extra.contains_key("nested"));
1304 }
1305
1306 #[test]
1309 fn test_old_common_payload_keys_silently_ignored() {
1310 use ciborium::Value;
1314 let cbor = Value::Map(vec![
1315 (Value::Text("version".to_string()), Value::Integer(2.into())),
1316 (Value::Text("common".to_string()), Value::Map(vec![])),
1317 (Value::Text("payload".to_string()), Value::Array(vec![])),
1318 ]);
1319 let mut bytes = Vec::new();
1320 ciborium::into_writer(&cbor, &mut bytes).unwrap();
1321
1322 let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
1323 assert_eq!(decoded.version, 2);
1324 assert!(decoded.base.is_empty());
1325 assert!(decoded.extra.is_empty());
1326 assert!(decoded.reserved.is_empty());
1327 }
1328
1329 #[test]
1330 fn test_old_reserved_key_name_ignored() {
1331 use ciborium::Value;
1333 let cbor = Value::Map(vec![
1334 (Value::Text("version".to_string()), Value::Integer(2.into())),
1335 (
1336 Value::Text("reserved".to_string()),
1337 Value::Map(vec![(
1338 Value::Text("rogue".to_string()),
1339 Value::Text("value".to_string()),
1340 )]),
1341 ),
1342 ]);
1343 let mut bytes = Vec::new();
1344 ciborium::into_writer(&cbor, &mut bytes).unwrap();
1345
1346 let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
1347 assert!(
1348 decoded.reserved.is_empty(),
1349 "old 'reserved' key should be ignored"
1350 );
1351 }
1352
1353 #[test]
1356 fn test_reserved_rejected_in_second_base_entry_only() {
1357 let mut entry0 = BTreeMap::new();
1359 entry0.insert("clean".to_string(), ciborium::Value::Text("ok".to_string()));
1360 let mut entry1 = BTreeMap::new();
1361 entry1.insert("_reserved_".to_string(), ciborium::Value::Map(vec![]));
1362 let meta = GlobalMetadata {
1363 version: 2,
1364 base: vec![entry0, entry1],
1365 ..Default::default()
1366 };
1367 let desc = make_descriptor(vec![2]);
1368 let data = vec![0u8; 8];
1369 let result = encode(
1370 &meta,
1371 &[(&desc, data.as_slice()), (&desc, data.as_slice())],
1372 &EncodeOptions::default(),
1373 );
1374 assert!(result.is_err());
1375 let err = result.unwrap_err().to_string();
1376 assert!(
1377 err.contains("base[1]"),
1378 "error should mention base[1]: {err}"
1379 );
1380 }
1381
1382 #[test]
1383 fn test_reserved_accepted_when_all_base_entries_clean() {
1384 let mut e0 = BTreeMap::new();
1386 e0.insert(
1387 "key0".to_string(),
1388 ciborium::Value::Text("val0".to_string()),
1389 );
1390 let mut e1 = BTreeMap::new();
1391 e1.insert(
1392 "key1".to_string(),
1393 ciborium::Value::Text("val1".to_string()),
1394 );
1395 let meta = GlobalMetadata {
1396 version: 2,
1397 base: vec![e0, e1],
1398 ..Default::default()
1399 };
1400 let desc = make_descriptor(vec![2]);
1401 let data = vec![0u8; 8];
1402 let options = EncodeOptions {
1403 hash_algorithm: None,
1404 ..Default::default()
1405 };
1406 let msg = encode(
1407 &meta,
1408 &[(&desc, data.as_slice()), (&desc, data.as_slice())],
1409 &options,
1410 )
1411 .unwrap();
1412 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1413 assert_eq!(decoded.base.len(), 2);
1414 assert!(decoded.base[0].contains_key("key0"));
1415 assert!(decoded.base[1].contains_key("key1"));
1416 }
1417
1418 #[test]
1421 fn test_reserved_tensor_dtype_strings_for_all_dtypes() {
1422 let dtypes_and_expected = [
1424 (Dtype::Float16, "float16"),
1425 (Dtype::Bfloat16, "bfloat16"),
1426 (Dtype::Float32, "float32"),
1427 (Dtype::Float64, "float64"),
1428 (Dtype::Complex64, "complex64"),
1429 (Dtype::Complex128, "complex128"),
1430 (Dtype::Int8, "int8"),
1431 (Dtype::Int16, "int16"),
1432 (Dtype::Int32, "int32"),
1433 (Dtype::Int64, "int64"),
1434 (Dtype::Uint8, "uint8"),
1435 (Dtype::Uint16, "uint16"),
1436 (Dtype::Uint32, "uint32"),
1437 (Dtype::Uint64, "uint64"),
1438 ];
1439
1440 for (dtype, expected_str) in dtypes_and_expected {
1441 let byte_width = dtype.byte_width();
1442 let num_elements: u64 = 4;
1443 let data_len = num_elements as usize * byte_width;
1444
1445 let desc = DataObjectDescriptor {
1446 obj_type: "ntensor".to_string(),
1447 ndim: 1,
1448 shape: vec![num_elements],
1449 strides: vec![1],
1450 dtype,
1451 byte_order: ByteOrder::native(),
1452 encoding: "none".to_string(),
1453 filter: "none".to_string(),
1454 compression: "none".to_string(),
1455 params: BTreeMap::new(),
1456 hash: None,
1457 };
1458 let data = vec![0u8; data_len];
1459 let meta = GlobalMetadata::default();
1460 let options = EncodeOptions {
1461 hash_algorithm: None,
1462 ..Default::default()
1463 };
1464 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1465 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1466
1467 let reserved = decoded.base[0]
1468 .get("_reserved_")
1469 .unwrap_or_else(|| panic!("_reserved_ missing for dtype {dtype}"));
1470 if let ciborium::Value::Map(pairs) = reserved {
1471 let tensor_entry = pairs
1472 .iter()
1473 .find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1474 if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
1475 let dtype_val = tensor_pairs
1476 .iter()
1477 .find(|(k, _)| *k == ciborium::Value::Text("dtype".to_string()));
1478 assert!(
1479 matches!(
1480 dtype_val,
1481 Some((_, ciborium::Value::Text(s))) if s == expected_str
1482 ),
1483 "dtype for {dtype} should be '{expected_str}', got: {dtype_val:?}"
1484 );
1485 } else {
1486 panic!("tensor missing or not a map for dtype {dtype}");
1487 }
1488 } else {
1489 panic!("_reserved_ is not a map for dtype {dtype}");
1490 }
1491 }
1492 }
1493
1494 #[test]
1497 fn test_global_metadata_serde_all_fields_populated() {
1498 use ciborium::Value;
1500
1501 let mut base_entry = BTreeMap::new();
1502 base_entry.insert("key".to_string(), Value::Text("base_val".to_string()));
1503 let mut reserved = BTreeMap::new();
1504 reserved.insert("encoder".to_string(), Value::Text("test".to_string()));
1505 let mut extra = BTreeMap::new();
1506 extra.insert("custom".to_string(), Value::Integer(42.into()));
1507
1508 let meta = GlobalMetadata {
1509 version: 2,
1510 base: vec![base_entry],
1511 reserved,
1512 extra,
1513 };
1514
1515 let cbor_bytes = crate::metadata::global_metadata_to_cbor(&meta).unwrap();
1517 let decoded: GlobalMetadata =
1518 crate::metadata::cbor_to_global_metadata(&cbor_bytes).unwrap();
1519
1520 assert_eq!(decoded.version, 2);
1521 assert_eq!(decoded.base.len(), 1);
1522 assert_eq!(
1523 decoded.base[0].get("key"),
1524 Some(&Value::Text("base_val".to_string()))
1525 );
1526 assert!(decoded.reserved.contains_key("encoder"));
1527 assert_eq!(
1528 decoded.extra.get("custom"),
1529 Some(&Value::Integer(42.into()))
1530 );
1531 }
1532
1533 #[test]
1536 fn test_provenance_fields_present_after_encode() {
1537 let meta = GlobalMetadata::default();
1538 let desc = make_descriptor(vec![2]);
1539 let data = vec![0u8; 8];
1540 let options = EncodeOptions {
1541 hash_algorithm: None,
1542 ..Default::default()
1543 };
1544 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1545 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1546
1547 assert!(decoded.reserved.contains_key("encoder"));
1549 assert!(decoded.reserved.contains_key("time"));
1550 assert!(decoded.reserved.contains_key("uuid"));
1551
1552 if let ciborium::Value::Map(pairs) = decoded.reserved.get("encoder").unwrap() {
1554 let has_name = pairs
1555 .iter()
1556 .any(|(k, _)| *k == ciborium::Value::Text("name".to_string()));
1557 let has_version = pairs
1558 .iter()
1559 .any(|(k, _)| *k == ciborium::Value::Text("version".to_string()));
1560 assert!(has_name, "encoder map should have 'name' key");
1561 assert!(has_version, "encoder map should have 'version' key");
1562 } else {
1563 panic!("encoder should be a map");
1564 }
1565
1566 if let ciborium::Value::Text(uuid_str) = decoded.reserved.get("uuid").unwrap() {
1568 assert_eq!(uuid_str.len(), 36, "UUID should be 36 chars: {uuid_str}");
1569 assert_eq!(
1570 uuid_str.chars().filter(|c| *c == '-').count(),
1571 4,
1572 "UUID should have 4 hyphens: {uuid_str}"
1573 );
1574 } else {
1575 panic!("uuid should be a text");
1576 }
1577
1578 if let ciborium::Value::Text(time_str) = decoded.reserved.get("time").unwrap() {
1580 assert!(
1581 time_str.ends_with('Z'),
1582 "time should end with Z: {time_str}"
1583 );
1584 assert!(
1585 time_str.contains('T'),
1586 "time should contain T separator: {time_str}"
1587 );
1588 } else {
1589 panic!("time should be a text");
1590 }
1591 }
1592
1593 #[test]
1594 fn test_both_reserved_and_reserved_underscore_only_new_captured() {
1595 use ciborium::Value;
1597 let cbor = Value::Map(vec![
1598 (
1599 Value::Text("_reserved_".to_string()),
1600 Value::Map(vec![(
1601 Value::Text("encoder".to_string()),
1602 Value::Text("tensogram".to_string()),
1603 )]),
1604 ),
1605 (
1606 Value::Text("reserved".to_string()),
1607 Value::Map(vec![(
1608 Value::Text("old".to_string()),
1609 Value::Text("ignored".to_string()),
1610 )]),
1611 ),
1612 (Value::Text("version".to_string()), Value::Integer(2.into())),
1613 ]);
1614 let mut bytes = Vec::new();
1615 ciborium::into_writer(&cbor, &mut bytes).unwrap();
1616
1617 let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
1618 assert!(decoded.reserved.contains_key("encoder"));
1619 assert!(!decoded.reserved.contains_key("old"));
1620 }
1621
1622 #[test]
1629 fn test_encode_pre_encoded_roundtrip_simple_packing() {
1630 let desc = make_descriptor(vec![4]);
1632 let raw_data: Vec<u8> = vec![0u8; 4 * 4]; let meta = GlobalMetadata::default();
1635 let options = EncodeOptions::default();
1636
1637 let msg1 = encode(&meta, &[(&desc, raw_data.as_slice())], &options).unwrap();
1639
1640 let (_, objects1) = decode(&msg1, &DecodeOptions::default()).unwrap();
1642 let (decoded_desc1, decoded_payload1) = &objects1[0];
1643
1644 let msg2 = encode_pre_encoded(
1646 &meta,
1647 &[(&decoded_desc1.clone(), decoded_payload1.as_slice())],
1648 &options,
1649 )
1650 .unwrap();
1651
1652 let (_, objects2) = decode(&msg2, &DecodeOptions::default()).unwrap();
1654 let (_, decoded_payload2) = &objects2[0];
1655
1656 assert_eq!(
1659 decoded_payload1, decoded_payload2,
1660 "decoded payloads should be equal after encode/re-encode roundtrip"
1661 );
1662 }
1663
1664 #[test]
1666 fn test_encode_pre_encoded_rejects_emit_preceders() {
1667 let desc = make_descriptor(vec![2]);
1668 let data = vec![0u8; 8];
1669 let meta = GlobalMetadata::default();
1670 let options = EncodeOptions {
1671 emit_preceders: true,
1672 ..Default::default()
1673 };
1674 let result = encode_pre_encoded(&meta, &[(&desc, data.as_slice())], &options);
1675 assert!(
1676 result.is_err(),
1677 "encode_pre_encoded with emit_preceders=true should fail"
1678 );
1679 let err = result.unwrap_err().to_string();
1680 assert!(
1681 err.contains("emit_preceders"),
1682 "error should mention emit_preceders: {err}"
1683 );
1684 }
1685
1686 #[test]
1688 fn test_encode_pre_encoded_overwrites_caller_hash() {
1689 let mut desc = make_descriptor(vec![2]);
1690 desc.hash = Some(HashDescriptor {
1692 hash_type: "xxh3".to_string(),
1693 value: "deadbeefdeadbeef".to_string(),
1694 });
1695
1696 let data = vec![0xAB_u8; 8]; let meta = GlobalMetadata::default();
1698 let options = EncodeOptions::default(); let msg = encode_pre_encoded(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1701 let (_, objects) = decode(&msg, &DecodeOptions::default()).unwrap();
1702 let (decoded_desc, decoded_payload) = &objects[0];
1703
1704 let computed_hash = match options.hash_algorithm {
1706 Some(algorithm) => compute_hash(decoded_payload, algorithm),
1707 None => panic!("expected hash algorithm"),
1708 };
1709
1710 let stored_hash = decoded_desc
1711 .hash
1712 .as_ref()
1713 .expect("hash should be present in decoded descriptor")
1714 .value
1715 .clone();
1716
1717 assert_ne!(
1718 stored_hash, "deadbeefdeadbeef",
1719 "caller's garbage hash must be overwritten"
1720 );
1721 assert_eq!(
1722 stored_hash, computed_hash,
1723 "library-computed hash must match hash over decoded payload"
1724 );
1725 }
1726
1727 #[test]
1728 fn test_validate_szip_block_offsets_happy_path() {
1729 let mut params = BTreeMap::new();
1730 params.insert(
1731 "szip_block_offsets".to_string(),
1732 ciborium::Value::Array(vec![0u64, 100, 200].into_iter().map(|n| n.into()).collect()),
1733 );
1734
1735 assert!(validate_szip_block_offsets(¶ms, 100).is_ok());
1736 }
1737
1738 #[test]
1739 fn test_validate_szip_block_offsets_missing_key() {
1740 let params = BTreeMap::new();
1741
1742 let err = validate_szip_block_offsets(¶ms, 100)
1743 .unwrap_err()
1744 .to_string();
1745 assert!(
1746 err.contains("missing") && err.contains("szip_block_offsets"),
1747 "error: {err}"
1748 );
1749 }
1750
1751 #[test]
1752 fn test_validate_szip_block_offsets_not_array() {
1753 let mut params = BTreeMap::new();
1754 params.insert(
1755 "szip_block_offsets".to_string(),
1756 ciborium::Value::Integer(0.into()),
1757 );
1758
1759 let err = validate_szip_block_offsets(¶ms, 100)
1760 .unwrap_err()
1761 .to_string();
1762 assert!(
1763 err.contains("array") && err.contains("szip_block_offsets"),
1764 "error: {err}"
1765 );
1766 }
1767
1768 #[test]
1769 fn test_validate_szip_block_offsets_non_integer_element() {
1770 let mut params = BTreeMap::new();
1771 params.insert(
1772 "szip_block_offsets".to_string(),
1773 ciborium::Value::Array(vec![
1774 ciborium::Value::Integer(0.into()),
1775 ciborium::Value::Text("x".to_string()),
1776 ]),
1777 );
1778
1779 let err = validate_szip_block_offsets(¶ms, 100)
1780 .unwrap_err()
1781 .to_string();
1782 assert!(
1783 err.contains("integer") && err.contains("szip_block_offsets"),
1784 "error: {err}"
1785 );
1786 }
1787
1788 #[test]
1789 fn test_validate_szip_block_offsets_nonzero_first() {
1790 let mut params = BTreeMap::new();
1791 params.insert(
1792 "szip_block_offsets".to_string(),
1793 ciborium::Value::Array(vec![5u64, 100, 200].into_iter().map(|n| n.into()).collect()),
1794 );
1795
1796 let err = validate_szip_block_offsets(¶ms, 100)
1797 .unwrap_err()
1798 .to_string();
1799 assert!(
1800 err.contains("must be 0") && err.contains("got 5"),
1801 "error: {err}"
1802 );
1803 }
1804
1805 #[test]
1806 fn test_validate_szip_block_offsets_non_monotonic() {
1807 let mut params = BTreeMap::new();
1808 params.insert(
1809 "szip_block_offsets".to_string(),
1810 ciborium::Value::Array(vec![0u64, 100, 50].into_iter().map(|n| n.into()).collect()),
1811 );
1812
1813 let err = validate_szip_block_offsets(¶ms, 100)
1814 .unwrap_err()
1815 .to_string();
1816 assert!(
1817 err.contains("increasing") || err.contains("monotonic"),
1818 "error: {err}"
1819 );
1820 }
1821
1822 #[test]
1823 fn test_validate_szip_block_offsets_offset_beyond_bound() {
1824 let mut params = BTreeMap::new();
1825 params.insert(
1826 "szip_block_offsets".to_string(),
1827 ciborium::Value::Array(vec![0u64, 100, 801].into_iter().map(|n| n.into()).collect()),
1828 );
1829
1830 let err = validate_szip_block_offsets(¶ms, 100)
1831 .unwrap_err()
1832 .to_string();
1833 assert!(err.contains("800") && err.contains("801"), "error: {err}");
1834 }
1835
1836 #[test]
1837 fn test_validate_no_szip_offsets_for_non_szip_rejects() {
1838 let mut params = BTreeMap::new();
1839 params.insert(
1840 "szip_block_offsets".to_string(),
1841 ciborium::Value::Array(vec![0u64, 1].into_iter().map(|n| n.into()).collect()),
1842 );
1843 let desc = DataObjectDescriptor {
1844 obj_type: "ntensor".to_string(),
1845 ndim: 1,
1846 shape: vec![2],
1847 strides: vec![1],
1848 dtype: Dtype::Float32,
1849 byte_order: ByteOrder::native(),
1850 encoding: "none".to_string(),
1851 filter: "none".to_string(),
1852 compression: "zstd".to_string(),
1853 params,
1854 hash: None,
1855 };
1856
1857 let err = validate_no_szip_offsets_for_non_szip(&desc)
1858 .unwrap_err()
1859 .to_string();
1860 assert!(
1861 err.contains("szip_block_offsets") && err.contains("zstd"),
1862 "error: {err}"
1863 );
1864 }
1865
1866 #[test]
1867 fn test_validate_no_szip_offsets_for_non_szip_allows_szip() {
1868 let mut params = BTreeMap::new();
1869 params.insert(
1870 "szip_block_offsets".to_string(),
1871 ciborium::Value::Array(vec![0u64, 1].into_iter().map(|n| n.into()).collect()),
1872 );
1873 let desc = DataObjectDescriptor {
1874 obj_type: "ntensor".to_string(),
1875 ndim: 1,
1876 shape: vec![2],
1877 strides: vec![1],
1878 dtype: Dtype::Float32,
1879 byte_order: ByteOrder::native(),
1880 encoding: "none".to_string(),
1881 filter: "none".to_string(),
1882 compression: "szip".to_string(),
1883 params,
1884 hash: None,
1885 };
1886
1887 assert!(validate_no_szip_offsets_for_non_szip(&desc).is_ok());
1888 }
1889}