1use std::collections::BTreeMap;
10
11use crate::dtype::Dtype;
12use crate::error::{Result, TensogramError};
13use crate::framing::{self, EncodedObject};
14use crate::hash::{HashAlgorithm, compute_hash};
15use crate::metadata::RESERVED_KEY;
16use crate::types::{DataObjectDescriptor, GlobalMetadata, HashDescriptor};
17#[cfg(feature = "blosc2")]
18use tensogram_encodings::pipeline::Blosc2Codec;
19#[cfg(feature = "sz3")]
20use tensogram_encodings::pipeline::Sz3ErrorBound;
21#[cfg(feature = "zfp")]
22use tensogram_encodings::pipeline::ZfpMode;
23use tensogram_encodings::pipeline::{
24 self, CompressionType, EncodingType, FilterType, PipelineConfig,
25};
26use tensogram_encodings::simple_packing::SimplePackingParams;
27
28#[derive(Debug, Clone)]
30pub struct EncodeOptions {
31 pub hash_algorithm: Option<HashAlgorithm>,
33 pub emit_preceders: bool,
40 pub compression_backend: pipeline::CompressionBackend,
48 pub threads: u32,
65 pub parallel_threshold_bytes: Option<usize>,
72 pub reject_nan: bool,
102 pub reject_inf: bool,
116}
117
118impl Default for EncodeOptions {
119 fn default() -> Self {
120 Self {
121 hash_algorithm: Some(HashAlgorithm::Xxh3),
122 emit_preceders: false,
123 compression_backend: pipeline::CompressionBackend::default(),
124 threads: 0,
125 parallel_threshold_bytes: None,
126 reject_nan: false,
127 reject_inf: false,
128 }
129 }
130}
131
132pub(crate) fn validate_object(desc: &DataObjectDescriptor, data_len: usize) -> Result<()> {
133 if desc.obj_type.is_empty() {
134 return Err(TensogramError::Metadata(
135 "obj_type must not be empty".to_string(),
136 ));
137 }
138 if desc.ndim as usize != desc.shape.len() {
139 return Err(TensogramError::Metadata(format!(
140 "ndim {} does not match shape.len() {}",
141 desc.ndim,
142 desc.shape.len()
143 )));
144 }
145 if desc.strides.len() != desc.shape.len() {
146 return Err(TensogramError::Metadata(format!(
147 "strides.len() {} does not match shape.len() {}",
148 desc.strides.len(),
149 desc.shape.len()
150 )));
151 }
152 if desc.encoding == "none" {
153 let product = desc
154 .shape
155 .iter()
156 .try_fold(1u64, |acc, &x| acc.checked_mul(x))
157 .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
158 if desc.dtype.byte_width() > 0 {
159 let expected_bytes = product
160 .checked_mul(desc.dtype.byte_width() as u64)
161 .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
162 if expected_bytes != data_len as u64 {
163 return Err(TensogramError::Metadata(format!(
164 "data_len {data_len} does not match expected {expected_bytes} bytes from shape and dtype"
165 )));
166 }
167 } else if desc.dtype == crate::Dtype::Bitmask {
168 let expected_bytes = product.div_ceil(8);
170 if expected_bytes != data_len as u64 {
171 return Err(TensogramError::Metadata(format!(
172 "data_len {data_len} does not match expected {expected_bytes} bytes for bitmask (ceil({product}/8))"
173 )));
174 }
175 }
176 }
177 Ok(())
178}
179
180#[derive(Debug, Clone, Copy)]
181enum EncodeMode {
182 Raw,
183 PreEncoded,
184}
185
186fn encode_one_object(
194 desc: &DataObjectDescriptor,
195 data: &[u8],
196 mode: EncodeMode,
197 options: &EncodeOptions,
198 intra_codec_threads: u32,
199) -> Result<EncodedObject> {
200 validate_object(desc, data.len())?;
201
202 if matches!(mode, EncodeMode::Raw) && (options.reject_nan || options.reject_inf) {
207 let parallel = crate::parallel::should_parallelise(
208 intra_codec_threads,
209 data.len(),
210 options.parallel_threshold_bytes,
211 );
212 crate::strict_finite::scan(
213 data,
214 desc.dtype,
215 desc.byte_order,
216 options.reject_nan,
217 options.reject_inf,
218 parallel,
219 )?;
220 }
221
222 let shape_product = desc
223 .shape
224 .iter()
225 .try_fold(1u64, |acc, &x| acc.checked_mul(x))
226 .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
227 let num_elements = usize::try_from(shape_product)
228 .map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))?;
229 let dtype = desc.dtype;
230
231 let mut config = build_pipeline_config_with_backend(
232 desc,
233 num_elements,
234 dtype,
235 options.compression_backend,
236 intra_codec_threads,
237 )?;
238
239 let mut final_desc = desc.clone();
241
242 let inline_hash_requested = matches!(mode, EncodeMode::Raw)
254 && match options.hash_algorithm {
255 Some(HashAlgorithm::Xxh3) => true,
256 None => false,
257 };
258 config.compute_hash = inline_hash_requested;
259
260 let (encoded_payload, inline_hash) = match mode {
261 EncodeMode::Raw => {
262 let result = pipeline::encode_pipeline(data, &config)
264 .map_err(|e| TensogramError::Encoding(e.to_string()))?;
265
266 if let Some(offsets) = &result.block_offsets {
268 final_desc.params.insert(
269 "szip_block_offsets".to_string(),
270 ciborium::Value::Array(
271 offsets
272 .iter()
273 .map(|&o| ciborium::Value::Integer(o.into()))
274 .collect(),
275 ),
276 );
277 }
278
279 (result.encoded_bytes, result.hash)
280 }
281 EncodeMode::PreEncoded => {
282 validate_no_szip_offsets_for_non_szip(desc)?;
287 if desc.compression == "szip" && desc.params.contains_key("szip_block_offsets") {
288 validate_szip_block_offsets(&desc.params, data.len())?;
289 }
290 (data.to_vec(), None)
291 }
292 };
293
294 if let Some(algorithm) = options.hash_algorithm {
299 let hash_value = match inline_hash {
300 Some(digest) => crate::hash::format_xxh3_digest(digest),
301 None => compute_hash(&encoded_payload, algorithm),
302 };
303 final_desc.hash = Some(HashDescriptor {
304 hash_type: algorithm.as_str().to_string(),
305 value: hash_value,
306 });
307 }
308
309 Ok(EncodedObject {
310 descriptor: final_desc,
311 encoded_payload,
312 })
313}
314
315fn encode_inner(
316 global_metadata: &GlobalMetadata,
317 descriptors: &[(&DataObjectDescriptor, &[u8])],
318 options: &EncodeOptions,
319 mode: EncodeMode,
320) -> Result<Vec<u8>> {
321 if options.emit_preceders {
324 return Err(TensogramError::Encoding(
325 "emit_preceders is not supported in buffered mode; use StreamingEncoder::write_preceder() instead".to_string(),
326 ));
327 }
328
329 if matches!(mode, EncodeMode::PreEncoded) && (options.reject_nan || options.reject_inf) {
336 return Err(TensogramError::Encoding(
337 "reject_nan / reject_inf do not apply to encode_pre_encoded: \
338 pre-encoded bytes are opaque to the library. Clear these \
339 flags before calling encode_pre_encoded, or use encode() \
340 on raw data."
341 .to_string(),
342 ));
343 }
344
345 let budget = crate::parallel::resolve_budget(options.threads);
352 let total_bytes: usize = descriptors.iter().map(|(_, d)| d.len()).sum();
353 let parallel =
354 crate::parallel::should_parallelise(budget, total_bytes, options.parallel_threshold_bytes);
355
356 let any_axis_b = descriptors
357 .iter()
358 .any(|(d, _)| crate::parallel::is_axis_b_friendly(&d.encoding, &d.filter, &d.compression));
359 let use_axis_a = parallel && crate::parallel::use_axis_a(descriptors.len(), budget, any_axis_b);
360
361 let intra_codec_threads = if parallel && !use_axis_a { budget } else { 0 };
365
366 let encode_one = |(desc, data): &(&DataObjectDescriptor, &[u8])| {
367 encode_one_object(desc, data, mode, options, intra_codec_threads)
368 };
369
370 let encoded_objects: Vec<EncodedObject> = if use_axis_a {
371 #[cfg(feature = "threads")]
375 {
376 use rayon::prelude::*;
377 crate::parallel::with_pool(budget, || {
378 descriptors
379 .par_iter()
380 .map(&encode_one)
381 .collect::<Result<Vec<_>>>()
382 })?
383 }
384 #[cfg(not(feature = "threads"))]
385 {
386 descriptors.iter().map(encode_one).collect::<Result<_>>()?
387 }
388 } else {
389 crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, || {
394 descriptors.iter().map(encode_one).collect::<Result<_>>()
395 })?
396 };
397
398 validate_no_client_reserved(global_metadata)?;
400
401 if global_metadata.base.len() > descriptors.len() {
405 return Err(TensogramError::Metadata(format!(
406 "metadata base has {} entries but only {} descriptors provided; \
407 extra base entries would be discarded",
408 global_metadata.base.len(),
409 descriptors.len()
410 )));
411 }
412
413 let mut enriched_meta = global_metadata.clone();
416 populate_base_entries(&mut enriched_meta.base, &encoded_objects);
417 populate_reserved_provenance(&mut enriched_meta.reserved);
418
419 framing::encode_message(&enriched_meta, &encoded_objects)
420}
421
422#[tracing::instrument(skip(global_metadata, descriptors, options), fields(objects = descriptors.len()))]
428pub fn encode(
429 global_metadata: &GlobalMetadata,
430 descriptors: &[(&DataObjectDescriptor, &[u8])],
431 options: &EncodeOptions,
432) -> Result<Vec<u8>> {
433 encode_inner(global_metadata, descriptors, options, EncodeMode::Raw)
434}
435
436#[tracing::instrument(name = "encode_pre_encoded", skip_all, fields(num_objects = descriptors.len()))]
453pub fn encode_pre_encoded(
454 global_metadata: &GlobalMetadata,
455 descriptors: &[(&DataObjectDescriptor, &[u8])],
456 options: &EncodeOptions,
457) -> Result<Vec<u8>> {
458 encode_inner(
459 global_metadata,
460 descriptors,
461 options,
462 EncodeMode::PreEncoded,
463 )
464}
465
466fn validate_no_client_reserved(meta: &GlobalMetadata) -> Result<()> {
471 if !meta.reserved.is_empty() {
472 return Err(TensogramError::Metadata(format!(
473 "client code must not write to '{RESERVED_KEY}' at message level; \
474 this field is populated by the library"
475 )));
476 }
477 for (i, entry) in meta.base.iter().enumerate() {
478 if entry.contains_key(RESERVED_KEY) {
479 return Err(TensogramError::Metadata(format!(
480 "client code must not write to '{RESERVED_KEY}' in base[{i}]; \
481 this field is populated by the library"
482 )));
483 }
484 }
485 Ok(())
486}
487
488pub(crate) fn populate_base_entries(
494 base: &mut Vec<BTreeMap<String, ciborium::Value>>,
495 encoded_objects: &[crate::framing::EncodedObject],
496) {
497 use ciborium::Value;
498
499 base.resize_with(encoded_objects.len(), BTreeMap::new);
501
502 for (entry, obj) in base.iter_mut().zip(encoded_objects.iter()) {
503 let desc = &obj.descriptor;
504
505 let tensor_map = Value::Map(vec![
506 (
507 Value::Text("ndim".to_string()),
508 Value::Integer(desc.ndim.into()),
509 ),
510 (
511 Value::Text("shape".to_string()),
512 Value::Array(
513 desc.shape
514 .iter()
515 .map(|&d| Value::Integer(d.into()))
516 .collect(),
517 ),
518 ),
519 (
520 Value::Text("strides".to_string()),
521 Value::Array(
522 desc.strides
523 .iter()
524 .map(|&s| Value::Integer(s.into()))
525 .collect(),
526 ),
527 ),
528 (
529 Value::Text("dtype".to_string()),
530 Value::Text(desc.dtype.to_string()),
531 ),
532 ]);
533
534 let reserved_map = Value::Map(vec![(Value::Text("tensor".to_string()), tensor_map)]);
535
536 entry.insert(RESERVED_KEY.to_string(), reserved_map);
537 }
538}
539
540pub(crate) fn populate_reserved_provenance(reserved: &mut BTreeMap<String, ciborium::Value>) {
551 use ciborium::Value;
552 #[cfg(not(target_arch = "wasm32"))]
553 use std::time::SystemTime;
554
555 let version_str = env!("CARGO_PKG_VERSION");
557 let encoder_map = Value::Map(vec![
558 (
559 Value::Text("name".to_string()),
560 Value::Text("tensogram".to_string()),
561 ),
562 (
563 Value::Text("version".to_string()),
564 Value::Text(version_str.to_string()),
565 ),
566 ]);
567 reserved.insert("encoder".to_string(), encoder_map);
568
569 #[cfg(not(target_arch = "wasm32"))]
574 {
575 let secs = SystemTime::now()
576 .duration_since(SystemTime::UNIX_EPOCH)
577 .unwrap_or_default()
578 .as_secs();
579
580 let days = secs / 86400;
583 let day_secs = secs % 86400;
584 let hours = day_secs / 3600;
585 let minutes = (day_secs % 3600) / 60;
586 let seconds = day_secs % 60;
587 let (y, m, d) = civil_from_days(days as i64);
589 let timestamp = format!("{y:04}-{m:02}-{d:02}T{hours:02}:{minutes:02}:{seconds:02}Z");
590 reserved.insert("time".to_string(), Value::Text(timestamp));
591 }
592
593 let id = uuid::Uuid::new_v4();
595 reserved.insert("uuid".to_string(), Value::Text(id.to_string()));
596}
597
598#[cfg(not(target_arch = "wasm32"))]
601fn civil_from_days(days: i64) -> (i64, u32, u32) {
602 let z = days + 719468;
603 let era = if z >= 0 { z } else { z - 146096 } / 146097;
604 let doe = (z - era * 146097) as u32;
607 let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
608 let y = yoe as i64 + era * 400;
609 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
610 let mp = (5 * doy + 2) / 153;
611 let d = doy - (153 * mp + 2) / 5 + 1;
612 let m = if mp < 10 { mp + 3 } else { mp - 9 };
613 let y = if m <= 2 { y + 1 } else { y };
614 (y, m, d)
615}
616
617pub(crate) fn build_pipeline_config(
618 desc: &DataObjectDescriptor,
619 num_values: usize,
620 dtype: Dtype,
621) -> Result<PipelineConfig> {
622 build_pipeline_config_with_backend(
623 desc,
624 num_values,
625 dtype,
626 pipeline::CompressionBackend::default(),
627 0,
628 )
629}
630
631pub(crate) fn build_pipeline_config_with_backend(
637 desc: &DataObjectDescriptor,
638 num_values: usize,
639 dtype: Dtype,
640 compression_backend: pipeline::CompressionBackend,
641 intra_codec_threads: u32,
642) -> Result<PipelineConfig> {
643 let encoding = match desc.encoding.as_str() {
644 "none" => EncodingType::None,
645 "simple_packing" => {
646 if dtype.byte_width() != 8 {
647 return Err(TensogramError::Encoding(
648 "simple_packing only supports float64 dtype".to_string(),
649 ));
650 }
651 let params = extract_simple_packing_params(&desc.params)?;
652 EncodingType::SimplePacking(params)
653 }
654 other => {
655 return Err(TensogramError::Encoding(format!(
656 "unknown encoding: {other}"
657 )));
658 }
659 };
660
661 let filter = match desc.filter.as_str() {
662 "none" => FilterType::None,
663 "shuffle" => {
664 let element_size = usize::try_from(get_u64_param(
665 &desc.params,
666 "shuffle_element_size",
667 )?)
668 .map_err(|_| {
669 TensogramError::Metadata("shuffle_element_size out of usize range".to_string())
670 })?;
671 FilterType::Shuffle { element_size }
672 }
673 other => return Err(TensogramError::Encoding(format!("unknown filter: {other}"))),
674 };
675
676 let compression = match desc.compression.as_str() {
677 "none" => CompressionType::None,
678 #[cfg(any(feature = "szip", feature = "szip-pure"))]
679 "szip" => {
680 let rsi = u32::try_from(get_u64_param(&desc.params, "szip_rsi")?)
681 .map_err(|_| TensogramError::Metadata("szip_rsi out of u32 range".to_string()))?;
682 let block_size = u32::try_from(get_u64_param(&desc.params, "szip_block_size")?)
683 .map_err(|_| {
684 TensogramError::Metadata("szip_block_size out of u32 range".to_string())
685 })?;
686 let flags = u32::try_from(get_u64_param(&desc.params, "szip_flags")?)
687 .map_err(|_| TensogramError::Metadata("szip_flags out of u32 range".to_string()))?;
688 let bits_per_sample = match (&encoding, &filter) {
689 (EncodingType::SimplePacking(params), _) => params.bits_per_value,
690 (EncodingType::None, FilterType::Shuffle { .. }) => 8,
691 (EncodingType::None, FilterType::None) => (dtype.byte_width() * 8) as u32,
692 };
693 CompressionType::Szip {
694 rsi,
695 block_size,
696 flags,
697 bits_per_sample,
698 }
699 }
700 #[cfg(any(feature = "zstd", feature = "zstd-pure"))]
701 "zstd" => {
702 let level_i64 = get_i64_param(&desc.params, "zstd_level").unwrap_or(3);
703 let level = i32::try_from(level_i64).map_err(|_| {
704 TensogramError::Metadata(format!("zstd_level value {level_i64} out of i32 range"))
705 })?;
706 CompressionType::Zstd { level }
707 }
708 #[cfg(feature = "lz4")]
709 "lz4" => CompressionType::Lz4,
710 #[cfg(feature = "blosc2")]
711 "blosc2" => {
712 let codec_str = match desc.params.get("blosc2_codec") {
713 Some(ciborium::Value::Text(s)) => s.as_str(),
714 _ => "lz4",
715 };
716 let codec = match codec_str {
717 "blosclz" => Blosc2Codec::Blosclz,
718 "lz4" => Blosc2Codec::Lz4,
719 "lz4hc" => Blosc2Codec::Lz4hc,
720 "zlib" => Blosc2Codec::Zlib,
721 "zstd" => Blosc2Codec::Zstd,
722 other => {
723 return Err(TensogramError::Encoding(format!(
724 "unknown blosc2 codec: {other}"
725 )));
726 }
727 };
728 let clevel_i64 = get_i64_param(&desc.params, "blosc2_clevel").unwrap_or(5);
729 let clevel = i32::try_from(clevel_i64).map_err(|_| {
730 TensogramError::Metadata(format!(
731 "blosc2_clevel value {clevel_i64} out of i32 range"
732 ))
733 })?;
734 let typesize = match (&encoding, &filter) {
735 (EncodingType::SimplePacking(params), _) => {
736 (params.bits_per_value as usize).div_ceil(8)
737 }
738 (EncodingType::None, FilterType::Shuffle { .. }) => 1,
739 (EncodingType::None, FilterType::None) => dtype.byte_width(),
740 };
741 CompressionType::Blosc2 {
742 codec,
743 clevel,
744 typesize,
745 }
746 }
747 #[cfg(feature = "zfp")]
748 "zfp" => {
749 let mode_str = match desc.params.get("zfp_mode") {
750 Some(ciborium::Value::Text(s)) => s.clone(),
751 _ => {
752 return Err(TensogramError::Metadata(
753 "missing required parameter: zfp_mode".to_string(),
754 ));
755 }
756 };
757 let mode = match mode_str.as_str() {
758 "fixed_rate" => {
759 let rate = get_f64_param(&desc.params, "zfp_rate")?;
760 ZfpMode::FixedRate { rate }
761 }
762 "fixed_precision" => {
763 let precision = u32::try_from(get_u64_param(&desc.params, "zfp_precision")?)
764 .map_err(|_| {
765 TensogramError::Metadata("zfp_precision out of u32 range".to_string())
766 })?;
767 ZfpMode::FixedPrecision { precision }
768 }
769 "fixed_accuracy" => {
770 let tolerance = get_f64_param(&desc.params, "zfp_tolerance")?;
771 ZfpMode::FixedAccuracy { tolerance }
772 }
773 other => {
774 return Err(TensogramError::Encoding(format!(
775 "unknown zfp_mode: {other}"
776 )));
777 }
778 };
779 CompressionType::Zfp { mode }
780 }
781 #[cfg(feature = "sz3")]
782 "sz3" => {
783 let mode_str = match desc.params.get("sz3_error_bound_mode") {
784 Some(ciborium::Value::Text(s)) => s.clone(),
785 _ => {
786 return Err(TensogramError::Metadata(
787 "missing required parameter: sz3_error_bound_mode".to_string(),
788 ));
789 }
790 };
791 let bound_val = get_f64_param(&desc.params, "sz3_error_bound")?;
792 let error_bound = match mode_str.as_str() {
793 "abs" => Sz3ErrorBound::Absolute(bound_val),
794 "rel" => Sz3ErrorBound::Relative(bound_val),
795 "psnr" => Sz3ErrorBound::Psnr(bound_val),
796 other => {
797 return Err(TensogramError::Encoding(format!(
798 "unknown sz3_error_bound_mode: {other}"
799 )));
800 }
801 };
802 CompressionType::Sz3 { error_bound }
803 }
804 other => {
805 return Err(TensogramError::Encoding(format!(
806 "unknown compression: {other}"
807 )));
808 }
809 };
810
811 Ok(PipelineConfig {
812 encoding,
813 filter,
814 compression,
815 num_values,
816 byte_order: desc.byte_order,
817 dtype_byte_width: dtype.byte_width(),
818 swap_unit_size: dtype.swap_unit_size(),
819 compression_backend,
820 intra_codec_threads,
821 compute_hash: false,
826 })
827}
828
829fn extract_simple_packing_params(
830 params: &BTreeMap<String, ciborium::Value>,
831) -> Result<SimplePackingParams> {
832 let reference_value = get_f64_param(params, "reference_value")?;
833 if reference_value.is_nan() || reference_value.is_infinite() {
834 return Err(TensogramError::Metadata(format!(
835 "reference_value must be finite, got {reference_value}"
836 )));
837 }
838 Ok(SimplePackingParams {
839 reference_value,
840 binary_scale_factor: i32::try_from(get_i64_param(params, "binary_scale_factor")?).map_err(
841 |_| TensogramError::Metadata("binary_scale_factor out of i32 range".to_string()),
842 )?,
843 decimal_scale_factor: i32::try_from(get_i64_param(params, "decimal_scale_factor")?)
844 .map_err(|_| {
845 TensogramError::Metadata("decimal_scale_factor out of i32 range".to_string())
846 })?,
847 bits_per_value: u32::try_from(get_u64_param(params, "bits_per_value")?)
848 .map_err(|_| TensogramError::Metadata("bits_per_value out of u32 range".to_string()))?,
849 })
850}
851
852pub(crate) fn get_f64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<f64> {
853 match params.get(key) {
854 Some(ciborium::Value::Float(f)) => Ok(*f),
855 Some(ciborium::Value::Integer(i)) => {
856 let n: i128 = (*i).into();
859 Ok(n as f64)
860 }
861 Some(other) => Err(TensogramError::Metadata(format!(
862 "expected number for {key}, got {other:?}"
863 ))),
864 None => Err(TensogramError::Metadata(format!(
865 "missing required parameter: {key}"
866 ))),
867 }
868}
869
870pub(crate) fn get_i64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<i64> {
871 match params.get(key) {
872 Some(ciborium::Value::Integer(i)) => {
873 let n: i128 = (*i).into();
874 i64::try_from(n).map_err(|_| {
875 TensogramError::Metadata(format!("integer value {n} out of i64 range for {key}"))
876 })
877 }
878 Some(other) => Err(TensogramError::Metadata(format!(
879 "expected integer for {key}, got {other:?}"
880 ))),
881 None => Err(TensogramError::Metadata(format!(
882 "missing required parameter: {key}"
883 ))),
884 }
885}
886
887pub(crate) fn get_u64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<u64> {
888 match params.get(key) {
889 Some(ciborium::Value::Integer(i)) => {
890 let n: i128 = (*i).into();
891 u64::try_from(n).map_err(|_| {
892 TensogramError::Metadata(format!("integer value {n} out of u64 range for {key}"))
893 })
894 }
895 Some(other) => Err(TensogramError::Metadata(format!(
896 "expected integer for {key}, got {other:?}"
897 ))),
898 None => Err(TensogramError::Metadata(format!(
899 "missing required parameter: {key}"
900 ))),
901 }
902}
903
904pub(crate) fn validate_szip_block_offsets(
905 params: &BTreeMap<String, ciborium::Value>,
906 encoded_bytes_len: usize,
907) -> Result<()> {
908 let value = params.get("szip_block_offsets").ok_or_else(|| {
909 TensogramError::Metadata(
910 "missing required parameter: szip_block_offsets for szip compression".to_string(),
911 )
912 })?;
913
914 let offsets = match value {
915 ciborium::Value::Array(arr) => arr,
916 other => {
917 return Err(TensogramError::Metadata(format!(
918 "szip_block_offsets must be an array, got {other:?}"
919 )));
920 }
921 };
922
923 if offsets.is_empty() {
924 return Err(TensogramError::Metadata(
925 "szip_block_offsets must not be empty; first offset must be 0".to_string(),
926 ));
927 }
928
929 let bit_bound = encoded_bytes_len.checked_mul(8).ok_or_else(|| {
930 TensogramError::Metadata(format!(
931 "encoded byte length {encoded_bytes_len} overflows bit-bound calculation"
932 ))
933 })?;
934 let bit_bound_u64 = u64::try_from(bit_bound).map_err(|_| {
935 TensogramError::Metadata(format!(
936 "bit-bound {bit_bound} derived from {encoded_bytes_len} bytes does not fit in u64"
937 ))
938 })?;
939
940 let mut parsed_offsets = Vec::with_capacity(offsets.len());
941 for (idx, item) in offsets.iter().enumerate() {
942 let offset = match item {
943 ciborium::Value::Integer(i) => {
944 let n: i128 = (*i).into();
945 u64::try_from(n).map_err(|_| {
946 TensogramError::Metadata(format!(
947 "szip_block_offsets[{idx}] = {n} out of u64 range"
948 ))
949 })?
950 }
951 other => {
952 return Err(TensogramError::Metadata(format!(
953 "szip_block_offsets[{idx}] must be an integer, got {other:?}"
954 )));
955 }
956 };
957
958 if offset > bit_bound_u64 {
959 return Err(TensogramError::Metadata(format!(
960 "szip_block_offsets[{idx}] = {offset} exceeds bit bound {bit_bound_u64} (encoded_bytes_len = {encoded_bytes_len} bytes, {bit_bound_u64} bits)"
961 )));
962 }
963
964 if idx == 0 {
965 if offset != 0 {
966 return Err(TensogramError::Metadata(format!(
967 "szip_block_offsets[0] must be 0, got {offset}"
968 )));
969 }
970 } else {
971 let prev = parsed_offsets[idx - 1];
972 if offset <= prev {
973 return Err(TensogramError::Metadata(format!(
974 "szip_block_offsets must be strictly increasing: szip_block_offsets[{}] = {}, szip_block_offsets[{idx}] = {offset}",
975 idx - 1,
976 prev
977 )));
978 }
979 }
980
981 parsed_offsets.push(offset);
982 }
983
984 Ok(())
985}
986
987pub(crate) fn validate_no_szip_offsets_for_non_szip(desc: &DataObjectDescriptor) -> Result<()> {
988 if desc.compression != "szip" && desc.params.contains_key("szip_block_offsets") {
989 return Err(TensogramError::Metadata(format!(
990 "szip_block_offsets provided but compression is '{}', not 'szip'",
991 desc.compression
992 )));
993 }
994 Ok(())
995}
996
997#[cfg(test)]
1000mod tests {
1001 use super::*;
1002 use crate::decode::{DecodeOptions, decode};
1003 use crate::types::{ByteOrder, GlobalMetadata};
1004 use std::collections::BTreeMap;
1005
1006 fn make_descriptor(shape: Vec<u64>) -> DataObjectDescriptor {
1007 let strides = {
1008 let mut s = vec![1u64; shape.len()];
1009 for i in (0..shape.len().saturating_sub(1)).rev() {
1010 s[i] = s[i + 1] * shape[i + 1];
1011 }
1012 s
1013 };
1014 DataObjectDescriptor {
1015 obj_type: "ntensor".to_string(),
1016 ndim: shape.len() as u64,
1017 shape,
1018 strides,
1019 dtype: Dtype::Float32,
1020 byte_order: ByteOrder::native(),
1021 encoding: "none".to_string(),
1022 filter: "none".to_string(),
1023 compression: "none".to_string(),
1024 params: BTreeMap::new(),
1025 hash: None,
1026 }
1027 }
1028
1029 #[test]
1032 fn test_base_more_entries_than_descriptors_rejected() {
1033 let meta = GlobalMetadata {
1035 version: 2,
1036 base: vec![
1037 BTreeMap::new(),
1038 BTreeMap::new(),
1039 BTreeMap::new(),
1040 BTreeMap::new(),
1041 BTreeMap::new(),
1042 ],
1043 ..Default::default()
1044 };
1045 let desc = make_descriptor(vec![4]);
1046 let data = vec![0u8; 16];
1047 let options = EncodeOptions {
1048 hash_algorithm: None,
1049 ..Default::default()
1050 };
1051 let result = encode(
1052 &meta,
1053 &[(&desc, data.as_slice()), (&desc, data.as_slice())],
1054 &options,
1055 );
1056 assert!(
1057 result.is_err(),
1058 "5 base entries with 2 descriptors should fail"
1059 );
1060 let err = result.unwrap_err().to_string();
1061 assert!(
1062 err.contains("5") && err.contains("2"),
1063 "error should mention counts: {err}"
1064 );
1065 }
1066
1067 #[test]
1068 fn test_base_fewer_entries_than_descriptors_auto_extended() {
1069 let meta = GlobalMetadata {
1071 version: 2,
1072 base: vec![],
1073 ..Default::default()
1074 };
1075 let desc = make_descriptor(vec![2]);
1076 let data = vec![0u8; 8];
1077 let options = EncodeOptions {
1078 hash_algorithm: None,
1079 ..Default::default()
1080 };
1081 let msg = encode(
1082 &meta,
1083 &[
1084 (&desc, data.as_slice()),
1085 (&desc, data.as_slice()),
1086 (&desc, data.as_slice()),
1087 ],
1088 &options,
1089 )
1090 .unwrap();
1091
1092 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1093 assert_eq!(decoded.base.len(), 3);
1094 for entry in &decoded.base {
1096 assert!(
1097 entry.contains_key("_reserved_"),
1098 "auto-extended base entry should have _reserved_"
1099 );
1100 }
1101 }
1102
1103 #[test]
1104 fn test_base_entry_with_top_level_key_names_no_collision() {
1105 let mut entry = BTreeMap::new();
1107 entry.insert(
1108 "version".to_string(),
1109 ciborium::Value::Text("my-version".to_string()),
1110 );
1111 entry.insert(
1112 "base".to_string(),
1113 ciborium::Value::Text("not-the-real-base".to_string()),
1114 );
1115 let meta = GlobalMetadata {
1116 version: 2,
1117 base: vec![entry],
1118 ..Default::default()
1119 };
1120 let desc = make_descriptor(vec![2]);
1121 let data = vec![0u8; 8];
1122 let options = EncodeOptions {
1123 hash_algorithm: None,
1124 ..Default::default()
1125 };
1126 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1127 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1128
1129 assert_eq!(decoded.version, 2);
1131 assert_eq!(
1133 decoded.base[0].get("version"),
1134 Some(&ciborium::Value::Text("my-version".to_string()))
1135 );
1136 assert_eq!(
1137 decoded.base[0].get("base"),
1138 Some(&ciborium::Value::Text("not-the-real-base".to_string()))
1139 );
1140 }
1141
1142 #[test]
1143 fn test_base_entry_with_deeply_nested_reserved_allowed() {
1144 let nested = ciborium::Value::Map(vec![(
1147 ciborium::Value::Text("_reserved_".to_string()),
1148 ciborium::Value::Text("nested-is-ok".to_string()),
1149 )]);
1150 let mut entry = BTreeMap::new();
1151 entry.insert("foo".to_string(), nested);
1152 let meta = GlobalMetadata {
1153 version: 2,
1154 base: vec![entry],
1155 ..Default::default()
1156 };
1157 let desc = make_descriptor(vec![2]);
1158 let data = vec![0u8; 8];
1159 let options = EncodeOptions {
1160 hash_algorithm: None,
1161 ..Default::default()
1162 };
1163 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1165 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1166 let foo = decoded.base[0].get("foo").unwrap();
1168 if let ciborium::Value::Map(pairs) = foo {
1169 assert_eq!(pairs.len(), 1);
1170 } else {
1171 panic!("expected map for foo");
1172 }
1173 }
1174
1175 #[test]
1178 fn test_reserved_rejected_at_message_level() {
1179 let mut reserved = BTreeMap::new();
1180 reserved.insert(
1181 "rogue".to_string(),
1182 ciborium::Value::Text("bad".to_string()),
1183 );
1184 let meta = GlobalMetadata {
1185 version: 2,
1186 reserved,
1187 ..Default::default()
1188 };
1189 let desc = make_descriptor(vec![2]);
1190 let data = vec![0u8; 8];
1191 let result = encode(
1192 &meta,
1193 &[(&desc, data.as_slice())],
1194 &EncodeOptions::default(),
1195 );
1196 assert!(result.is_err());
1197 let err = result.unwrap_err().to_string();
1198 assert!(
1199 err.contains("_reserved_") && err.contains("message level"),
1200 "error: {err}"
1201 );
1202 }
1203
1204 #[test]
1205 fn test_reserved_rejected_in_base_entry() {
1206 let mut entry = BTreeMap::new();
1207 entry.insert("_reserved_".to_string(), ciborium::Value::Map(vec![]));
1208 let meta = GlobalMetadata {
1209 version: 2,
1210 base: vec![entry],
1211 ..Default::default()
1212 };
1213 let desc = make_descriptor(vec![2]);
1214 let data = vec![0u8; 8];
1215 let result = encode(
1216 &meta,
1217 &[(&desc, data.as_slice())],
1218 &EncodeOptions::default(),
1219 );
1220 assert!(result.is_err());
1221 let err = result.unwrap_err().to_string();
1222 assert!(
1223 err.contains("_reserved_") && err.contains("base[0]"),
1224 "error: {err}"
1225 );
1226 }
1227
1228 #[test]
1229 fn test_reserved_tensor_has_four_keys_after_encode() {
1230 let meta = GlobalMetadata::default();
1231 let desc = make_descriptor(vec![3, 4]);
1232 let data = vec![0u8; 3 * 4 * 4]; let options = EncodeOptions {
1234 hash_algorithm: None,
1235 ..Default::default()
1236 };
1237 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1238 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1239
1240 let reserved = decoded.base[0]
1241 .get("_reserved_")
1242 .expect("_reserved_ missing");
1243 if let ciborium::Value::Map(pairs) = reserved {
1244 let tensor_entry = pairs
1246 .iter()
1247 .find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1248 assert!(tensor_entry.is_some(), "missing tensor key in _reserved_");
1249 if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
1250 let keys: Vec<String> = tensor_pairs
1251 .iter()
1252 .filter_map(|(k, _)| {
1253 if let ciborium::Value::Text(s) = k {
1254 Some(s.clone())
1255 } else {
1256 None
1257 }
1258 })
1259 .collect();
1260 assert_eq!(keys.len(), 4, "tensor should have 4 keys, got: {keys:?}");
1261 assert!(keys.contains(&"ndim".to_string()));
1262 assert!(keys.contains(&"shape".to_string()));
1263 assert!(keys.contains(&"strides".to_string()));
1264 assert!(keys.contains(&"dtype".to_string()));
1265 } else {
1266 panic!("tensor is not a map");
1267 }
1268 } else {
1269 panic!("_reserved_ is not a map");
1270 }
1271 }
1272
1273 #[test]
1274 fn test_reserved_tensor_scalar_ndim_zero() {
1275 let desc = DataObjectDescriptor {
1277 obj_type: "ntensor".to_string(),
1278 ndim: 0,
1279 shape: vec![],
1280 strides: vec![],
1281 dtype: Dtype::Float32,
1282 byte_order: ByteOrder::native(),
1283 encoding: "none".to_string(),
1284 filter: "none".to_string(),
1285 compression: "none".to_string(),
1286 params: BTreeMap::new(),
1287 hash: None,
1288 };
1289 let data = vec![0u8; 4]; let meta = GlobalMetadata::default();
1291 let options = EncodeOptions {
1292 hash_algorithm: None,
1293 ..Default::default()
1294 };
1295 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1296 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1297
1298 let reserved = decoded.base[0]
1299 .get("_reserved_")
1300 .expect("_reserved_ missing");
1301 if let ciborium::Value::Map(pairs) = reserved {
1302 let tensor_entry = pairs
1303 .iter()
1304 .find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1305 if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
1306 let ndim = tensor_pairs
1308 .iter()
1309 .find(|(k, _)| *k == ciborium::Value::Text("ndim".to_string()));
1310 assert!(
1311 matches!(ndim, Some((_, ciborium::Value::Integer(i))) if i128::from(*i) == 0),
1312 "ndim should be 0 for scalar"
1313 );
1314 let shape = tensor_pairs
1316 .iter()
1317 .find(|(k, _)| *k == ciborium::Value::Text("shape".to_string()));
1318 assert!(
1319 matches!(shape, Some((_, ciborium::Value::Array(a))) if a.is_empty()),
1320 "shape should be [] for scalar"
1321 );
1322 } else {
1323 panic!("tensor missing or not a map");
1324 }
1325 } else {
1326 panic!("_reserved_ is not a map");
1327 }
1328 }
1329
1330 #[test]
1333 fn test_extra_with_keys_colliding_with_base_entry_keys() {
1334 let mut entry = BTreeMap::new();
1336 entry.insert(
1337 "mars".to_string(),
1338 ciborium::Value::Text("base-mars".to_string()),
1339 );
1340 let mut extra = BTreeMap::new();
1341 extra.insert(
1342 "mars".to_string(),
1343 ciborium::Value::Text("extra-mars".to_string()),
1344 );
1345 let meta = GlobalMetadata {
1346 version: 2,
1347 base: vec![entry],
1348 extra,
1349 ..Default::default()
1350 };
1351 let desc = make_descriptor(vec![2]);
1352 let data = vec![0u8; 8];
1353 let options = EncodeOptions {
1354 hash_algorithm: None,
1355 ..Default::default()
1356 };
1357 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1358 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1359
1360 assert_eq!(
1361 decoded.base[0].get("mars"),
1362 Some(&ciborium::Value::Text("base-mars".to_string()))
1363 );
1364 assert_eq!(
1365 decoded.extra.get("mars"),
1366 Some(&ciborium::Value::Text("extra-mars".to_string()))
1367 );
1368 }
1369
1370 #[test]
1371 fn test_empty_extra_omitted_from_cbor() {
1372 let meta = GlobalMetadata {
1373 version: 2,
1374 extra: BTreeMap::new(),
1375 ..Default::default()
1376 };
1377 let desc = make_descriptor(vec![2]);
1378 let data = vec![0u8; 8];
1379 let options = EncodeOptions {
1380 hash_algorithm: None,
1381 ..Default::default()
1382 };
1383 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1384 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1385 assert!(decoded.extra.is_empty());
1386 }
1387
1388 #[test]
1389 fn test_extra_with_nested_maps_round_trips() {
1390 let nested = ciborium::Value::Map(vec![
1391 (
1392 ciborium::Value::Text("key1".to_string()),
1393 ciborium::Value::Integer(42.into()),
1394 ),
1395 (
1396 ciborium::Value::Text("key2".to_string()),
1397 ciborium::Value::Map(vec![(
1398 ciborium::Value::Text("deep".to_string()),
1399 ciborium::Value::Bool(true),
1400 )]),
1401 ),
1402 ]);
1403 let mut extra = BTreeMap::new();
1404 extra.insert("nested".to_string(), nested.clone());
1405 let meta = GlobalMetadata {
1406 version: 2,
1407 extra,
1408 ..Default::default()
1409 };
1410 let desc = make_descriptor(vec![2]);
1411 let data = vec![0u8; 8];
1412 let options = EncodeOptions {
1413 hash_algorithm: None,
1414 ..Default::default()
1415 };
1416 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1417 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1418 assert!(decoded.extra.contains_key("nested"));
1420 }
1421
1422 #[test]
1425 fn test_old_common_payload_keys_silently_ignored() {
1426 use ciborium::Value;
1430 let cbor = Value::Map(vec![
1431 (Value::Text("version".to_string()), Value::Integer(2.into())),
1432 (Value::Text("common".to_string()), Value::Map(vec![])),
1433 (Value::Text("payload".to_string()), Value::Array(vec![])),
1434 ]);
1435 let mut bytes = Vec::new();
1436 ciborium::into_writer(&cbor, &mut bytes).unwrap();
1437
1438 let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
1439 assert_eq!(decoded.version, 2);
1440 assert!(decoded.base.is_empty());
1441 assert!(decoded.extra.is_empty());
1442 assert!(decoded.reserved.is_empty());
1443 }
1444
1445 #[test]
1446 fn test_old_reserved_key_name_ignored() {
1447 use ciborium::Value;
1449 let cbor = Value::Map(vec![
1450 (Value::Text("version".to_string()), Value::Integer(2.into())),
1451 (
1452 Value::Text("reserved".to_string()),
1453 Value::Map(vec![(
1454 Value::Text("rogue".to_string()),
1455 Value::Text("value".to_string()),
1456 )]),
1457 ),
1458 ]);
1459 let mut bytes = Vec::new();
1460 ciborium::into_writer(&cbor, &mut bytes).unwrap();
1461
1462 let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
1463 assert!(
1464 decoded.reserved.is_empty(),
1465 "old 'reserved' key should be ignored"
1466 );
1467 }
1468
1469 #[test]
1472 fn test_reserved_rejected_in_second_base_entry_only() {
1473 let mut entry0 = BTreeMap::new();
1475 entry0.insert("clean".to_string(), ciborium::Value::Text("ok".to_string()));
1476 let mut entry1 = BTreeMap::new();
1477 entry1.insert("_reserved_".to_string(), ciborium::Value::Map(vec![]));
1478 let meta = GlobalMetadata {
1479 version: 2,
1480 base: vec![entry0, entry1],
1481 ..Default::default()
1482 };
1483 let desc = make_descriptor(vec![2]);
1484 let data = vec![0u8; 8];
1485 let result = encode(
1486 &meta,
1487 &[(&desc, data.as_slice()), (&desc, data.as_slice())],
1488 &EncodeOptions::default(),
1489 );
1490 assert!(result.is_err());
1491 let err = result.unwrap_err().to_string();
1492 assert!(
1493 err.contains("base[1]"),
1494 "error should mention base[1]: {err}"
1495 );
1496 }
1497
1498 #[test]
1499 fn test_reserved_accepted_when_all_base_entries_clean() {
1500 let mut e0 = BTreeMap::new();
1502 e0.insert(
1503 "key0".to_string(),
1504 ciborium::Value::Text("val0".to_string()),
1505 );
1506 let mut e1 = BTreeMap::new();
1507 e1.insert(
1508 "key1".to_string(),
1509 ciborium::Value::Text("val1".to_string()),
1510 );
1511 let meta = GlobalMetadata {
1512 version: 2,
1513 base: vec![e0, e1],
1514 ..Default::default()
1515 };
1516 let desc = make_descriptor(vec![2]);
1517 let data = vec![0u8; 8];
1518 let options = EncodeOptions {
1519 hash_algorithm: None,
1520 ..Default::default()
1521 };
1522 let msg = encode(
1523 &meta,
1524 &[(&desc, data.as_slice()), (&desc, data.as_slice())],
1525 &options,
1526 )
1527 .unwrap();
1528 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1529 assert_eq!(decoded.base.len(), 2);
1530 assert!(decoded.base[0].contains_key("key0"));
1531 assert!(decoded.base[1].contains_key("key1"));
1532 }
1533
1534 #[test]
1537 fn test_reserved_tensor_dtype_strings_for_all_dtypes() {
1538 let dtypes_and_expected = [
1540 (Dtype::Float16, "float16"),
1541 (Dtype::Bfloat16, "bfloat16"),
1542 (Dtype::Float32, "float32"),
1543 (Dtype::Float64, "float64"),
1544 (Dtype::Complex64, "complex64"),
1545 (Dtype::Complex128, "complex128"),
1546 (Dtype::Int8, "int8"),
1547 (Dtype::Int16, "int16"),
1548 (Dtype::Int32, "int32"),
1549 (Dtype::Int64, "int64"),
1550 (Dtype::Uint8, "uint8"),
1551 (Dtype::Uint16, "uint16"),
1552 (Dtype::Uint32, "uint32"),
1553 (Dtype::Uint64, "uint64"),
1554 ];
1555
1556 for (dtype, expected_str) in dtypes_and_expected {
1557 let byte_width = dtype.byte_width();
1558 let num_elements: u64 = 4;
1559 let data_len = num_elements as usize * byte_width;
1560
1561 let desc = DataObjectDescriptor {
1562 obj_type: "ntensor".to_string(),
1563 ndim: 1,
1564 shape: vec![num_elements],
1565 strides: vec![1],
1566 dtype,
1567 byte_order: ByteOrder::native(),
1568 encoding: "none".to_string(),
1569 filter: "none".to_string(),
1570 compression: "none".to_string(),
1571 params: BTreeMap::new(),
1572 hash: None,
1573 };
1574 let data = vec![0u8; data_len];
1575 let meta = GlobalMetadata::default();
1576 let options = EncodeOptions {
1577 hash_algorithm: None,
1578 ..Default::default()
1579 };
1580 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1581 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1582
1583 let reserved = decoded.base[0]
1584 .get("_reserved_")
1585 .unwrap_or_else(|| panic!("_reserved_ missing for dtype {dtype}"));
1586 if let ciborium::Value::Map(pairs) = reserved {
1587 let tensor_entry = pairs
1588 .iter()
1589 .find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1590 if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
1591 let dtype_val = tensor_pairs
1592 .iter()
1593 .find(|(k, _)| *k == ciborium::Value::Text("dtype".to_string()));
1594 assert!(
1595 matches!(
1596 dtype_val,
1597 Some((_, ciborium::Value::Text(s))) if s == expected_str
1598 ),
1599 "dtype for {dtype} should be '{expected_str}', got: {dtype_val:?}"
1600 );
1601 } else {
1602 panic!("tensor missing or not a map for dtype {dtype}");
1603 }
1604 } else {
1605 panic!("_reserved_ is not a map for dtype {dtype}");
1606 }
1607 }
1608 }
1609
1610 #[test]
1613 fn test_global_metadata_serde_all_fields_populated() {
1614 use ciborium::Value;
1616
1617 let mut base_entry = BTreeMap::new();
1618 base_entry.insert("key".to_string(), Value::Text("base_val".to_string()));
1619 let mut reserved = BTreeMap::new();
1620 reserved.insert("encoder".to_string(), Value::Text("test".to_string()));
1621 let mut extra = BTreeMap::new();
1622 extra.insert("custom".to_string(), Value::Integer(42.into()));
1623
1624 let meta = GlobalMetadata {
1625 version: 2,
1626 base: vec![base_entry],
1627 reserved,
1628 extra,
1629 };
1630
1631 let cbor_bytes = crate::metadata::global_metadata_to_cbor(&meta).unwrap();
1633 let decoded: GlobalMetadata =
1634 crate::metadata::cbor_to_global_metadata(&cbor_bytes).unwrap();
1635
1636 assert_eq!(decoded.version, 2);
1637 assert_eq!(decoded.base.len(), 1);
1638 assert_eq!(
1639 decoded.base[0].get("key"),
1640 Some(&Value::Text("base_val".to_string()))
1641 );
1642 assert!(decoded.reserved.contains_key("encoder"));
1643 assert_eq!(
1644 decoded.extra.get("custom"),
1645 Some(&Value::Integer(42.into()))
1646 );
1647 }
1648
1649 #[test]
1652 fn test_provenance_fields_present_after_encode() {
1653 let meta = GlobalMetadata::default();
1654 let desc = make_descriptor(vec![2]);
1655 let data = vec![0u8; 8];
1656 let options = EncodeOptions {
1657 hash_algorithm: None,
1658 ..Default::default()
1659 };
1660 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1661 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1662
1663 assert!(decoded.reserved.contains_key("encoder"));
1665 assert!(decoded.reserved.contains_key("time"));
1666 assert!(decoded.reserved.contains_key("uuid"));
1667
1668 if let ciborium::Value::Map(pairs) = decoded.reserved.get("encoder").unwrap() {
1670 let has_name = pairs
1671 .iter()
1672 .any(|(k, _)| *k == ciborium::Value::Text("name".to_string()));
1673 let has_version = pairs
1674 .iter()
1675 .any(|(k, _)| *k == ciborium::Value::Text("version".to_string()));
1676 assert!(has_name, "encoder map should have 'name' key");
1677 assert!(has_version, "encoder map should have 'version' key");
1678 } else {
1679 panic!("encoder should be a map");
1680 }
1681
1682 if let ciborium::Value::Text(uuid_str) = decoded.reserved.get("uuid").unwrap() {
1684 assert_eq!(uuid_str.len(), 36, "UUID should be 36 chars: {uuid_str}");
1685 assert_eq!(
1686 uuid_str.chars().filter(|c| *c == '-').count(),
1687 4,
1688 "UUID should have 4 hyphens: {uuid_str}"
1689 );
1690 } else {
1691 panic!("uuid should be a text");
1692 }
1693
1694 if let ciborium::Value::Text(time_str) = decoded.reserved.get("time").unwrap() {
1696 assert!(
1697 time_str.ends_with('Z'),
1698 "time should end with Z: {time_str}"
1699 );
1700 assert!(
1701 time_str.contains('T'),
1702 "time should contain T separator: {time_str}"
1703 );
1704 } else {
1705 panic!("time should be a text");
1706 }
1707 }
1708
1709 #[test]
1710 fn test_both_reserved_and_reserved_underscore_only_new_captured() {
1711 use ciborium::Value;
1713 let cbor = Value::Map(vec![
1714 (
1715 Value::Text("_reserved_".to_string()),
1716 Value::Map(vec![(
1717 Value::Text("encoder".to_string()),
1718 Value::Text("tensogram".to_string()),
1719 )]),
1720 ),
1721 (
1722 Value::Text("reserved".to_string()),
1723 Value::Map(vec![(
1724 Value::Text("old".to_string()),
1725 Value::Text("ignored".to_string()),
1726 )]),
1727 ),
1728 (Value::Text("version".to_string()), Value::Integer(2.into())),
1729 ]);
1730 let mut bytes = Vec::new();
1731 ciborium::into_writer(&cbor, &mut bytes).unwrap();
1732
1733 let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
1734 assert!(decoded.reserved.contains_key("encoder"));
1735 assert!(!decoded.reserved.contains_key("old"));
1736 }
1737
1738 #[test]
1745 fn test_encode_pre_encoded_roundtrip_simple_packing() {
1746 let desc = make_descriptor(vec![4]);
1748 let raw_data: Vec<u8> = vec![0u8; 4 * 4]; let meta = GlobalMetadata::default();
1751 let options = EncodeOptions::default();
1752
1753 let msg1 = encode(&meta, &[(&desc, raw_data.as_slice())], &options).unwrap();
1755
1756 let (_, objects1) = decode(&msg1, &DecodeOptions::default()).unwrap();
1758 let (decoded_desc1, decoded_payload1) = &objects1[0];
1759
1760 let msg2 = encode_pre_encoded(
1762 &meta,
1763 &[(&decoded_desc1.clone(), decoded_payload1.as_slice())],
1764 &options,
1765 )
1766 .unwrap();
1767
1768 let (_, objects2) = decode(&msg2, &DecodeOptions::default()).unwrap();
1770 let (_, decoded_payload2) = &objects2[0];
1771
1772 assert_eq!(
1775 decoded_payload1, decoded_payload2,
1776 "decoded payloads should be equal after encode/re-encode roundtrip"
1777 );
1778 }
1779
1780 #[test]
1782 fn test_encode_pre_encoded_rejects_emit_preceders() {
1783 let desc = make_descriptor(vec![2]);
1784 let data = vec![0u8; 8];
1785 let meta = GlobalMetadata::default();
1786 let options = EncodeOptions {
1787 emit_preceders: true,
1788 ..Default::default()
1789 };
1790 let result = encode_pre_encoded(&meta, &[(&desc, data.as_slice())], &options);
1791 assert!(
1792 result.is_err(),
1793 "encode_pre_encoded with emit_preceders=true should fail"
1794 );
1795 let err = result.unwrap_err().to_string();
1796 assert!(
1797 err.contains("emit_preceders"),
1798 "error should mention emit_preceders: {err}"
1799 );
1800 }
1801
1802 #[test]
1804 fn test_encode_pre_encoded_overwrites_caller_hash() {
1805 let mut desc = make_descriptor(vec![2]);
1806 desc.hash = Some(HashDescriptor {
1808 hash_type: "xxh3".to_string(),
1809 value: "deadbeefdeadbeef".to_string(),
1810 });
1811
1812 let data = vec![0xAB_u8; 8]; let meta = GlobalMetadata::default();
1814 let options = EncodeOptions::default(); let msg = encode_pre_encoded(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1817 let (_, objects) = decode(&msg, &DecodeOptions::default()).unwrap();
1818 let (decoded_desc, decoded_payload) = &objects[0];
1819
1820 let computed_hash = match options.hash_algorithm {
1822 Some(algorithm) => compute_hash(decoded_payload, algorithm),
1823 None => panic!("expected hash algorithm"),
1824 };
1825
1826 let stored_hash = decoded_desc
1827 .hash
1828 .as_ref()
1829 .expect("hash should be present in decoded descriptor")
1830 .value
1831 .clone();
1832
1833 assert_ne!(
1834 stored_hash, "deadbeefdeadbeef",
1835 "caller's garbage hash must be overwritten"
1836 );
1837 assert_eq!(
1838 stored_hash, computed_hash,
1839 "library-computed hash must match hash over decoded payload"
1840 );
1841 }
1842
1843 #[test]
1844 fn test_validate_szip_block_offsets_happy_path() {
1845 let mut params = BTreeMap::new();
1846 params.insert(
1847 "szip_block_offsets".to_string(),
1848 ciborium::Value::Array(vec![0u64, 100, 200].into_iter().map(|n| n.into()).collect()),
1849 );
1850
1851 assert!(validate_szip_block_offsets(¶ms, 100).is_ok());
1852 }
1853
1854 #[test]
1855 fn test_validate_szip_block_offsets_missing_key() {
1856 let params = BTreeMap::new();
1857
1858 let err = validate_szip_block_offsets(¶ms, 100)
1859 .unwrap_err()
1860 .to_string();
1861 assert!(
1862 err.contains("missing") && err.contains("szip_block_offsets"),
1863 "error: {err}"
1864 );
1865 }
1866
1867 #[test]
1868 fn test_validate_szip_block_offsets_not_array() {
1869 let mut params = BTreeMap::new();
1870 params.insert(
1871 "szip_block_offsets".to_string(),
1872 ciborium::Value::Integer(0.into()),
1873 );
1874
1875 let err = validate_szip_block_offsets(¶ms, 100)
1876 .unwrap_err()
1877 .to_string();
1878 assert!(
1879 err.contains("array") && err.contains("szip_block_offsets"),
1880 "error: {err}"
1881 );
1882 }
1883
1884 #[test]
1885 fn test_validate_szip_block_offsets_non_integer_element() {
1886 let mut params = BTreeMap::new();
1887 params.insert(
1888 "szip_block_offsets".to_string(),
1889 ciborium::Value::Array(vec![
1890 ciborium::Value::Integer(0.into()),
1891 ciborium::Value::Text("x".to_string()),
1892 ]),
1893 );
1894
1895 let err = validate_szip_block_offsets(¶ms, 100)
1896 .unwrap_err()
1897 .to_string();
1898 assert!(
1899 err.contains("integer") && err.contains("szip_block_offsets"),
1900 "error: {err}"
1901 );
1902 }
1903
1904 #[test]
1905 fn test_validate_szip_block_offsets_nonzero_first() {
1906 let mut params = BTreeMap::new();
1907 params.insert(
1908 "szip_block_offsets".to_string(),
1909 ciborium::Value::Array(vec![5u64, 100, 200].into_iter().map(|n| n.into()).collect()),
1910 );
1911
1912 let err = validate_szip_block_offsets(¶ms, 100)
1913 .unwrap_err()
1914 .to_string();
1915 assert!(
1916 err.contains("must be 0") && err.contains("got 5"),
1917 "error: {err}"
1918 );
1919 }
1920
1921 #[test]
1922 fn test_validate_szip_block_offsets_non_monotonic() {
1923 let mut params = BTreeMap::new();
1924 params.insert(
1925 "szip_block_offsets".to_string(),
1926 ciborium::Value::Array(vec![0u64, 100, 50].into_iter().map(|n| n.into()).collect()),
1927 );
1928
1929 let err = validate_szip_block_offsets(¶ms, 100)
1930 .unwrap_err()
1931 .to_string();
1932 assert!(
1933 err.contains("increasing") || err.contains("monotonic"),
1934 "error: {err}"
1935 );
1936 }
1937
1938 #[test]
1939 fn test_validate_szip_block_offsets_offset_beyond_bound() {
1940 let mut params = BTreeMap::new();
1941 params.insert(
1942 "szip_block_offsets".to_string(),
1943 ciborium::Value::Array(vec![0u64, 100, 801].into_iter().map(|n| n.into()).collect()),
1944 );
1945
1946 let err = validate_szip_block_offsets(¶ms, 100)
1947 .unwrap_err()
1948 .to_string();
1949 assert!(err.contains("800") && err.contains("801"), "error: {err}");
1950 }
1951
1952 #[test]
1953 fn test_validate_no_szip_offsets_for_non_szip_rejects() {
1954 let mut params = BTreeMap::new();
1955 params.insert(
1956 "szip_block_offsets".to_string(),
1957 ciborium::Value::Array(vec![0u64, 1].into_iter().map(|n| n.into()).collect()),
1958 );
1959 let desc = DataObjectDescriptor {
1960 obj_type: "ntensor".to_string(),
1961 ndim: 1,
1962 shape: vec![2],
1963 strides: vec![1],
1964 dtype: Dtype::Float32,
1965 byte_order: ByteOrder::native(),
1966 encoding: "none".to_string(),
1967 filter: "none".to_string(),
1968 compression: "zstd".to_string(),
1969 params,
1970 hash: None,
1971 };
1972
1973 let err = validate_no_szip_offsets_for_non_szip(&desc)
1974 .unwrap_err()
1975 .to_string();
1976 assert!(
1977 err.contains("szip_block_offsets") && err.contains("zstd"),
1978 "error: {err}"
1979 );
1980 }
1981
1982 #[test]
1983 fn test_validate_no_szip_offsets_for_non_szip_allows_szip() {
1984 let mut params = BTreeMap::new();
1985 params.insert(
1986 "szip_block_offsets".to_string(),
1987 ciborium::Value::Array(vec![0u64, 1].into_iter().map(|n| n.into()).collect()),
1988 );
1989 let desc = DataObjectDescriptor {
1990 obj_type: "ntensor".to_string(),
1991 ndim: 1,
1992 shape: vec![2],
1993 strides: vec![1],
1994 dtype: Dtype::Float32,
1995 byte_order: ByteOrder::native(),
1996 encoding: "none".to_string(),
1997 filter: "none".to_string(),
1998 compression: "szip".to_string(),
1999 params,
2000 hash: None,
2001 };
2002
2003 assert!(validate_no_szip_offsets_for_non_szip(&desc).is_ok());
2004 }
2005}