1use std::collections::BTreeMap;
10
11use crate::dtype::Dtype;
12use crate::error::{Result, TensogramError};
13use crate::framing::{self, EncodedObject};
14use crate::metadata::RESERVED_KEY;
15use crate::substitute_and_mask::{self, MaskSet};
16use crate::types::{DataObjectDescriptor, GlobalMetadata, MaskDescriptor, MasksMetadata};
17pub use tensogram_encodings::bitmask::MaskMethod;
18#[cfg(feature = "blosc2")]
19use tensogram_encodings::pipeline::Blosc2Codec;
20#[cfg(feature = "sz3")]
21use tensogram_encodings::pipeline::Sz3ErrorBound;
22#[cfg(feature = "zfp")]
23use tensogram_encodings::pipeline::ZfpMode;
24use tensogram_encodings::pipeline::{
25 self, ByteOrder, CompressionType, EncodingType, FilterType, PipelineConfig,
26};
27use tensogram_encodings::simple_packing::{self, SimplePackingParams};
28
29#[derive(Debug, Clone)]
31pub struct EncodeOptions {
32 pub hashing: bool,
46 pub compression_backend: pipeline::CompressionBackend,
54 pub threads: u32,
71 pub parallel_threshold_bytes: Option<usize>,
78 pub allow_nan: bool,
85 pub allow_inf: bool,
92 pub nan_mask_method: MaskMethod,
98 pub pos_inf_mask_method: MaskMethod,
101 pub neg_inf_mask_method: MaskMethod,
104 pub small_mask_threshold_bytes: usize,
110 pub aggregate_hash: AggregateHashPolicy,
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
143pub enum AggregateHashPolicy {
144 #[default]
147 Auto,
148 None,
151 Header,
154 Footer,
156 Both,
160}
161
162impl AggregateHashPolicy {
163 pub(crate) fn resolved_buffered(self) -> Self {
166 match self {
167 AggregateHashPolicy::Auto => AggregateHashPolicy::Header,
168 other => other,
169 }
170 }
171
172 pub(crate) fn resolved_streaming(self) -> Result<Self> {
176 match self {
177 AggregateHashPolicy::Auto => Ok(AggregateHashPolicy::Footer),
178 AggregateHashPolicy::None => Ok(AggregateHashPolicy::None),
179 AggregateHashPolicy::Footer => Ok(AggregateHashPolicy::Footer),
180 AggregateHashPolicy::Header => Err(TensogramError::Encoding(
181 "AggregateHashPolicy::Header is not supported in streaming mode \
182 — the header is written before any data object, so per-object \
183 hashes are not yet known. Use Auto (defaults to Footer in \
184 streaming) or Footer explicitly."
185 .to_string(),
186 )),
187 AggregateHashPolicy::Both => Err(TensogramError::Encoding(
188 "AggregateHashPolicy::Both is not supported in streaming mode \
189 — the header is written before any data object, so per-object \
190 hashes are not yet known. Use Auto (defaults to Footer in \
191 streaming) or Footer explicitly."
192 .to_string(),
193 )),
194 }
195 }
196
197 pub(crate) fn emits_header(self) -> bool {
201 matches!(
202 self,
203 AggregateHashPolicy::Header | AggregateHashPolicy::Both
204 )
205 }
206
207 pub(crate) fn emits_footer(self) -> bool {
211 matches!(
212 self,
213 AggregateHashPolicy::Footer | AggregateHashPolicy::Both
214 )
215 }
216}
217
218impl Default for EncodeOptions {
219 fn default() -> Self {
220 Self {
221 hashing: true,
222 compression_backend: pipeline::CompressionBackend::default(),
223 threads: 0,
224 parallel_threshold_bytes: None,
225 allow_nan: false,
226 allow_inf: false,
227 nan_mask_method: MaskMethod::default(),
228 pos_inf_mask_method: MaskMethod::default(),
229 neg_inf_mask_method: MaskMethod::default(),
230 small_mask_threshold_bytes: 128,
231 aggregate_hash: AggregateHashPolicy::Auto,
232 }
233 }
234}
235
236pub(crate) fn validate_object(desc: &DataObjectDescriptor, data_len: usize) -> Result<()> {
237 if desc.obj_type.is_empty() {
238 return Err(TensogramError::Metadata(
239 "obj_type must not be empty".to_string(),
240 ));
241 }
242 if desc.ndim as usize != desc.shape.len() {
243 return Err(TensogramError::Metadata(format!(
244 "ndim {} does not match shape.len() {}",
245 desc.ndim,
246 desc.shape.len()
247 )));
248 }
249 if desc.strides.len() != desc.shape.len() {
250 return Err(TensogramError::Metadata(format!(
251 "strides.len() {} does not match shape.len() {}",
252 desc.strides.len(),
253 desc.shape.len()
254 )));
255 }
256 if desc.encoding == "none" {
257 let product = desc
258 .shape
259 .iter()
260 .try_fold(1u64, |acc, &x| acc.checked_mul(x))
261 .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
262 if desc.dtype.byte_width() > 0 {
263 let expected_bytes = product
264 .checked_mul(desc.dtype.byte_width() as u64)
265 .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
266 if expected_bytes != data_len as u64 {
267 return Err(TensogramError::Metadata(format!(
268 "data_len {data_len} does not match expected {expected_bytes} bytes from shape and dtype"
269 )));
270 }
271 } else if desc.dtype == crate::Dtype::Bitmask {
272 let expected_bytes = product.div_ceil(8);
274 if expected_bytes != data_len as u64 {
275 return Err(TensogramError::Metadata(format!(
276 "data_len {data_len} does not match expected {expected_bytes} bytes for bitmask (ceil({product}/8))"
277 )));
278 }
279 }
280 }
281 if let Some(masks) = &desc.masks {
286 validate_mask_params(masks)?;
287 }
288 Ok(())
289}
290
291fn mask_method_allowed_params(method: &str) -> Option<&'static [&'static str]> {
301 match method {
302 "none" | "rle" | "roaring" | "lz4" => Some(&[]),
303 "zstd" => Some(&["level"]),
304 "blosc2" => Some(&["codec", "level"]),
305 _ => None,
306 }
307}
308
309fn validate_mask_descriptor(kind: &str, md: &MaskDescriptor) -> Result<()> {
310 let allowed = mask_method_allowed_params(&md.method).ok_or_else(|| {
311 TensogramError::Metadata(format!(
312 "mask {kind} has unknown method {method:?}; \
313 expected one of: none, rle, roaring, lz4, zstd, blosc2",
314 kind = kind,
315 method = md.method,
316 ))
317 })?;
318 for k in md.params.keys() {
319 if !allowed.contains(&k.as_str()) {
320 return Err(TensogramError::Metadata(format!(
321 "mask {kind} (method {method:?}) has unknown param {key:?}; \
322 allowed for this method: {allowed:?}",
323 kind = kind,
324 method = md.method,
325 key = k,
326 allowed = allowed,
327 )));
328 }
329 }
330 Ok(())
331}
332
333fn validate_mask_params(masks: &MasksMetadata) -> Result<()> {
334 if let Some(md) = &masks.nan {
335 validate_mask_descriptor("nan", md)?;
336 }
337 if let Some(md) = &masks.pos_inf {
338 validate_mask_descriptor("inf+", md)?;
339 }
340 if let Some(md) = &masks.neg_inf {
341 validate_mask_descriptor("inf-", md)?;
342 }
343 Ok(())
344}
345
346#[derive(Debug, Clone, Copy)]
347enum EncodeMode {
348 Raw,
349 PreEncoded,
350}
351
352fn encode_one_object(
360 desc: &DataObjectDescriptor,
361 data: &[u8],
362 mode: EncodeMode,
363 options: &EncodeOptions,
364 intra_codec_threads: u32,
365) -> Result<EncodedObject> {
366 validate_object(desc, data.len())?;
367
368 let (pipeline_input, mask_set) = if matches!(mode, EncodeMode::Raw) {
381 let parallel = crate::parallel::should_parallelise(
382 intra_codec_threads,
383 data.len(),
384 options.parallel_threshold_bytes,
385 );
386 let (cow, masks) = substitute_and_mask::substitute_and_mask(
387 data,
388 desc.dtype,
389 desc.byte_order,
390 options.allow_nan,
391 options.allow_inf,
392 parallel,
393 )?;
394 (cow, masks)
395 } else {
396 (std::borrow::Cow::Borrowed(data), MaskSet::empty(0))
397 };
398
399 let num_elements = desc.num_elements()?;
400 let dtype = desc.dtype;
401
402 let mut final_desc = desc.clone();
418 if matches!(mode, EncodeMode::Raw) {
419 resolve_simple_packing_params(&mut final_desc, data)?;
420 }
421
422 let mut config = build_pipeline_config_with_backend(
423 &final_desc,
424 num_elements,
425 dtype,
426 options.compression_backend,
427 intra_codec_threads,
428 )?;
429
430 let inline_hash_requested = matches!(mode, EncodeMode::Raw) && options.hashing;
442 config.compute_hash = inline_hash_requested;
443
444 let (encoded_payload, inline_hash) = match mode {
445 EncodeMode::Raw => {
446 let result = pipeline::encode_pipeline(pipeline_input.as_ref(), &config)
451 .map_err(|e| TensogramError::Encoding(e.to_string()))?;
452
453 if let Some(offsets) = &result.block_offsets {
455 final_desc.params.insert(
456 "szip_block_offsets".to_string(),
457 ciborium::Value::Array(
458 offsets
459 .iter()
460 .map(|&o| ciborium::Value::Integer(o.into()))
461 .collect(),
462 ),
463 );
464 }
465
466 (result.encoded_bytes, result.hash)
467 }
468 EncodeMode::PreEncoded => {
469 validate_no_szip_offsets_for_non_szip(desc)?;
474 if desc.compression == "szip" && desc.params.contains_key("szip_block_offsets") {
475 validate_szip_block_offsets(&desc.params, data.len())?;
476 }
477 (data.to_vec(), None)
478 }
479 };
480
481 let (payload_region, masks_metadata) = compose_payload_region(
496 encoded_payload,
497 mask_set,
498 &options.nan_mask_method,
499 &options.pos_inf_mask_method,
500 &options.neg_inf_mask_method,
501 options.small_mask_threshold_bytes,
502 )?;
503 if let Some(m) = masks_metadata {
504 final_desc.masks = Some(m);
505 }
506 let encoded_payload = payload_region;
507
508 let _ = (inline_hash, options);
522
523 Ok(EncodedObject {
524 descriptor: final_desc,
525 encoded_payload,
526 })
527}
528
529fn encode_inner(
530 global_metadata: &GlobalMetadata,
531 descriptors: &[(&DataObjectDescriptor, &[u8])],
532 options: &EncodeOptions,
533 mode: EncodeMode,
534) -> Result<Vec<u8>> {
535 let budget = crate::parallel::resolve_budget(options.threads)?;
545 let total_bytes: usize = descriptors.iter().map(|(_, d)| d.len()).sum();
546 let parallel =
547 crate::parallel::should_parallelise(budget, total_bytes, options.parallel_threshold_bytes);
548
549 let any_axis_b = descriptors
550 .iter()
551 .any(|(d, _)| crate::parallel::is_axis_b_friendly(&d.encoding, &d.filter, &d.compression));
552 let use_axis_a = parallel && crate::parallel::use_axis_a(descriptors.len(), budget, any_axis_b);
553
554 let intra_codec_threads = if parallel && !use_axis_a { budget } else { 0 };
558
559 let encode_one = |(desc, data): &(&DataObjectDescriptor, &[u8])| {
560 encode_one_object(desc, data, mode, options, intra_codec_threads)
561 };
562
563 let encoded_objects: Vec<EncodedObject> = if use_axis_a {
564 #[cfg(feature = "threads")]
568 {
569 use rayon::prelude::*;
570 crate::parallel::with_pool(budget, || {
571 descriptors
572 .par_iter()
573 .map(&encode_one)
574 .collect::<Result<Vec<_>>>()
575 })?
576 }
577 #[cfg(not(feature = "threads"))]
578 {
579 descriptors.iter().map(encode_one).collect::<Result<_>>()?
580 }
581 } else {
582 crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, || {
587 descriptors.iter().map(encode_one).collect::<Result<_>>()
588 })?
589 };
590
591 validate_no_client_reserved(global_metadata)?;
593
594 if global_metadata.base.len() > descriptors.len() {
598 return Err(TensogramError::Metadata(format!(
599 "metadata base has {} entries but only {} descriptors provided; \
600 extra base entries would be discarded",
601 global_metadata.base.len(),
602 descriptors.len()
603 )));
604 }
605
606 let mut enriched_meta = global_metadata.clone();
609 populate_base_entries(&mut enriched_meta.base, &encoded_objects);
610 populate_reserved_provenance(&mut enriched_meta.reserved);
611
612 let resolved = options.aggregate_hash.resolved_buffered();
617 let hash_policy = framing::HashFramePolicy {
618 header: resolved.emits_header(),
619 footer: resolved.emits_footer(),
620 };
621 framing::encode_message(
622 &enriched_meta,
623 &encoded_objects,
624 options.hashing,
625 hash_policy,
626 )
627}
628
629#[tracing::instrument(skip(global_metadata, descriptors, options), fields(objects = descriptors.len()))]
635pub fn encode(
636 global_metadata: &GlobalMetadata,
637 descriptors: &[(&DataObjectDescriptor, &[u8])],
638 options: &EncodeOptions,
639) -> Result<Vec<u8>> {
640 encode_inner(global_metadata, descriptors, options, EncodeMode::Raw)
641}
642
643#[tracing::instrument(name = "encode_pre_encoded", skip_all, fields(num_objects = descriptors.len()))]
662pub fn encode_pre_encoded(
663 global_metadata: &GlobalMetadata,
664 descriptors: &[(&DataObjectDescriptor, &[u8])],
665 options: &EncodeOptions,
666) -> Result<Vec<u8>> {
667 encode_inner(
668 global_metadata,
669 descriptors,
670 options,
671 EncodeMode::PreEncoded,
672 )
673}
674
675fn validate_no_client_reserved(meta: &GlobalMetadata) -> Result<()> {
680 if !meta.reserved.is_empty() {
681 return Err(TensogramError::Metadata(format!(
682 "client code must not write to '{RESERVED_KEY}' at message level; \
683 this field is populated by the library"
684 )));
685 }
686 for (i, entry) in meta.base.iter().enumerate() {
687 if entry.contains_key(RESERVED_KEY) {
688 return Err(TensogramError::Metadata(format!(
689 "client code must not write to '{RESERVED_KEY}' in base[{i}]; \
690 this field is populated by the library"
691 )));
692 }
693 }
694 Ok(())
695}
696
697pub(crate) fn populate_base_entries(
703 base: &mut Vec<BTreeMap<String, ciborium::Value>>,
704 encoded_objects: &[crate::framing::EncodedObject],
705) {
706 use ciborium::Value;
707
708 base.resize_with(encoded_objects.len(), BTreeMap::new);
710
711 for (entry, obj) in base.iter_mut().zip(encoded_objects.iter()) {
712 let desc = &obj.descriptor;
713
714 let tensor_map = Value::Map(vec![
715 (
716 Value::Text("ndim".to_string()),
717 Value::Integer(desc.ndim.into()),
718 ),
719 (
720 Value::Text("shape".to_string()),
721 Value::Array(
722 desc.shape
723 .iter()
724 .map(|&d| Value::Integer(d.into()))
725 .collect(),
726 ),
727 ),
728 (
729 Value::Text("strides".to_string()),
730 Value::Array(
731 desc.strides
732 .iter()
733 .map(|&s| Value::Integer(s.into()))
734 .collect(),
735 ),
736 ),
737 (
738 Value::Text("dtype".to_string()),
739 Value::Text(desc.dtype.to_string()),
740 ),
741 ]);
742
743 let reserved_map = Value::Map(vec![(Value::Text("tensor".to_string()), tensor_map)]);
744
745 entry.insert(RESERVED_KEY.to_string(), reserved_map);
746 }
747}
748
749pub(crate) fn populate_reserved_provenance(reserved: &mut BTreeMap<String, ciborium::Value>) {
760 use ciborium::Value;
761 #[cfg(not(target_arch = "wasm32"))]
762 use std::time::SystemTime;
763
764 let version_str = env!("CARGO_PKG_VERSION");
766 let encoder_map = Value::Map(vec![
767 (
768 Value::Text("name".to_string()),
769 Value::Text("tensogram".to_string()),
770 ),
771 (
772 Value::Text("version".to_string()),
773 Value::Text(version_str.to_string()),
774 ),
775 ]);
776 reserved.insert("encoder".to_string(), encoder_map);
777
778 #[cfg(not(target_arch = "wasm32"))]
783 {
784 let secs = SystemTime::now()
785 .duration_since(SystemTime::UNIX_EPOCH)
786 .unwrap_or_default()
787 .as_secs();
788
789 let days = secs / 86400;
792 let day_secs = secs % 86400;
793 let hours = day_secs / 3600;
794 let minutes = (day_secs % 3600) / 60;
795 let seconds = day_secs % 60;
796 let (y, m, d) = civil_from_days(days as i64);
798 let timestamp = format!("{y:04}-{m:02}-{d:02}T{hours:02}:{minutes:02}:{seconds:02}Z");
799 reserved.insert("time".to_string(), Value::Text(timestamp));
800 }
801
802 let id = uuid::Uuid::new_v4();
804 reserved.insert("uuid".to_string(), Value::Text(id.to_string()));
805}
806
807#[cfg(not(target_arch = "wasm32"))]
810fn civil_from_days(days: i64) -> (i64, u32, u32) {
811 let z = days + 719468;
812 let era = if z >= 0 { z } else { z - 146096 } / 146097;
813 let doe = (z - era * 146097) as u32;
816 let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
817 let y = yoe as i64 + era * 400;
818 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
819 let mp = (5 * doy + 2) / 153;
820 let d = doy - (153 * mp + 2) / 5 + 1;
821 let m = if mp < 10 { mp + 3 } else { mp - 9 };
822 let y = if m <= 2 { y + 1 } else { y };
823 (y, m, d)
824}
825
826pub(crate) fn build_pipeline_config(
827 desc: &DataObjectDescriptor,
828 num_values: usize,
829 dtype: Dtype,
830) -> Result<PipelineConfig> {
831 build_pipeline_config_with_backend(
832 desc,
833 num_values,
834 dtype,
835 pipeline::CompressionBackend::default(),
836 0,
837 )
838}
839
840fn resolve_encoding(desc: &DataObjectDescriptor, dtype: Dtype) -> Result<EncodingType> {
842 match desc.encoding.as_str() {
843 "none" => Ok(EncodingType::None),
844 "simple_packing" => {
845 if dtype != Dtype::Float64 {
850 return Err(TensogramError::Encoding(format!(
851 "simple_packing only supports float64 dtype; got {dtype:?}"
852 )));
853 }
854 let params = extract_simple_packing_params(&desc.params)?;
855 Ok(EncodingType::SimplePacking(params))
856 }
857 other => Err(TensogramError::Encoding(format!(
858 "unknown encoding: {other}"
859 ))),
860 }
861}
862
863fn resolve_filter(desc: &DataObjectDescriptor) -> Result<FilterType> {
865 match desc.filter.as_str() {
866 "none" => Ok(FilterType::None),
867 "shuffle" => {
868 let element_size = usize::try_from(get_u64_param(
869 &desc.params,
870 "shuffle_element_size",
871 )?)
872 .map_err(|_| {
873 TensogramError::Metadata("shuffle_element_size out of usize range".to_string())
874 })?;
875 Ok(FilterType::Shuffle { element_size })
876 }
877 other => Err(TensogramError::Encoding(format!("unknown filter: {other}"))),
878 }
879}
880
881fn resolve_compression(
885 desc: &DataObjectDescriptor,
886 dtype: Dtype,
887 encoding: &EncodingType,
888 filter: &FilterType,
889) -> Result<CompressionType> {
890 match desc.compression.as_str() {
891 "none" => Ok(CompressionType::None),
892 #[cfg(any(feature = "szip", feature = "szip-pure"))]
893 "szip" => {
894 let rsi = u32::try_from(get_u64_param(&desc.params, "szip_rsi")?)
895 .map_err(|_| TensogramError::Metadata("szip_rsi out of u32 range".to_string()))?;
896 let block_size = u32::try_from(get_u64_param(&desc.params, "szip_block_size")?)
897 .map_err(|_| {
898 TensogramError::Metadata("szip_block_size out of u32 range".to_string())
899 })?;
900 let flags = u32::try_from(get_u64_param(&desc.params, "szip_flags")?)
901 .map_err(|_| TensogramError::Metadata("szip_flags out of u32 range".to_string()))?;
902 let bits_per_sample = match (encoding, filter) {
903 (EncodingType::SimplePacking(params), _) => params.bits_per_value,
904 (EncodingType::None, FilterType::Shuffle { .. }) => 8,
905 (EncodingType::None, FilterType::None) => (dtype.byte_width() * 8) as u32,
906 };
907 Ok(CompressionType::Szip {
908 rsi,
909 block_size,
910 flags,
911 bits_per_sample,
912 })
913 }
914 #[cfg(any(feature = "zstd", feature = "zstd-pure"))]
915 "zstd" => {
916 let level_i64 = get_i64_param_or_default(&desc.params, "zstd_level", 3)?;
921 let level = i32::try_from(level_i64).map_err(|_| {
922 TensogramError::Metadata(format!("zstd_level value {level_i64} out of i32 range"))
923 })?;
924 Ok(CompressionType::Zstd { level })
925 }
926 #[cfg(feature = "lz4")]
927 "lz4" => Ok(CompressionType::Lz4),
928 #[cfg(feature = "blosc2")]
929 "blosc2" => {
930 let codec_str = get_text_param_or_default(&desc.params, "blosc2_codec", "lz4")?;
934 let codec = match codec_str {
935 "blosclz" => Blosc2Codec::Blosclz,
936 "lz4" => Blosc2Codec::Lz4,
937 "lz4hc" => Blosc2Codec::Lz4hc,
938 "zlib" => Blosc2Codec::Zlib,
939 "zstd" => Blosc2Codec::Zstd,
940 other => {
941 return Err(TensogramError::Encoding(format!(
942 "unknown blosc2 codec: {other}"
943 )));
944 }
945 };
946 let clevel_i64 = get_i64_param_or_default(&desc.params, "blosc2_clevel", 5)?;
949 let clevel = i32::try_from(clevel_i64).map_err(|_| {
950 TensogramError::Metadata(format!(
951 "blosc2_clevel value {clevel_i64} out of i32 range"
952 ))
953 })?;
954 let typesize = match (encoding, filter) {
955 (EncodingType::SimplePacking(params), _) => {
956 (params.bits_per_value as usize).div_ceil(8)
957 }
958 (EncodingType::None, FilterType::Shuffle { .. }) => 1,
959 (EncodingType::None, FilterType::None) => dtype.byte_width(),
960 };
961 Ok(CompressionType::Blosc2 {
962 codec,
963 clevel,
964 typesize,
965 })
966 }
967 #[cfg(feature = "zfp")]
968 "zfp" => {
969 let mode_str = match desc.params.get("zfp_mode") {
970 Some(ciborium::Value::Text(s)) => s.clone(),
971 _ => {
972 return Err(TensogramError::Metadata(
973 "missing required parameter: zfp_mode".to_string(),
974 ));
975 }
976 };
977 let mode = match mode_str.as_str() {
978 "fixed_rate" => {
979 let rate = get_f64_param(&desc.params, "zfp_rate")?;
980 ZfpMode::FixedRate { rate }
981 }
982 "fixed_precision" => {
983 let precision = u32::try_from(get_u64_param(&desc.params, "zfp_precision")?)
984 .map_err(|_| {
985 TensogramError::Metadata("zfp_precision out of u32 range".to_string())
986 })?;
987 ZfpMode::FixedPrecision { precision }
988 }
989 "fixed_accuracy" => {
990 let tolerance = get_f64_param(&desc.params, "zfp_tolerance")?;
991 ZfpMode::FixedAccuracy { tolerance }
992 }
993 other => {
994 return Err(TensogramError::Encoding(format!(
995 "unknown zfp_mode: {other}"
996 )));
997 }
998 };
999 Ok(CompressionType::Zfp { mode })
1000 }
1001 #[cfg(feature = "sz3")]
1002 "sz3" => {
1003 let mode_str = match desc.params.get("sz3_error_bound_mode") {
1004 Some(ciborium::Value::Text(s)) => s.clone(),
1005 _ => {
1006 return Err(TensogramError::Metadata(
1007 "missing required parameter: sz3_error_bound_mode".to_string(),
1008 ));
1009 }
1010 };
1011 let bound_val = get_f64_param(&desc.params, "sz3_error_bound")?;
1012 let error_bound = match mode_str.as_str() {
1013 "abs" => Sz3ErrorBound::Absolute(bound_val),
1014 "rel" => Sz3ErrorBound::Relative(bound_val),
1015 "psnr" => Sz3ErrorBound::Psnr(bound_val),
1016 other => {
1017 return Err(TensogramError::Encoding(format!(
1018 "unknown sz3_error_bound_mode: {other}"
1019 )));
1020 }
1021 };
1022 Ok(CompressionType::Sz3 { error_bound })
1023 }
1024 "rle" => {
1025 if dtype != Dtype::Bitmask {
1027 return Err(TensogramError::Encoding(format!(
1028 "compression \"rle\" only supports dtype=bitmask, got dtype={:?}",
1029 dtype
1030 )));
1031 }
1032 Ok(CompressionType::Rle)
1033 }
1034 "roaring" => {
1035 if dtype != Dtype::Bitmask {
1037 return Err(TensogramError::Encoding(format!(
1038 "compression \"roaring\" only supports dtype=bitmask, got dtype={:?}",
1039 dtype
1040 )));
1041 }
1042 Ok(CompressionType::Roaring)
1043 }
1044 other => Err(TensogramError::Encoding(format!(
1045 "unknown compression: {other}"
1046 ))),
1047 }
1048}
1049
1050pub(crate) fn build_pipeline_config_with_backend(
1056 desc: &DataObjectDescriptor,
1057 num_values: usize,
1058 dtype: Dtype,
1059 compression_backend: pipeline::CompressionBackend,
1060 intra_codec_threads: u32,
1061) -> Result<PipelineConfig> {
1062 let encoding = resolve_encoding(desc, dtype)?;
1063 let filter = resolve_filter(desc)?;
1064 let compression = resolve_compression(desc, dtype, &encoding, &filter)?;
1065
1066 Ok(PipelineConfig {
1067 encoding,
1068 filter,
1069 compression,
1070 num_values,
1071 byte_order: desc.byte_order,
1072 dtype_byte_width: dtype.byte_width(),
1073 swap_unit_size: dtype.swap_unit_size(),
1074 compression_backend,
1075 intra_codec_threads,
1076 compute_hash: false,
1081 })
1082}
1083
1084fn extract_simple_packing_params(
1085 params: &BTreeMap<String, ciborium::Value>,
1086) -> Result<SimplePackingParams> {
1087 let reference_value = get_f64_param(params, "sp_reference_value")?;
1088 if reference_value.is_nan() || reference_value.is_infinite() {
1089 return Err(TensogramError::Metadata(format!(
1090 "sp_reference_value must be finite, got {reference_value}"
1091 )));
1092 }
1093 Ok(SimplePackingParams {
1094 reference_value,
1095 binary_scale_factor: i32::try_from(get_i64_param(params, "sp_binary_scale_factor")?)
1096 .map_err(|_| {
1097 TensogramError::Metadata("sp_binary_scale_factor out of i32 range".to_string())
1098 })?,
1099 decimal_scale_factor: i32::try_from(get_i64_param(params, "sp_decimal_scale_factor")?)
1100 .map_err(|_| {
1101 TensogramError::Metadata("sp_decimal_scale_factor out of i32 range".to_string())
1102 })?,
1103 bits_per_value: u32::try_from(get_u64_param(params, "sp_bits_per_value")?).map_err(
1104 |_| TensogramError::Metadata("sp_bits_per_value out of u32 range".to_string()),
1105 )?,
1106 })
1107}
1108
1109pub(crate) fn resolve_simple_packing_params(
1141 desc: &mut DataObjectDescriptor,
1142 data_bytes: &[u8],
1143) -> Result<()> {
1144 if desc.encoding != "simple_packing" {
1145 return Ok(());
1146 }
1147
1148 if desc.dtype != Dtype::Float64 {
1153 return Err(TensogramError::Encoding(format!(
1154 "simple_packing only supports float64 dtype; got {:?}",
1155 desc.dtype
1156 )));
1157 }
1158
1159 if !desc.params.contains_key("sp_bits_per_value") {
1166 return Err(TensogramError::Metadata(
1167 "simple_packing requires sp_bits_per_value (the encoder can \
1168 auto-compute sp_reference_value + sp_binary_scale_factor from \
1169 the data, but the bit-width and decimal scale are the user \
1170 knobs). Provide at least sp_bits_per_value, or the full \
1171 explicit 4-key set."
1172 .to_string(),
1173 ));
1174 }
1175
1176 let has_ref = desc.params.contains_key("sp_reference_value");
1182 let has_bsf = desc.params.contains_key("sp_binary_scale_factor");
1183 if has_ref ^ has_bsf {
1184 let (set, missing) = if has_ref {
1185 ("sp_reference_value", "sp_binary_scale_factor")
1186 } else {
1187 ("sp_binary_scale_factor", "sp_reference_value")
1188 };
1189 return Err(TensogramError::Metadata(format!(
1190 "simple_packing: descriptor sets {set} but not {missing}. \
1191 Provide both for explicit-params encoding, or neither to \
1192 let the encoder auto-compute them from the data."
1193 )));
1194 }
1195
1196 if has_ref && has_bsf {
1201 desc.params
1202 .entry("sp_decimal_scale_factor".to_string())
1203 .or_insert(ciborium::Value::Integer(0i64.into()));
1204 return Ok(());
1205 }
1206
1207 let bits_per_value = u32::try_from(get_u64_param(&desc.params, "sp_bits_per_value")?)
1208 .map_err(|_| TensogramError::Metadata("sp_bits_per_value out of u32 range".to_string()))?;
1209 let decimal_scale_factor = i32::try_from(get_i64_param_or_default(
1214 &desc.params,
1215 "sp_decimal_scale_factor",
1216 0,
1217 )?)
1218 .map_err(|_| {
1219 TensogramError::Metadata("sp_decimal_scale_factor out of i32 range".to_string())
1220 })?;
1221
1222 let values = bytes_as_f64_vec(data_bytes, desc.byte_order)?;
1223 let params = simple_packing::compute_params(&values, bits_per_value, decimal_scale_factor)
1224 .map_err(|e| TensogramError::Encoding(e.to_string()))?;
1225
1226 desc.params.insert(
1227 "sp_reference_value".to_string(),
1228 ciborium::Value::Float(params.reference_value),
1229 );
1230 desc.params.insert(
1231 "sp_binary_scale_factor".to_string(),
1232 ciborium::Value::Integer(i64::from(params.binary_scale_factor).into()),
1233 );
1234 desc.params.insert(
1235 "sp_decimal_scale_factor".to_string(),
1236 ciborium::Value::Integer(i64::from(params.decimal_scale_factor).into()),
1237 );
1238 desc.params.insert(
1239 "sp_bits_per_value".to_string(),
1240 ciborium::Value::Integer(i64::from(params.bits_per_value).into()),
1241 );
1242 Ok(())
1243}
1244
1245fn bytes_as_f64_vec(bytes: &[u8], byte_order: ByteOrder) -> Result<Vec<f64>> {
1253 if !bytes.len().is_multiple_of(8) {
1254 return Err(TensogramError::Metadata(format!(
1255 "simple_packing: input byte length {} is not a multiple of 8 (float64)",
1256 bytes.len()
1257 )));
1258 }
1259 let n = bytes.len() / 8;
1260 let mut out: Vec<f64> = Vec::new();
1261 out.try_reserve_exact(n).map_err(|e| {
1262 TensogramError::Encoding(format!(
1263 "simple_packing: failed to reserve {} bytes for byte-to-f64 \
1264 conversion: {e}",
1265 n.saturating_mul(std::mem::size_of::<f64>()),
1266 ))
1267 })?;
1268 for chunk in bytes.chunks_exact(8) {
1269 let mut buf = [0u8; 8];
1270 buf.copy_from_slice(chunk);
1271 out.push(match byte_order {
1272 ByteOrder::Big => f64::from_be_bytes(buf),
1273 ByteOrder::Little => f64::from_le_bytes(buf),
1274 });
1275 }
1276 Ok(out)
1277}
1278
1279const F64_EXACT_INT_BOUND: i128 = 1 << 53;
1283
1284pub(crate) fn get_f64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<f64> {
1285 match params.get(key) {
1286 Some(ciborium::Value::Float(f)) => Ok(*f),
1287 Some(ciborium::Value::Integer(i)) => {
1288 let n: i128 = (*i).into();
1294 if n.abs() > F64_EXACT_INT_BOUND {
1295 return Err(TensogramError::Metadata(format!(
1296 "{key}: integer value {n} is outside the f64 \
1297 exact-representable range [-2^53, 2^53]; \
1298 converting to f64 would silently lose precision. \
1299 Supply a float literal or pick a parameter that \
1300 accepts integers up to i64::MAX."
1301 )));
1302 }
1303 Ok(n as f64)
1306 }
1307 Some(other) => Err(TensogramError::Metadata(format!(
1308 "expected number for {key}, got {kind}",
1309 kind = crate::metadata::cbor_value_kind(other),
1310 ))),
1311 None => Err(TensogramError::Metadata(format!(
1312 "missing required parameter: {key}"
1313 ))),
1314 }
1315}
1316
1317pub(crate) fn get_i64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<i64> {
1318 match params.get(key) {
1319 Some(ciborium::Value::Integer(i)) => {
1320 let n: i128 = (*i).into();
1321 i64::try_from(n).map_err(|_| {
1322 TensogramError::Metadata(format!("integer value {n} out of i64 range for {key}"))
1323 })
1324 }
1325 Some(other) => Err(TensogramError::Metadata(format!(
1326 "expected integer for {key}, got {kind}",
1327 kind = crate::metadata::cbor_value_kind(other),
1328 ))),
1329 None => Err(TensogramError::Metadata(format!(
1330 "missing required parameter: {key}"
1331 ))),
1332 }
1333}
1334
1335pub(crate) fn get_i64_param_or_default(
1347 params: &BTreeMap<String, ciborium::Value>,
1348 key: &str,
1349 default: i64,
1350) -> Result<i64> {
1351 match params.get(key) {
1352 Some(ciborium::Value::Integer(i)) => {
1353 let n: i128 = (*i).into();
1354 i64::try_from(n).map_err(|_| {
1355 TensogramError::Metadata(format!("integer value {n} out of i64 range for {key}"))
1356 })
1357 }
1358 Some(other) => Err(TensogramError::Metadata(format!(
1359 "expected integer for {key}, got {kind}; \
1360 if you meant to use the default ({default}), omit the key",
1361 kind = crate::metadata::cbor_value_kind(other),
1362 ))),
1363 None => Ok(default),
1364 }
1365}
1366
1367pub(crate) fn get_u64_param(params: &BTreeMap<String, ciborium::Value>, key: &str) -> Result<u64> {
1368 match params.get(key) {
1369 Some(ciborium::Value::Integer(i)) => {
1370 let n: i128 = (*i).into();
1371 u64::try_from(n).map_err(|_| {
1372 TensogramError::Metadata(format!("integer value {n} out of u64 range for {key}"))
1373 })
1374 }
1375 Some(other) => Err(TensogramError::Metadata(format!(
1376 "expected integer for {key}, got {kind}",
1377 kind = crate::metadata::cbor_value_kind(other),
1378 ))),
1379 None => Err(TensogramError::Metadata(format!(
1380 "missing required parameter: {key}"
1381 ))),
1382 }
1383}
1384
1385#[cfg(any(feature = "blosc2", test))]
1396pub(crate) fn get_text_param_or_default<'a>(
1397 params: &'a BTreeMap<String, ciborium::Value>,
1398 key: &str,
1399 default: &'a str,
1400) -> Result<&'a str> {
1401 match params.get(key) {
1402 Some(ciborium::Value::Text(s)) => Ok(s.as_str()),
1403 Some(other) => Err(TensogramError::Metadata(format!(
1404 "expected text for {key}, got {kind}; \
1405 if you meant to use the default ({default:?}), omit the key",
1406 kind = crate::metadata::cbor_value_kind(other),
1407 ))),
1408 None => Ok(default),
1409 }
1410}
1411
1412pub(crate) fn validate_szip_block_offsets(
1413 params: &BTreeMap<String, ciborium::Value>,
1414 encoded_bytes_len: usize,
1415) -> Result<()> {
1416 let value = params.get("szip_block_offsets").ok_or_else(|| {
1417 TensogramError::Metadata(
1418 "missing required parameter: szip_block_offsets for szip compression".to_string(),
1419 )
1420 })?;
1421
1422 let offsets = match value {
1423 ciborium::Value::Array(arr) => arr,
1424 other => {
1425 return Err(TensogramError::Metadata(format!(
1426 "szip_block_offsets must be an array, got {other:?}"
1427 )));
1428 }
1429 };
1430
1431 if offsets.is_empty() {
1432 return Err(TensogramError::Metadata(
1433 "szip_block_offsets must not be empty; first offset must be 0".to_string(),
1434 ));
1435 }
1436
1437 let bit_bound = encoded_bytes_len.checked_mul(8).ok_or_else(|| {
1438 TensogramError::Metadata(format!(
1439 "encoded byte length {encoded_bytes_len} overflows bit-bound calculation"
1440 ))
1441 })?;
1442 let bit_bound_u64 = u64::try_from(bit_bound).map_err(|_| {
1443 TensogramError::Metadata(format!(
1444 "bit-bound {bit_bound} derived from {encoded_bytes_len} bytes does not fit in u64"
1445 ))
1446 })?;
1447
1448 let mut parsed_offsets = Vec::with_capacity(offsets.len());
1449 for (idx, item) in offsets.iter().enumerate() {
1450 let offset = match item {
1451 ciborium::Value::Integer(i) => {
1452 let n: i128 = (*i).into();
1453 u64::try_from(n).map_err(|_| {
1454 TensogramError::Metadata(format!(
1455 "szip_block_offsets[{idx}] = {n} out of u64 range"
1456 ))
1457 })?
1458 }
1459 other => {
1460 return Err(TensogramError::Metadata(format!(
1461 "szip_block_offsets[{idx}] must be an integer, got {other:?}"
1462 )));
1463 }
1464 };
1465
1466 if offset > bit_bound_u64 {
1467 return Err(TensogramError::Metadata(format!(
1468 "szip_block_offsets[{idx}] = {offset} exceeds bit bound {bit_bound_u64} (encoded_bytes_len = {encoded_bytes_len} bytes, {bit_bound_u64} bits)"
1469 )));
1470 }
1471
1472 if idx == 0 {
1473 if offset != 0 {
1474 return Err(TensogramError::Metadata(format!(
1475 "szip_block_offsets[0] must be 0, got {offset}"
1476 )));
1477 }
1478 } else {
1479 let prev = parsed_offsets[idx - 1];
1480 if offset <= prev {
1481 return Err(TensogramError::Metadata(format!(
1482 "szip_block_offsets must be strictly increasing: szip_block_offsets[{}] = {}, szip_block_offsets[{idx}] = {offset}",
1483 idx - 1,
1484 prev
1485 )));
1486 }
1487 }
1488
1489 parsed_offsets.push(offset);
1490 }
1491
1492 Ok(())
1493}
1494
1495pub(crate) fn validate_no_szip_offsets_for_non_szip(desc: &DataObjectDescriptor) -> Result<()> {
1496 if desc.compression != "szip" && desc.params.contains_key("szip_block_offsets") {
1497 return Err(TensogramError::Metadata(format!(
1498 "szip_block_offsets provided but compression is '{}', not 'szip'",
1499 desc.compression
1500 )));
1501 }
1502 Ok(())
1503}
1504
1505pub(crate) fn compose_payload_region(
1533 mut encoded_payload: Vec<u8>,
1534 masks: MaskSet,
1535 nan_method: &MaskMethod,
1536 pos_inf_method: &MaskMethod,
1537 neg_inf_method: &MaskMethod,
1538 small_threshold: usize,
1539) -> Result<(Vec<u8>, Option<MasksMetadata>)> {
1540 if masks.is_empty() {
1541 return Ok((encoded_payload, None));
1542 }
1543
1544 let mut metadata = MasksMetadata::default();
1545 let mut region_cursor = encoded_payload.len() as u64;
1546
1547 let mut append_one =
1552 |bits_opt: Option<&Vec<bool>>, method: &MaskMethod| -> Result<Option<MaskDescriptor>> {
1553 let Some(bits) = bits_opt else {
1554 return Ok(None);
1555 };
1556 let (blob, used_method) = encode_one_mask(bits, method.clone(), small_threshold)?;
1557 let desc = MaskDescriptor {
1558 method: used_method.name().to_string(),
1559 offset: region_cursor,
1560 length: blob.len() as u64,
1561 params: mask_params_cbor(&used_method),
1562 };
1563 region_cursor += blob.len() as u64;
1564 encoded_payload.extend_from_slice(&blob);
1565 Ok(Some(desc))
1566 };
1567 metadata.nan = append_one(masks.nan.as_ref(), nan_method)?;
1568 metadata.pos_inf = append_one(masks.pos_inf.as_ref(), pos_inf_method)?;
1569 metadata.neg_inf = append_one(masks.neg_inf.as_ref(), neg_inf_method)?;
1570
1571 Ok((encoded_payload, Some(metadata)))
1572}
1573
1574fn encode_one_mask(
1580 bits: &[bool],
1581 requested: MaskMethod,
1582 small_threshold: usize,
1583) -> Result<(Vec<u8>, MaskMethod)> {
1584 use tensogram_encodings::bitmask;
1585
1586 let uncompressed_bytes = bits.len().div_ceil(8);
1590 let method = if small_threshold > 0 && uncompressed_bytes <= small_threshold {
1591 MaskMethod::None
1592 } else {
1593 requested
1594 };
1595
1596 let blob = match &method {
1597 MaskMethod::None => bitmask::codecs::encode_none(bits)
1598 .map_err(|e| TensogramError::Encoding(format!("bitmask pack: {e}")))?,
1599 MaskMethod::Rle => bitmask::rle::encode(bits),
1600 MaskMethod::Roaring => bitmask::roaring::encode(bits)
1601 .map_err(|e| TensogramError::Encoding(format!("roaring mask encode: {e}")))?,
1602 MaskMethod::Lz4 => bitmask::codecs::encode_lz4(bits)
1603 .map_err(|e| TensogramError::Encoding(format!("lz4 mask encode: {e}")))?,
1604 MaskMethod::Zstd { level } => bitmask::codecs::encode_zstd(bits, *level)
1605 .map_err(|e| TensogramError::Encoding(format!("zstd mask encode: {e}")))?,
1606 #[cfg(feature = "blosc2")]
1607 MaskMethod::Blosc2 { codec, level } => bitmask::codecs::encode_blosc2(bits, *codec, *level)
1608 .map_err(|e| TensogramError::Encoding(format!("blosc2 mask encode: {e}")))?,
1609 };
1610
1611 Ok((blob, method))
1612}
1613
1614fn mask_params_cbor(method: &MaskMethod) -> BTreeMap<String, ciborium::Value> {
1618 let mut params = BTreeMap::new();
1619 match method {
1620 MaskMethod::None | MaskMethod::Rle | MaskMethod::Roaring | MaskMethod::Lz4 => {}
1621 MaskMethod::Zstd { level } => {
1622 if let Some(l) = level {
1623 params.insert(
1624 "level".to_string(),
1625 ciborium::Value::Integer((*l as i64).into()),
1626 );
1627 }
1628 }
1629 #[cfg(feature = "blosc2")]
1630 MaskMethod::Blosc2 { codec, level } => {
1631 let codec_str = match codec {
1632 Blosc2Codec::Blosclz => "blosclz",
1633 Blosc2Codec::Lz4 => "lz4",
1634 Blosc2Codec::Lz4hc => "lz4hc",
1635 Blosc2Codec::Zlib => "zlib",
1636 Blosc2Codec::Zstd => "zstd",
1637 };
1638 params.insert(
1639 "codec".to_string(),
1640 ciborium::Value::Text(codec_str.to_string()),
1641 );
1642 params.insert(
1643 "level".to_string(),
1644 ciborium::Value::Integer((*level as i64).into()),
1645 );
1646 }
1647 }
1648 params
1649}
1650
1651#[cfg(test)]
1654mod tests {
1655 use super::*;
1656 use crate::decode::{DecodeOptions, decode};
1657 use crate::types::{ByteOrder, GlobalMetadata};
1658 use std::collections::BTreeMap;
1659
1660 fn make_descriptor(shape: Vec<u64>) -> DataObjectDescriptor {
1661 let strides = {
1662 let mut s = vec![1u64; shape.len()];
1663 for i in (0..shape.len().saturating_sub(1)).rev() {
1664 s[i] = s[i + 1] * shape[i + 1];
1665 }
1666 s
1667 };
1668 DataObjectDescriptor {
1669 obj_type: "ntensor".to_string(),
1670 ndim: shape.len() as u64,
1671 shape,
1672 strides,
1673 dtype: Dtype::Float32,
1674 byte_order: ByteOrder::native(),
1675 encoding: "none".to_string(),
1676 filter: "none".to_string(),
1677 compression: "none".to_string(),
1678 params: BTreeMap::new(),
1679 masks: None,
1680 }
1681 }
1682
1683 #[test]
1686 fn test_base_more_entries_than_descriptors_rejected() {
1687 let meta = GlobalMetadata {
1689 base: vec![
1690 BTreeMap::new(),
1691 BTreeMap::new(),
1692 BTreeMap::new(),
1693 BTreeMap::new(),
1694 BTreeMap::new(),
1695 ],
1696 ..Default::default()
1697 };
1698 let desc = make_descriptor(vec![4]);
1699 let data = vec![0u8; 16];
1700 let options = EncodeOptions {
1701 hashing: false,
1702 ..Default::default()
1703 };
1704 let result = encode(
1705 &meta,
1706 &[(&desc, data.as_slice()), (&desc, data.as_slice())],
1707 &options,
1708 );
1709 assert!(
1710 result.is_err(),
1711 "5 base entries with 2 descriptors should fail"
1712 );
1713 let err = result.unwrap_err().to_string();
1714 assert!(
1715 err.contains("5") && err.contains("2"),
1716 "error should mention counts: {err}"
1717 );
1718 }
1719
1720 #[test]
1721 fn test_base_fewer_entries_than_descriptors_auto_extended() {
1722 let meta = GlobalMetadata {
1724 base: vec![],
1725 ..Default::default()
1726 };
1727 let desc = make_descriptor(vec![2]);
1728 let data = vec![0u8; 8];
1729 let options = EncodeOptions {
1730 hashing: false,
1731 ..Default::default()
1732 };
1733 let msg = encode(
1734 &meta,
1735 &[
1736 (&desc, data.as_slice()),
1737 (&desc, data.as_slice()),
1738 (&desc, data.as_slice()),
1739 ],
1740 &options,
1741 )
1742 .unwrap();
1743
1744 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1745 assert_eq!(decoded.base.len(), 3);
1746 for entry in &decoded.base {
1748 assert!(
1749 entry.contains_key("_reserved_"),
1750 "auto-extended base entry should have _reserved_"
1751 );
1752 }
1753 }
1754
1755 #[test]
1756 fn test_base_entry_with_top_level_key_names_no_collision() {
1757 let mut entry = BTreeMap::new();
1759 entry.insert(
1760 "version".to_string(),
1761 ciborium::Value::Text("my-version".to_string()),
1762 );
1763 entry.insert(
1764 "base".to_string(),
1765 ciborium::Value::Text("not-the-real-base".to_string()),
1766 );
1767 let meta = GlobalMetadata {
1768 base: vec![entry],
1769 ..Default::default()
1770 };
1771 let desc = make_descriptor(vec![2]);
1772 let data = vec![0u8; 8];
1773 let options = EncodeOptions {
1774 hashing: false,
1775 ..Default::default()
1776 };
1777 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1778 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1779
1780 assert_eq!(
1785 decoded.base[0].get("version"),
1786 Some(&ciborium::Value::Text("my-version".to_string()))
1787 );
1788 assert_eq!(
1789 decoded.base[0].get("base"),
1790 Some(&ciborium::Value::Text("not-the-real-base".to_string()))
1791 );
1792 }
1793
1794 #[test]
1795 fn test_base_entry_with_deeply_nested_reserved_allowed() {
1796 let nested = ciborium::Value::Map(vec![(
1799 ciborium::Value::Text("_reserved_".to_string()),
1800 ciborium::Value::Text("nested-is-ok".to_string()),
1801 )]);
1802 let mut entry = BTreeMap::new();
1803 entry.insert("foo".to_string(), nested);
1804 let meta = GlobalMetadata {
1805 base: vec![entry],
1806 ..Default::default()
1807 };
1808 let desc = make_descriptor(vec![2]);
1809 let data = vec![0u8; 8];
1810 let options = EncodeOptions {
1811 hashing: false,
1812 ..Default::default()
1813 };
1814 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1816 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1817 let foo = decoded.base[0].get("foo").unwrap();
1819 if let ciborium::Value::Map(pairs) = foo {
1820 assert_eq!(pairs.len(), 1);
1821 } else {
1822 panic!("expected map for foo");
1823 }
1824 }
1825
1826 #[test]
1829 fn test_reserved_rejected_at_message_level() {
1830 let mut reserved = BTreeMap::new();
1831 reserved.insert(
1832 "rogue".to_string(),
1833 ciborium::Value::Text("bad".to_string()),
1834 );
1835 let meta = GlobalMetadata {
1836 reserved,
1837 ..Default::default()
1838 };
1839 let desc = make_descriptor(vec![2]);
1840 let data = vec![0u8; 8];
1841 let result = encode(
1842 &meta,
1843 &[(&desc, data.as_slice())],
1844 &EncodeOptions::default(),
1845 );
1846 assert!(result.is_err());
1847 let err = result.unwrap_err().to_string();
1848 assert!(
1849 err.contains("_reserved_") && err.contains("message level"),
1850 "error: {err}"
1851 );
1852 }
1853
1854 #[test]
1855 fn test_reserved_rejected_in_base_entry() {
1856 let mut entry = BTreeMap::new();
1857 entry.insert("_reserved_".to_string(), ciborium::Value::Map(vec![]));
1858 let meta = GlobalMetadata {
1859 base: vec![entry],
1860 ..Default::default()
1861 };
1862 let desc = make_descriptor(vec![2]);
1863 let data = vec![0u8; 8];
1864 let result = encode(
1865 &meta,
1866 &[(&desc, data.as_slice())],
1867 &EncodeOptions::default(),
1868 );
1869 assert!(result.is_err());
1870 let err = result.unwrap_err().to_string();
1871 assert!(
1872 err.contains("_reserved_") && err.contains("base[0]"),
1873 "error: {err}"
1874 );
1875 }
1876
1877 #[test]
1878 fn test_reserved_tensor_has_four_keys_after_encode() {
1879 let meta = GlobalMetadata::default();
1880 let desc = make_descriptor(vec![3, 4]);
1881 let data = vec![0u8; 3 * 4 * 4]; let options = EncodeOptions {
1883 hashing: false,
1884 ..Default::default()
1885 };
1886 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1887 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1888
1889 let reserved = decoded.base[0]
1890 .get("_reserved_")
1891 .expect("_reserved_ missing");
1892 if let ciborium::Value::Map(pairs) = reserved {
1893 let tensor_entry = pairs
1895 .iter()
1896 .find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1897 assert!(tensor_entry.is_some(), "missing tensor key in _reserved_");
1898 if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
1899 let keys: Vec<String> = tensor_pairs
1900 .iter()
1901 .filter_map(|(k, _)| {
1902 if let ciborium::Value::Text(s) = k {
1903 Some(s.clone())
1904 } else {
1905 None
1906 }
1907 })
1908 .collect();
1909 assert_eq!(keys.len(), 4, "tensor should have 4 keys, got: {keys:?}");
1910 assert!(keys.contains(&"ndim".to_string()));
1911 assert!(keys.contains(&"shape".to_string()));
1912 assert!(keys.contains(&"strides".to_string()));
1913 assert!(keys.contains(&"dtype".to_string()));
1914 } else {
1915 panic!("tensor is not a map");
1916 }
1917 } else {
1918 panic!("_reserved_ is not a map");
1919 }
1920 }
1921
1922 #[test]
1923 fn test_reserved_tensor_scalar_ndim_zero() {
1924 let desc = DataObjectDescriptor {
1926 obj_type: "ntensor".to_string(),
1927 ndim: 0,
1928 shape: vec![],
1929 strides: vec![],
1930 dtype: Dtype::Float32,
1931 byte_order: ByteOrder::native(),
1932 encoding: "none".to_string(),
1933 filter: "none".to_string(),
1934 compression: "none".to_string(),
1935 params: BTreeMap::new(),
1936 masks: None,
1937 };
1938 let data = vec![0u8; 4]; let meta = GlobalMetadata::default();
1940 let options = EncodeOptions {
1941 hashing: false,
1942 ..Default::default()
1943 };
1944 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
1945 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
1946
1947 let reserved = decoded.base[0]
1948 .get("_reserved_")
1949 .expect("_reserved_ missing");
1950 if let ciborium::Value::Map(pairs) = reserved {
1951 let tensor_entry = pairs
1952 .iter()
1953 .find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1954 if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
1955 let ndim = tensor_pairs
1957 .iter()
1958 .find(|(k, _)| *k == ciborium::Value::Text("ndim".to_string()));
1959 assert!(
1960 matches!(ndim, Some((_, ciborium::Value::Integer(i))) if i128::from(*i) == 0),
1961 "ndim should be 0 for scalar"
1962 );
1963 let shape = tensor_pairs
1965 .iter()
1966 .find(|(k, _)| *k == ciborium::Value::Text("shape".to_string()));
1967 assert!(
1968 matches!(shape, Some((_, ciborium::Value::Array(a))) if a.is_empty()),
1969 "shape should be [] for scalar"
1970 );
1971 } else {
1972 panic!("tensor missing or not a map");
1973 }
1974 } else {
1975 panic!("_reserved_ is not a map");
1976 }
1977 }
1978
1979 #[test]
1982 fn test_extra_with_keys_colliding_with_base_entry_keys() {
1983 let mut entry = BTreeMap::new();
1985 entry.insert(
1986 "mars".to_string(),
1987 ciborium::Value::Text("base-mars".to_string()),
1988 );
1989 let mut extra = BTreeMap::new();
1990 extra.insert(
1991 "mars".to_string(),
1992 ciborium::Value::Text("extra-mars".to_string()),
1993 );
1994 let meta = GlobalMetadata {
1995 base: vec![entry],
1996 extra,
1997 ..Default::default()
1998 };
1999 let desc = make_descriptor(vec![2]);
2000 let data = vec![0u8; 8];
2001 let options = EncodeOptions {
2002 hashing: false,
2003 ..Default::default()
2004 };
2005 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
2006 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
2007
2008 assert_eq!(
2009 decoded.base[0].get("mars"),
2010 Some(&ciborium::Value::Text("base-mars".to_string()))
2011 );
2012 assert_eq!(
2013 decoded.extra.get("mars"),
2014 Some(&ciborium::Value::Text("extra-mars".to_string()))
2015 );
2016 }
2017
2018 #[test]
2019 fn test_empty_extra_omitted_from_cbor() {
2020 let meta = GlobalMetadata {
2021 extra: BTreeMap::new(),
2022 ..Default::default()
2023 };
2024 let desc = make_descriptor(vec![2]);
2025 let data = vec![0u8; 8];
2026 let options = EncodeOptions {
2027 hashing: false,
2028 ..Default::default()
2029 };
2030 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
2031 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
2032 assert!(decoded.extra.is_empty());
2033 }
2034
2035 #[test]
2036 fn test_extra_with_nested_maps_round_trips() {
2037 let nested = ciborium::Value::Map(vec![
2038 (
2039 ciborium::Value::Text("key1".to_string()),
2040 ciborium::Value::Integer(42.into()),
2041 ),
2042 (
2043 ciborium::Value::Text("key2".to_string()),
2044 ciborium::Value::Map(vec![(
2045 ciborium::Value::Text("deep".to_string()),
2046 ciborium::Value::Bool(true),
2047 )]),
2048 ),
2049 ]);
2050 let mut extra = BTreeMap::new();
2051 extra.insert("nested".to_string(), nested.clone());
2052 let meta = GlobalMetadata {
2053 extra,
2054 ..Default::default()
2055 };
2056 let desc = make_descriptor(vec![2]);
2057 let data = vec![0u8; 8];
2058 let options = EncodeOptions {
2059 hashing: false,
2060 ..Default::default()
2061 };
2062 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
2063 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
2064 assert!(decoded.extra.contains_key("nested"));
2066 }
2067
2068 #[test]
2071 fn test_legacy_top_level_keys_routed_to_extra() {
2072 use ciborium::Value;
2078 let cbor = Value::Map(vec![
2079 (Value::Text("common".to_string()), Value::Map(vec![])),
2080 (Value::Text("payload".to_string()), Value::Array(vec![])),
2081 (Value::Text("version".to_string()), Value::Integer(3.into())),
2082 ]);
2083 let mut bytes = Vec::new();
2084 ciborium::into_writer(&cbor, &mut bytes).unwrap();
2085
2086 let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
2087 assert!(decoded.base.is_empty());
2088 assert!(decoded.reserved.is_empty());
2089 assert!(decoded.extra.contains_key("common"));
2090 assert!(decoded.extra.contains_key("payload"));
2091 assert_eq!(
2092 decoded.extra.get("version"),
2093 Some(&Value::Integer(3.into()))
2094 );
2095 }
2096
2097 #[test]
2098 fn test_old_reserved_key_name_routed_to_extra() {
2099 use ciborium::Value;
2104 let cbor = Value::Map(vec![(
2105 Value::Text("reserved".to_string()),
2106 Value::Map(vec![(
2107 Value::Text("rogue".to_string()),
2108 Value::Text("value".to_string()),
2109 )]),
2110 )]);
2111 let mut bytes = Vec::new();
2112 ciborium::into_writer(&cbor, &mut bytes).unwrap();
2113
2114 let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
2115 assert!(
2116 decoded.reserved.is_empty(),
2117 "legacy 'reserved' must NOT bleed into library-managed `_reserved_`"
2118 );
2119 assert!(
2120 decoded.extra.contains_key("reserved"),
2121 "legacy 'reserved' key must land in `_extra_`"
2122 );
2123 }
2124
2125 #[test]
2128 fn test_reserved_rejected_in_second_base_entry_only() {
2129 let mut entry0 = BTreeMap::new();
2131 entry0.insert("clean".to_string(), ciborium::Value::Text("ok".to_string()));
2132 let mut entry1 = BTreeMap::new();
2133 entry1.insert("_reserved_".to_string(), ciborium::Value::Map(vec![]));
2134 let meta = GlobalMetadata {
2135 base: vec![entry0, entry1],
2136 ..Default::default()
2137 };
2138 let desc = make_descriptor(vec![2]);
2139 let data = vec![0u8; 8];
2140 let result = encode(
2141 &meta,
2142 &[(&desc, data.as_slice()), (&desc, data.as_slice())],
2143 &EncodeOptions::default(),
2144 );
2145 assert!(result.is_err());
2146 let err = result.unwrap_err().to_string();
2147 assert!(
2148 err.contains("base[1]"),
2149 "error should mention base[1]: {err}"
2150 );
2151 }
2152
2153 #[test]
2154 fn test_reserved_accepted_when_all_base_entries_clean() {
2155 let mut e0 = BTreeMap::new();
2157 e0.insert(
2158 "key0".to_string(),
2159 ciborium::Value::Text("val0".to_string()),
2160 );
2161 let mut e1 = BTreeMap::new();
2162 e1.insert(
2163 "key1".to_string(),
2164 ciborium::Value::Text("val1".to_string()),
2165 );
2166 let meta = GlobalMetadata {
2167 base: vec![e0, e1],
2168 ..Default::default()
2169 };
2170 let desc = make_descriptor(vec![2]);
2171 let data = vec![0u8; 8];
2172 let options = EncodeOptions {
2173 hashing: false,
2174 ..Default::default()
2175 };
2176 let msg = encode(
2177 &meta,
2178 &[(&desc, data.as_slice()), (&desc, data.as_slice())],
2179 &options,
2180 )
2181 .unwrap();
2182 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
2183 assert_eq!(decoded.base.len(), 2);
2184 assert!(decoded.base[0].contains_key("key0"));
2185 assert!(decoded.base[1].contains_key("key1"));
2186 }
2187
2188 #[test]
2191 fn test_reserved_tensor_dtype_strings_for_all_dtypes() {
2192 let dtypes_and_expected = [
2194 (Dtype::Float16, "float16"),
2195 (Dtype::Bfloat16, "bfloat16"),
2196 (Dtype::Float32, "float32"),
2197 (Dtype::Float64, "float64"),
2198 (Dtype::Complex64, "complex64"),
2199 (Dtype::Complex128, "complex128"),
2200 (Dtype::Int8, "int8"),
2201 (Dtype::Int16, "int16"),
2202 (Dtype::Int32, "int32"),
2203 (Dtype::Int64, "int64"),
2204 (Dtype::Uint8, "uint8"),
2205 (Dtype::Uint16, "uint16"),
2206 (Dtype::Uint32, "uint32"),
2207 (Dtype::Uint64, "uint64"),
2208 ];
2209
2210 for (dtype, expected_str) in dtypes_and_expected {
2211 let byte_width = dtype.byte_width();
2212 let num_elements: u64 = 4;
2213 let data_len = num_elements as usize * byte_width;
2214
2215 let desc = DataObjectDescriptor {
2216 obj_type: "ntensor".to_string(),
2217 ndim: 1,
2218 shape: vec![num_elements],
2219 strides: vec![1],
2220 dtype,
2221 byte_order: ByteOrder::native(),
2222 encoding: "none".to_string(),
2223 filter: "none".to_string(),
2224 compression: "none".to_string(),
2225 params: BTreeMap::new(),
2226 masks: None,
2227 };
2228 let data = vec![0u8; data_len];
2229 let meta = GlobalMetadata::default();
2230 let options = EncodeOptions {
2231 hashing: false,
2232 ..Default::default()
2233 };
2234 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
2235 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
2236
2237 let reserved = decoded.base[0]
2238 .get("_reserved_")
2239 .unwrap_or_else(|| panic!("_reserved_ missing for dtype {dtype}"));
2240 if let ciborium::Value::Map(pairs) = reserved {
2241 let tensor_entry = pairs
2242 .iter()
2243 .find(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
2244 if let Some((_, ciborium::Value::Map(tensor_pairs))) = tensor_entry {
2245 let dtype_val = tensor_pairs
2246 .iter()
2247 .find(|(k, _)| *k == ciborium::Value::Text("dtype".to_string()));
2248 assert!(
2249 matches!(
2250 dtype_val,
2251 Some((_, ciborium::Value::Text(s))) if s == expected_str
2252 ),
2253 "dtype for {dtype} should be '{expected_str}', got: {dtype_val:?}"
2254 );
2255 } else {
2256 panic!("tensor missing or not a map for dtype {dtype}");
2257 }
2258 } else {
2259 panic!("_reserved_ is not a map for dtype {dtype}");
2260 }
2261 }
2262 }
2263
2264 #[test]
2267 fn test_global_metadata_serde_all_fields_populated() {
2268 use ciborium::Value;
2270
2271 let mut base_entry = BTreeMap::new();
2272 base_entry.insert("key".to_string(), Value::Text("base_val".to_string()));
2273 let mut reserved = BTreeMap::new();
2274 reserved.insert("encoder".to_string(), Value::Text("test".to_string()));
2275 let mut extra = BTreeMap::new();
2276 extra.insert("custom".to_string(), Value::Integer(42.into()));
2277
2278 let meta = GlobalMetadata {
2279 base: vec![base_entry],
2280 reserved,
2281 extra,
2282 };
2283
2284 let cbor_bytes = crate::metadata::global_metadata_to_cbor(&meta).unwrap();
2286 let decoded: GlobalMetadata =
2287 crate::metadata::cbor_to_global_metadata(&cbor_bytes).unwrap();
2288 assert_eq!(decoded.base.len(), 1);
2289 assert_eq!(
2290 decoded.base[0].get("key"),
2291 Some(&Value::Text("base_val".to_string()))
2292 );
2293 assert!(decoded.reserved.contains_key("encoder"));
2294 assert_eq!(
2295 decoded.extra.get("custom"),
2296 Some(&Value::Integer(42.into()))
2297 );
2298 }
2299
2300 #[test]
2303 fn test_provenance_fields_present_after_encode() {
2304 let meta = GlobalMetadata::default();
2305 let desc = make_descriptor(vec![2]);
2306 let data = vec![0u8; 8];
2307 let options = EncodeOptions {
2308 hashing: false,
2309 ..Default::default()
2310 };
2311 let msg = encode(&meta, &[(&desc, data.as_slice())], &options).unwrap();
2312 let (decoded, _) = decode(&msg, &DecodeOptions::default()).unwrap();
2313
2314 assert!(decoded.reserved.contains_key("encoder"));
2316 assert!(decoded.reserved.contains_key("time"));
2317 assert!(decoded.reserved.contains_key("uuid"));
2318
2319 if let ciborium::Value::Map(pairs) = decoded.reserved.get("encoder").unwrap() {
2321 let has_name = pairs
2322 .iter()
2323 .any(|(k, _)| *k == ciborium::Value::Text("name".to_string()));
2324 let has_version = pairs
2325 .iter()
2326 .any(|(k, _)| *k == ciborium::Value::Text("version".to_string()));
2327 assert!(has_name, "encoder map should have 'name' key");
2328 assert!(has_version, "encoder map should have 'version' key");
2329 } else {
2330 panic!("encoder should be a map");
2331 }
2332
2333 if let ciborium::Value::Text(uuid_str) = decoded.reserved.get("uuid").unwrap() {
2335 assert_eq!(uuid_str.len(), 36, "UUID should be 36 chars: {uuid_str}");
2336 assert_eq!(
2337 uuid_str.chars().filter(|c| *c == '-').count(),
2338 4,
2339 "UUID should have 4 hyphens: {uuid_str}"
2340 );
2341 } else {
2342 panic!("uuid should be a text");
2343 }
2344
2345 if let ciborium::Value::Text(time_str) = decoded.reserved.get("time").unwrap() {
2347 assert!(
2348 time_str.ends_with('Z'),
2349 "time should end with Z: {time_str}"
2350 );
2351 assert!(
2352 time_str.contains('T'),
2353 "time should contain T separator: {time_str}"
2354 );
2355 } else {
2356 panic!("time should be a text");
2357 }
2358 }
2359
2360 #[test]
2361 fn test_both_reserved_and_reserved_underscore_only_new_captured() {
2362 use ciborium::Value;
2364 let cbor = Value::Map(vec![
2365 (
2366 Value::Text("_reserved_".to_string()),
2367 Value::Map(vec![(
2368 Value::Text("encoder".to_string()),
2369 Value::Text("tensogram".to_string()),
2370 )]),
2371 ),
2372 (
2373 Value::Text("reserved".to_string()),
2374 Value::Map(vec![(
2375 Value::Text("old".to_string()),
2376 Value::Text("ignored".to_string()),
2377 )]),
2378 ),
2379 (Value::Text("version".to_string()), Value::Integer(3.into())),
2380 ]);
2381 let mut bytes = Vec::new();
2382 ciborium::into_writer(&cbor, &mut bytes).unwrap();
2383
2384 let decoded: GlobalMetadata = crate::metadata::cbor_to_global_metadata(&bytes).unwrap();
2385 assert!(decoded.reserved.contains_key("encoder"));
2386 assert!(!decoded.reserved.contains_key("old"));
2387 }
2388
2389 #[test]
2396 fn test_encode_pre_encoded_roundtrip_simple_packing() {
2397 let desc = make_descriptor(vec![4]);
2399 let raw_data: Vec<u8> = vec![0u8; 4 * 4]; let meta = GlobalMetadata::default();
2402 let options = EncodeOptions::default();
2403
2404 let msg1 = encode(&meta, &[(&desc, raw_data.as_slice())], &options).unwrap();
2406
2407 let (_, objects1) = decode(&msg1, &DecodeOptions::default()).unwrap();
2409 let (decoded_desc1, decoded_payload1) = &objects1[0];
2410
2411 let msg2 = encode_pre_encoded(
2413 &meta,
2414 &[(&decoded_desc1.clone(), decoded_payload1.as_slice())],
2415 &options,
2416 )
2417 .unwrap();
2418
2419 let (_, objects2) = decode(&msg2, &DecodeOptions::default()).unwrap();
2421 let (_, decoded_payload2) = &objects2[0];
2422
2423 assert_eq!(
2426 decoded_payload1, decoded_payload2,
2427 "decoded payloads should be equal after encode/re-encode roundtrip"
2428 );
2429 }
2430
2431 #[test]
2443 fn test_encode_pre_encoded_populates_inline_hash_slot() {
2444 use crate::framing::{decode_message, scan};
2445 use crate::hash::verify_frame_hash;
2446 use crate::wire::{FrameHeader, MessageFlags, Preamble};
2447
2448 let desc = make_descriptor(vec![2]);
2449 let data = vec![0xABu8; 8];
2450 let meta = GlobalMetadata::default();
2451 let options = EncodeOptions::default();
2452
2453 let msg = encode_pre_encoded(&meta, &[(&desc, data.as_slice())], &options).unwrap();
2454
2455 let preamble = Preamble::read_from(&msg).unwrap();
2457 assert!(preamble.flags.has(MessageFlags::HASHES_PRESENT));
2458
2459 let messages = scan(&msg);
2461 assert_eq!(messages.len(), 1);
2462 let (offset, len) = messages[0];
2463 let only_msg = &msg[offset..offset + len];
2464 let decoded = decode_message(only_msg).unwrap();
2465 for (_, _, _, frame_offset) in &decoded.objects {
2466 let frame = &only_msg[*frame_offset..];
2467 let fh = FrameHeader::read_from(frame).unwrap();
2468 let frame_bytes = &frame[..fh.total_length as usize];
2469 verify_frame_hash(frame_bytes, fh.frame_type, None)
2470 .expect("inline hash slot must verify against body");
2471 }
2472 }
2473
2474 #[test]
2475 fn test_validate_szip_block_offsets_happy_path() {
2476 let mut params = BTreeMap::new();
2477 params.insert(
2478 "szip_block_offsets".to_string(),
2479 ciborium::Value::Array(vec![0u64, 100, 200].into_iter().map(|n| n.into()).collect()),
2480 );
2481
2482 assert!(validate_szip_block_offsets(¶ms, 100).is_ok());
2483 }
2484
2485 #[test]
2486 fn test_validate_szip_block_offsets_missing_key() {
2487 let params = BTreeMap::new();
2488
2489 let err = validate_szip_block_offsets(¶ms, 100)
2490 .unwrap_err()
2491 .to_string();
2492 assert!(
2493 err.contains("missing") && err.contains("szip_block_offsets"),
2494 "error: {err}"
2495 );
2496 }
2497
2498 #[test]
2499 fn test_validate_szip_block_offsets_not_array() {
2500 let mut params = BTreeMap::new();
2501 params.insert(
2502 "szip_block_offsets".to_string(),
2503 ciborium::Value::Integer(0.into()),
2504 );
2505
2506 let err = validate_szip_block_offsets(¶ms, 100)
2507 .unwrap_err()
2508 .to_string();
2509 assert!(
2510 err.contains("array") && err.contains("szip_block_offsets"),
2511 "error: {err}"
2512 );
2513 }
2514
2515 #[test]
2516 fn test_validate_szip_block_offsets_non_integer_element() {
2517 let mut params = BTreeMap::new();
2518 params.insert(
2519 "szip_block_offsets".to_string(),
2520 ciborium::Value::Array(vec![
2521 ciborium::Value::Integer(0.into()),
2522 ciborium::Value::Text("x".to_string()),
2523 ]),
2524 );
2525
2526 let err = validate_szip_block_offsets(¶ms, 100)
2527 .unwrap_err()
2528 .to_string();
2529 assert!(
2530 err.contains("integer") && err.contains("szip_block_offsets"),
2531 "error: {err}"
2532 );
2533 }
2534
2535 #[test]
2536 fn test_validate_szip_block_offsets_nonzero_first() {
2537 let mut params = BTreeMap::new();
2538 params.insert(
2539 "szip_block_offsets".to_string(),
2540 ciborium::Value::Array(vec![5u64, 100, 200].into_iter().map(|n| n.into()).collect()),
2541 );
2542
2543 let err = validate_szip_block_offsets(¶ms, 100)
2544 .unwrap_err()
2545 .to_string();
2546 assert!(
2547 err.contains("must be 0") && err.contains("got 5"),
2548 "error: {err}"
2549 );
2550 }
2551
2552 #[test]
2553 fn test_validate_szip_block_offsets_non_monotonic() {
2554 let mut params = BTreeMap::new();
2555 params.insert(
2556 "szip_block_offsets".to_string(),
2557 ciborium::Value::Array(vec![0u64, 100, 50].into_iter().map(|n| n.into()).collect()),
2558 );
2559
2560 let err = validate_szip_block_offsets(¶ms, 100)
2561 .unwrap_err()
2562 .to_string();
2563 assert!(
2564 err.contains("increasing") || err.contains("monotonic"),
2565 "error: {err}"
2566 );
2567 }
2568
2569 #[test]
2570 fn test_validate_szip_block_offsets_offset_beyond_bound() {
2571 let mut params = BTreeMap::new();
2572 params.insert(
2573 "szip_block_offsets".to_string(),
2574 ciborium::Value::Array(vec![0u64, 100, 801].into_iter().map(|n| n.into()).collect()),
2575 );
2576
2577 let err = validate_szip_block_offsets(¶ms, 100)
2578 .unwrap_err()
2579 .to_string();
2580 assert!(err.contains("800") && err.contains("801"), "error: {err}");
2581 }
2582
2583 #[test]
2584 fn test_validate_no_szip_offsets_for_non_szip_rejects() {
2585 let mut params = BTreeMap::new();
2586 params.insert(
2587 "szip_block_offsets".to_string(),
2588 ciborium::Value::Array(vec![0u64, 1].into_iter().map(|n| n.into()).collect()),
2589 );
2590 let desc = DataObjectDescriptor {
2591 obj_type: "ntensor".to_string(),
2592 ndim: 1,
2593 shape: vec![2],
2594 strides: vec![1],
2595 dtype: Dtype::Float32,
2596 byte_order: ByteOrder::native(),
2597 encoding: "none".to_string(),
2598 filter: "none".to_string(),
2599 compression: "zstd".to_string(),
2600 params,
2601 masks: None,
2602 };
2603
2604 let err = validate_no_szip_offsets_for_non_szip(&desc)
2605 .unwrap_err()
2606 .to_string();
2607 assert!(
2608 err.contains("szip_block_offsets") && err.contains("zstd"),
2609 "error: {err}"
2610 );
2611 }
2612
2613 #[test]
2614 fn test_validate_no_szip_offsets_for_non_szip_allows_szip() {
2615 let mut params = BTreeMap::new();
2616 params.insert(
2617 "szip_block_offsets".to_string(),
2618 ciborium::Value::Array(vec![0u64, 1].into_iter().map(|n| n.into()).collect()),
2619 );
2620 let desc = DataObjectDescriptor {
2621 obj_type: "ntensor".to_string(),
2622 ndim: 1,
2623 shape: vec![2],
2624 strides: vec![1],
2625 dtype: Dtype::Float32,
2626 byte_order: ByteOrder::native(),
2627 encoding: "none".to_string(),
2628 filter: "none".to_string(),
2629 compression: "szip".to_string(),
2630 params,
2631 masks: None,
2632 };
2633
2634 assert!(validate_no_szip_offsets_for_non_szip(&desc).is_ok());
2635 }
2636
2637 #[test]
2640 fn aggregate_hash_policy_default_is_auto() {
2641 assert_eq!(AggregateHashPolicy::default(), AggregateHashPolicy::Auto);
2642 }
2643
2644 #[test]
2645 fn aggregate_hash_policy_buffered_resolves_auto_to_header() {
2646 assert_eq!(
2647 AggregateHashPolicy::Auto.resolved_buffered(),
2648 AggregateHashPolicy::Header
2649 );
2650 assert_eq!(
2652 AggregateHashPolicy::None.resolved_buffered(),
2653 AggregateHashPolicy::None
2654 );
2655 assert_eq!(
2656 AggregateHashPolicy::Footer.resolved_buffered(),
2657 AggregateHashPolicy::Footer
2658 );
2659 assert_eq!(
2660 AggregateHashPolicy::Both.resolved_buffered(),
2661 AggregateHashPolicy::Both
2662 );
2663 }
2664
2665 #[test]
2666 fn aggregate_hash_policy_streaming_rejects_header() {
2667 let err = AggregateHashPolicy::Header
2668 .resolved_streaming()
2669 .unwrap_err();
2670 assert!(matches!(err, TensogramError::Encoding(_)));
2671 }
2672
2673 #[test]
2674 fn aggregate_hash_policy_streaming_rejects_both() {
2675 let err = AggregateHashPolicy::Both.resolved_streaming().unwrap_err();
2676 assert!(matches!(err, TensogramError::Encoding(_)));
2677 }
2678
2679 #[test]
2680 fn aggregate_hash_policy_streaming_resolves_auto_to_footer() {
2681 assert_eq!(
2682 AggregateHashPolicy::Auto.resolved_streaming().unwrap(),
2683 AggregateHashPolicy::Footer
2684 );
2685 }
2686
2687 #[test]
2688 fn aggregate_hash_policy_streaming_accepts_explicit_footer_and_none() {
2689 assert_eq!(
2690 AggregateHashPolicy::Footer.resolved_streaming().unwrap(),
2691 AggregateHashPolicy::Footer
2692 );
2693 assert_eq!(
2694 AggregateHashPolicy::None.resolved_streaming().unwrap(),
2695 AggregateHashPolicy::None
2696 );
2697 }
2698
2699 #[test]
2700 fn aggregate_hash_policy_emits_flags() {
2701 assert!(AggregateHashPolicy::Header.emits_header());
2703 assert!(!AggregateHashPolicy::Header.emits_footer());
2704 assert!(!AggregateHashPolicy::Footer.emits_header());
2705 assert!(AggregateHashPolicy::Footer.emits_footer());
2706 assert!(AggregateHashPolicy::Both.emits_header());
2707 assert!(AggregateHashPolicy::Both.emits_footer());
2708 assert!(!AggregateHashPolicy::None.emits_header());
2709 assert!(!AggregateHashPolicy::None.emits_footer());
2710 }
2711
2712 #[test]
2715 fn get_i64_param_or_default_returns_default_on_absent() {
2716 let params = BTreeMap::new();
2717 assert_eq!(
2718 get_i64_param_or_default(¶ms, "zstd_level", 3).unwrap(),
2719 3
2720 );
2721 }
2722
2723 #[test]
2724 fn get_i64_param_or_default_returns_present_value() {
2725 let mut params = BTreeMap::new();
2726 params.insert(
2727 "zstd_level".to_string(),
2728 ciborium::Value::Integer(7i64.into()),
2729 );
2730 assert_eq!(
2731 get_i64_param_or_default(¶ms, "zstd_level", 3).unwrap(),
2732 7
2733 );
2734 }
2735
2736 #[test]
2737 fn get_i64_param_or_default_rejects_wrong_type() {
2738 let mut params = BTreeMap::new();
2744 params.insert(
2745 "zstd_level".to_string(),
2746 ciborium::Value::Text("high".to_string()),
2747 );
2748 let err = get_i64_param_or_default(¶ms, "zstd_level", 3).unwrap_err();
2749 match err {
2750 TensogramError::Metadata(msg) => {
2751 assert!(msg.contains("expected integer"), "msg: {msg}");
2752 assert!(msg.contains("zstd_level"), "msg: {msg}");
2753 assert!(msg.contains("default"), "msg: {msg}");
2754 }
2755 other => panic!("expected Metadata error, got: {other:?}"),
2756 }
2757 }
2758
2759 #[test]
2760 fn get_text_param_or_default_returns_default_on_absent() {
2761 let params = BTreeMap::new();
2762 assert_eq!(
2763 get_text_param_or_default(¶ms, "blosc2_codec", "lz4").unwrap(),
2764 "lz4"
2765 );
2766 }
2767
2768 #[test]
2769 fn get_text_param_or_default_returns_present_value() {
2770 let mut params = BTreeMap::new();
2771 params.insert(
2772 "blosc2_codec".to_string(),
2773 ciborium::Value::Text("zstd".to_string()),
2774 );
2775 assert_eq!(
2776 get_text_param_or_default(¶ms, "blosc2_codec", "lz4").unwrap(),
2777 "zstd"
2778 );
2779 }
2780
2781 #[test]
2782 fn get_text_param_or_default_rejects_wrong_type() {
2783 let mut params = BTreeMap::new();
2784 params.insert(
2785 "blosc2_codec".to_string(),
2786 ciborium::Value::Integer(5i64.into()),
2787 );
2788 let err = get_text_param_or_default(¶ms, "blosc2_codec", "lz4").unwrap_err();
2789 match err {
2790 TensogramError::Metadata(msg) => {
2791 assert!(msg.contains("expected text"), "msg: {msg}");
2792 assert!(msg.contains("blosc2_codec"), "msg: {msg}");
2793 }
2794 other => panic!("expected Metadata error, got: {other:?}"),
2795 }
2796 }
2797
2798 fn make_mask_desc(method: &str, params: BTreeMap<String, ciborium::Value>) -> MaskDescriptor {
2801 MaskDescriptor {
2802 method: method.to_string(),
2803 offset: 0,
2804 length: 1,
2805 params,
2806 }
2807 }
2808
2809 #[test]
2810 fn validate_mask_params_accepts_empty_for_paramless_methods() {
2811 for m in ["none", "rle", "roaring", "lz4"] {
2812 let masks = MasksMetadata {
2813 nan: Some(make_mask_desc(m, BTreeMap::new())),
2814 ..Default::default()
2815 };
2816 assert!(
2817 validate_mask_params(&masks).is_ok(),
2818 "method {m} must accept empty params"
2819 );
2820 }
2821 }
2822
2823 #[test]
2824 fn validate_mask_params_accepts_zstd_level() {
2825 let mut params = BTreeMap::new();
2826 params.insert("level".to_string(), ciborium::Value::Integer(3i64.into()));
2827 let masks = MasksMetadata {
2828 nan: Some(make_mask_desc("zstd", params)),
2829 ..Default::default()
2830 };
2831 assert!(validate_mask_params(&masks).is_ok());
2832 }
2833
2834 #[test]
2835 fn validate_mask_params_rejects_unknown_method() {
2836 let masks = MasksMetadata {
2837 nan: Some(make_mask_desc("snappy", BTreeMap::new())),
2838 ..Default::default()
2839 };
2840 let err = validate_mask_params(&masks).unwrap_err();
2841 match err {
2842 TensogramError::Metadata(msg) => {
2843 assert!(msg.contains("unknown method"), "msg: {msg}");
2844 assert!(msg.contains("snappy"), "msg: {msg}");
2845 assert!(msg.contains("expected one of"), "msg: {msg}");
2846 }
2847 other => panic!("expected Metadata error, got: {other:?}"),
2848 }
2849 }
2850
2851 #[test]
2852 fn validate_mask_params_rejects_unknown_param_for_paramless_method() {
2853 let mut params = BTreeMap::new();
2856 params.insert("level".to_string(), ciborium::Value::Integer(5i64.into()));
2857 let masks = MasksMetadata {
2858 pos_inf: Some(make_mask_desc("rle", params)),
2859 ..Default::default()
2860 };
2861 let err = validate_mask_params(&masks).unwrap_err();
2862 match err {
2863 TensogramError::Metadata(msg) => {
2864 assert!(msg.contains("unknown param"), "msg: {msg}");
2865 assert!(msg.contains("level"), "msg: {msg}");
2866 assert!(msg.contains("rle"), "msg: {msg}");
2867 assert!(msg.contains("inf+"), "kind tag missing: {msg}");
2868 }
2869 other => panic!("expected Metadata error, got: {other:?}"),
2870 }
2871 }
2872
2873 #[test]
2876 fn get_f64_param_accepts_integer_within_exact_range() {
2877 let mut params = BTreeMap::new();
2878 params.insert(
2880 "tol".to_string(),
2881 ciborium::Value::Integer((1i64 << 53).into()),
2882 );
2883 assert_eq!(get_f64_param(¶ms, "tol").unwrap(), (1u64 << 53) as f64);
2884 }
2885
2886 #[test]
2887 fn get_f64_param_rejects_integer_beyond_exact_range() {
2888 let mut params = BTreeMap::new();
2890 let too_big = i64::from((1u32 << 30) - 1) << 24; params.insert("tol".to_string(), ciborium::Value::Integer(too_big.into()));
2892 let err = get_f64_param(¶ms, "tol").unwrap_err();
2893 match err {
2894 TensogramError::Metadata(msg) => {
2895 assert!(msg.contains("exact-representable"), "msg: {msg}");
2896 assert!(msg.contains("tol"), "msg: {msg}");
2897 }
2898 other => panic!("expected Metadata error, got: {other:?}"),
2899 }
2900 }
2901
2902 #[test]
2903 fn get_f64_param_accepts_negative_integer_within_range() {
2904 let mut params = BTreeMap::new();
2905 params.insert(
2906 "tol".to_string(),
2907 ciborium::Value::Integer((-(1i64 << 53)).into()),
2908 );
2909 assert_eq!(
2910 get_f64_param(¶ms, "tol").unwrap(),
2911 -((1u64 << 53) as f64)
2912 );
2913 }
2914
2915 #[test]
2916 fn get_f64_param_rejects_large_negative_integer() {
2917 let mut params = BTreeMap::new();
2918 let too_neg = -(i64::from((1u32 << 30) - 1) << 24);
2919 params.insert("tol".to_string(), ciborium::Value::Integer(too_neg.into()));
2920 let err = get_f64_param(¶ms, "tol").unwrap_err();
2921 assert!(matches!(err, TensogramError::Metadata(_)));
2922 }
2923
2924 #[test]
2925 fn validate_mask_params_rejects_typo_param() {
2926 let mut params = BTreeMap::new();
2929 params.insert("levle".to_string(), ciborium::Value::Integer(3i64.into()));
2930 let masks = MasksMetadata {
2931 neg_inf: Some(make_mask_desc("zstd", params)),
2932 ..Default::default()
2933 };
2934 let err = validate_mask_params(&masks).unwrap_err();
2935 match err {
2936 TensogramError::Metadata(msg) => {
2937 assert!(msg.contains("unknown param"), "msg: {msg}");
2938 assert!(msg.contains("levle"), "msg: {msg}");
2939 assert!(msg.contains("zstd"), "msg: {msg}");
2940 }
2941 other => panic!("expected Metadata error, got: {other:?}"),
2942 }
2943 }
2944}