1use crate::encode::build_pipeline_config_with_backend;
10use crate::error::{Result, TensogramError};
11use crate::framing;
12use crate::hash;
13use crate::types::{DataObjectDescriptor, DecodedObject, GlobalMetadata};
14use tensogram_encodings::pipeline;
15
16fn extract_block_offsets(
17 params: &std::collections::BTreeMap<String, ciborium::Value>,
18) -> Result<Vec<u64>> {
19 match params.get("szip_block_offsets") {
20 Some(ciborium::Value::Array(arr)) => arr
21 .iter()
22 .map(|v| match v {
23 ciborium::Value::Integer(i) => {
24 let n: i128 = (*i).into();
25 u64::try_from(n).map_err(|_| {
26 TensogramError::Metadata("szip_block_offset out of u64 range".to_string())
27 })
28 }
29 _ => Err(TensogramError::Metadata(
30 "szip_block_offsets must contain integers".to_string(),
31 )),
32 })
33 .collect(),
34 Some(_) => Err(TensogramError::Metadata(
35 "szip_block_offsets must be an array".to_string(),
36 )),
37 None => Err(TensogramError::Compression(
38 "missing szip_block_offsets in payload metadata (required for partial range decode)"
39 .to_string(),
40 )),
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct DecodeOptions {
47 pub verify_hash: bool,
49 pub native_byte_order: bool,
54 pub compression_backend: pipeline::CompressionBackend,
57 pub threads: u32,
66 pub parallel_threshold_bytes: Option<usize>,
70}
71
72impl Default for DecodeOptions {
73 fn default() -> Self {
74 Self {
75 verify_hash: false,
76 native_byte_order: true,
77 compression_backend: pipeline::CompressionBackend::default(),
78 threads: 0,
79 parallel_threshold_bytes: None,
80 }
81 }
82}
83
84#[tracing::instrument(skip(buf, options), fields(buf_len = buf.len()))]
93pub fn decode(buf: &[u8], options: &DecodeOptions) -> Result<(GlobalMetadata, Vec<DecodedObject>)> {
94 let msg = framing::decode_message(buf)?;
95
96 let budget = crate::parallel::resolve_budget(options.threads);
97 let total_bytes: usize = msg.objects.iter().map(|(_, p, _)| p.len()).sum();
98 let parallel =
99 crate::parallel::should_parallelise(budget, total_bytes, options.parallel_threshold_bytes);
100 let any_axis_b = msg.objects.iter().any(|(d, _, _)| {
101 crate::parallel::is_axis_b_friendly(&d.encoding, &d.filter, &d.compression)
102 });
103 let use_axis_a = parallel && crate::parallel::use_axis_a(msg.objects.len(), budget, any_axis_b);
104 let intra_codec_threads = if parallel && !use_axis_a { budget } else { 0 };
105
106 let decode_one = |(desc, payload_bytes, _offset): &(DataObjectDescriptor, &[u8], usize)|
107 -> Result<DecodedObject> {
108 let decoded = decode_single_object_with_backend(
109 desc,
110 payload_bytes,
111 options,
112 options.compression_backend,
113 intra_codec_threads,
114 )?;
115 Ok((desc.clone(), decoded))
116 };
117
118 let data_objects: Vec<DecodedObject> = if use_axis_a {
119 #[cfg(feature = "threads")]
120 {
121 use rayon::prelude::*;
122 crate::parallel::with_pool(budget, || {
123 msg.objects
124 .par_iter()
125 .map(&decode_one)
126 .collect::<Result<Vec<_>>>()
127 })?
128 }
129 #[cfg(not(feature = "threads"))]
130 {
131 msg.objects.iter().map(decode_one).collect::<Result<_>>()?
132 }
133 } else {
134 crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, || {
135 msg.objects.iter().map(decode_one).collect::<Result<_>>()
136 })?
137 };
138
139 Ok((msg.global_metadata, data_objects))
140}
141
142pub fn decode_metadata(buf: &[u8]) -> Result<GlobalMetadata> {
144 framing::decode_metadata_only(buf)
145}
146
147pub fn decode_descriptors(buf: &[u8]) -> Result<(GlobalMetadata, Vec<DataObjectDescriptor>)> {
155 let msg = framing::decode_message(buf)?;
156 let descriptors = msg.objects.into_iter().map(|(desc, _, _)| desc).collect();
157 Ok((msg.global_metadata, descriptors))
158}
159
160pub fn decode_object(
163 buf: &[u8],
164 index: usize,
165 options: &DecodeOptions,
166) -> Result<(GlobalMetadata, DataObjectDescriptor, Vec<u8>)> {
167 let msg = framing::decode_message(buf)?;
168
169 if index >= msg.objects.len() {
170 return Err(TensogramError::Object(format!(
171 "object index {} out of range (num_objects={})",
172 index,
173 msg.objects.len()
174 )));
175 }
176
177 let (desc, payload_bytes, _) = &msg.objects[index];
178
179 let budget = crate::parallel::resolve_budget(options.threads);
182 let parallel = crate::parallel::should_parallelise(
183 budget,
184 payload_bytes.len(),
185 options.parallel_threshold_bytes,
186 );
187 let intra_codec_threads = if parallel { budget } else { 0 };
188
189 let decoded = crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, || {
190 decode_single_object_with_backend(
191 desc,
192 payload_bytes,
193 options,
194 options.compression_backend,
195 intra_codec_threads,
196 )
197 })?;
198
199 Ok((msg.global_metadata, desc.clone(), decoded))
200}
201
202pub fn decode_range(
210 buf: &[u8],
211 object_index: usize,
212 ranges: &[(u64, u64)],
213 options: &DecodeOptions,
214) -> Result<(DataObjectDescriptor, Vec<Vec<u8>>)> {
215 let msg = framing::decode_message(buf)?;
216
217 if object_index >= msg.objects.len() {
218 return Err(TensogramError::Object(format!(
219 "object index {} out of range (num_objects={})",
220 object_index,
221 msg.objects.len()
222 )));
223 }
224
225 let (desc, payload_bytes, _) = &msg.objects[object_index];
226 let parts = decode_range_from_payload(desc, payload_bytes, ranges, options)?;
227 Ok((desc.clone(), parts))
228}
229
230pub fn decode_range_from_payload(
231 desc: &DataObjectDescriptor,
232 payload_bytes: &[u8],
233 ranges: &[(u64, u64)],
234 options: &DecodeOptions,
235) -> Result<Vec<Vec<u8>>> {
236 if desc.filter != "none" {
237 return Err(TensogramError::Encoding(
238 "decode_range is not supported when a filter (e.g. shuffle) is applied".to_string(),
239 ));
240 }
241
242 if desc.dtype.byte_width() == 0 {
243 return Err(TensogramError::Encoding(
244 "partial range decode not supported for bitmask dtype".to_string(),
245 ));
246 }
247
248 if options.verify_hash
249 && let Some(ref hash_desc) = desc.hash
250 {
251 hash::verify_hash(payload_bytes, hash_desc)?;
252 }
253
254 let shape_product = desc
255 .shape
256 .iter()
257 .try_fold(1u64, |acc, &x| acc.checked_mul(x))
258 .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
259 let num_elements = usize::try_from(shape_product)
260 .map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))?;
261 let budget = crate::parallel::resolve_budget(options.threads);
267 let elem_bytes = desc.dtype.byte_width();
270 let total_bytes: usize = ranges
271 .iter()
272 .map(|(_, c)| (*c as usize).saturating_mul(elem_bytes))
273 .sum();
274 let parallel =
275 crate::parallel::should_parallelise(budget, total_bytes, options.parallel_threshold_bytes);
276 let axis_b_friendly =
277 crate::parallel::is_axis_b_friendly(&desc.encoding, &desc.filter, &desc.compression);
278 let use_axis_a = parallel && crate::parallel::use_axis_a(ranges.len(), budget, axis_b_friendly);
279 let intra_codec_threads = if parallel && !use_axis_a { budget } else { 0 };
280
281 let config = build_pipeline_config_with_backend(
282 desc,
283 num_elements,
284 desc.dtype,
285 options.compression_backend,
286 intra_codec_threads,
287 )?;
288
289 let block_offsets = if desc.compression == "szip" {
290 extract_block_offsets(&desc.params)?
291 } else {
292 Vec::new()
293 };
294
295 let decode_one = |offset: u64, count: u64| -> Result<Vec<u8>> {
296 pipeline::decode_range_pipeline(
297 payload_bytes,
298 &config,
299 &block_offsets,
300 offset,
301 count,
302 options.native_byte_order,
303 )
304 .map_err(|e| {
305 TensogramError::Encoding(format!("range (offset={offset}, count={count}): {e}"))
306 })
307 };
308
309 let run_seq = || -> Result<Vec<Vec<u8>>> {
310 ranges
311 .iter()
312 .map(|&(offset, count)| decode_one(offset, count))
313 .collect()
314 };
315
316 let results: Vec<Vec<u8>> = if use_axis_a {
317 #[cfg(feature = "threads")]
318 {
319 use rayon::prelude::*;
320 crate::parallel::with_pool(budget, || {
321 ranges
322 .par_iter()
323 .map(|&(offset, count)| decode_one(offset, count))
324 .collect::<Result<Vec<_>>>()
325 })?
326 }
327 #[cfg(not(feature = "threads"))]
328 {
329 run_seq()?
330 }
331 } else {
332 crate::parallel::run_maybe_pooled(budget, parallel, intra_codec_threads, run_seq)?
333 };
334
335 Ok(results)
336}
337
338#[cfg(feature = "remote")]
339pub(crate) fn decode_single_object(
340 desc: &DataObjectDescriptor,
341 payload_bytes: &[u8],
342 options: &DecodeOptions,
343) -> Result<Vec<u8>> {
344 decode_single_object_with_backend(desc, payload_bytes, options, options.compression_backend, 0)
345}
346
347fn decode_single_object_with_backend(
352 desc: &DataObjectDescriptor,
353 payload_bytes: &[u8],
354 options: &DecodeOptions,
355 backend: pipeline::CompressionBackend,
356 intra_codec_threads: u32,
357) -> Result<Vec<u8>> {
358 if options.verify_hash
359 && let Some(ref hash_desc) = desc.hash
360 {
361 hash::verify_hash(payload_bytes, hash_desc)?;
362 }
363
364 let shape_product = desc
365 .shape
366 .iter()
367 .try_fold(1u64, |acc, &x| acc.checked_mul(x))
368 .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
369 let num_elements = usize::try_from(shape_product)
370 .map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))?;
371 let config = build_pipeline_config_with_backend(
372 desc,
373 num_elements,
374 desc.dtype,
375 backend,
376 intra_codec_threads,
377 )?;
378 let decoded = pipeline::decode_pipeline(payload_bytes, &config, options.native_byte_order)
379 .map_err(|e| TensogramError::Encoding(e.to_string()))?;
380
381 Ok(decoded)
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use crate::dtype::Dtype;
388 use crate::encode::{EncodeOptions, encode};
389 use crate::types::ByteOrder;
390 use std::collections::BTreeMap;
391
392 fn make_global_meta() -> GlobalMetadata {
393 GlobalMetadata {
394 version: 2,
395 extra: BTreeMap::new(),
396 ..Default::default()
397 }
398 }
399
400 fn make_descriptor(shape: Vec<u64>) -> DataObjectDescriptor {
401 let strides = if shape.is_empty() {
402 vec![]
403 } else {
404 let mut s = vec![1u64; shape.len()];
405 for i in (0..shape.len() - 1).rev() {
406 s[i] = s[i + 1] * shape[i + 1];
407 }
408 s
409 };
410 DataObjectDescriptor {
411 obj_type: "ntensor".to_string(),
412 ndim: shape.len() as u64,
413 shape,
414 strides,
415 dtype: Dtype::Float32,
416 byte_order: ByteOrder::native(),
417 encoding: "none".to_string(),
418 filter: "none".to_string(),
419 compression: "none".to_string(),
420 params: BTreeMap::new(),
421 hash: None,
422 }
423 }
424
425 #[test]
428 fn test_decode_corrupt_message_bytes() {
429 let garbage = vec![0xDE, 0xAD, 0xBE, 0xEF, 0x00, 0x01, 0x02, 0x03];
431 let result = decode(&garbage, &DecodeOptions::default());
432 assert!(result.is_err(), "decoding garbage should fail");
433 }
434
435 #[test]
436 fn test_decode_truncated_message() {
437 let meta = make_global_meta();
439 let desc = make_descriptor(vec![4]);
440 let data = vec![0u8; 16];
441 let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
442
443 let truncated = &encoded[..encoded.len() / 2];
445 let result = decode(truncated, &DecodeOptions::default());
446 assert!(result.is_err(), "decoding truncated message should fail");
447 }
448
449 #[test]
450 fn test_decode_corrupted_cbor_in_message() {
451 let meta = make_global_meta();
455 let desc = make_descriptor(vec![4]);
456 let data = vec![42u8; 16];
457 let mut encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
458
459 let cbor_start = 40;
461 let corrupt_end = (cbor_start + 30).min(encoded.len());
462 for byte in &mut encoded[cbor_start..corrupt_end] {
463 *byte = 0xFF;
464 }
465
466 let result = decode(&encoded, &DecodeOptions::default());
467 assert!(result.is_err(), "decoding corrupted CBOR should fail");
469 }
470
471 #[test]
474 fn test_decode_object_index_out_of_range() {
475 let meta = make_global_meta();
476 let desc = make_descriptor(vec![4]);
477 let data = vec![0u8; 16];
478 let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
479
480 let result = decode_object(&encoded, 1, &DecodeOptions::default());
482 assert!(result.is_err());
483 let msg = result.unwrap_err().to_string();
484 assert!(
485 msg.contains("out of range"),
486 "expected 'out of range', got: {msg}"
487 );
488
489 let result = decode_object(&encoded, 999, &DecodeOptions::default());
491 assert!(result.is_err());
492 assert!(result.unwrap_err().to_string().contains("out of range"));
493 }
494
495 #[test]
496 fn test_decode_object_valid_index() {
497 let meta = make_global_meta();
498 let desc0 = make_descriptor(vec![2]);
499 let data0 = vec![10u8; 8];
500 let desc1 = make_descriptor(vec![3]);
501 let data1 = vec![20u8; 12];
502
503 let encoded = encode(
504 &meta,
505 &[(&desc0, data0.as_slice()), (&desc1, data1.as_slice())],
506 &EncodeOptions::default(),
507 )
508 .unwrap();
509
510 let (_, ret_desc, ret_data) =
512 decode_object(&encoded, 0, &DecodeOptions::default()).unwrap();
513 assert_eq!(ret_desc.shape, vec![2]);
514 assert_eq!(ret_data, data0);
515
516 let (_, ret_desc, ret_data) =
518 decode_object(&encoded, 1, &DecodeOptions::default()).unwrap();
519 assert_eq!(ret_desc.shape, vec![3]);
520 assert_eq!(ret_data, data1);
521 }
522
523 #[test]
526 fn test_decode_range_object_index_out_of_range() {
527 let meta = make_global_meta();
528 let desc = make_descriptor(vec![4]);
529 let data = vec![0u8; 16];
530 let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
531
532 let result = decode_range(&encoded, 5, &[(0, 2)], &DecodeOptions::default());
533 assert!(result.is_err());
534 let msg = result.unwrap_err().to_string();
535 assert!(
536 msg.contains("out of range"),
537 "expected 'out of range', got: {msg}"
538 );
539 }
540
541 #[test]
542 fn test_decode_range_exceeds_payload() {
543 let meta = make_global_meta();
544 let desc = make_descriptor(vec![4]); let data = vec![0u8; 16];
546 let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
547
548 let result = decode_range(&encoded, 0, &[(2, 10)], &DecodeOptions::default());
550 assert!(result.is_err(), "range exceeding payload should fail");
551 }
552
553 #[test]
554 fn test_decode_range_valid() {
555 let meta = make_global_meta();
556 let desc = make_descriptor(vec![8]); let data: Vec<u8> = (0..32).collect();
558 let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
559
560 let (ret_desc, parts) =
561 decode_range(&encoded, 0, &[(0, 4)], &DecodeOptions::default()).unwrap();
562 assert_eq!(ret_desc.shape, vec![8]);
563 assert_eq!(parts.len(), 1);
564 assert_eq!(parts[0].len(), 16); }
566
567 #[test]
568 fn test_decode_range_empty_ranges() {
569 let meta = make_global_meta();
570 let desc = make_descriptor(vec![4]);
571 let data = vec![0u8; 16];
572 let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
573
574 let (_, parts) = decode_range(&encoded, 0, &[], &DecodeOptions::default()).unwrap();
575 assert!(parts.is_empty());
576 }
577
578 #[test]
581 fn test_decode_metadata_valid() {
582 let meta = make_global_meta();
583 let desc = make_descriptor(vec![4]);
584 let data = vec![0u8; 16];
585 let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
586
587 let decoded_meta = decode_metadata(&encoded).unwrap();
588 assert_eq!(decoded_meta.version, 2);
589 }
590
591 #[test]
592 fn test_decode_metadata_corrupt() {
593 let garbage = vec![0xFF; 50];
594 let result = decode_metadata(&garbage);
595 assert!(result.is_err(), "decode_metadata on garbage should fail");
596 }
597
598 #[test]
601 fn test_decode_descriptors_valid() {
602 let meta = make_global_meta();
603 let desc0 = make_descriptor(vec![4]);
604 let desc1 = make_descriptor(vec![2, 3]);
605 let data0 = vec![0u8; 16];
606 let data1 = vec![0u8; 24];
607 let encoded = encode(
608 &meta,
609 &[(&desc0, data0.as_slice()), (&desc1, data1.as_slice())],
610 &EncodeOptions::default(),
611 )
612 .unwrap();
613
614 let (decoded_meta, descs) = decode_descriptors(&encoded).unwrap();
615 assert_eq!(decoded_meta.version, 2);
616 assert_eq!(descs.len(), 2);
617 assert_eq!(descs[0].shape, vec![4]);
618 assert_eq!(descs[1].shape, vec![2, 3]);
619 }
620
621 #[test]
624 fn test_decode_range_filter_shuffle_rejected() {
625 let meta = make_global_meta();
626 let mut desc = make_descriptor(vec![100]);
627 desc.filter = "shuffle".to_string();
628 desc.params.insert(
629 "shuffle_element_size".to_string(),
630 ciborium::Value::Integer(4.into()),
631 );
632 let data: Vec<u8> = (0..400).map(|i| (i % 256) as u8).collect();
633
634 let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
635
636 let result = decode_range(&encoded, 0, &[(0, 10)], &DecodeOptions::default());
637 assert!(result.is_err());
638 let msg = result.unwrap_err().to_string();
639 assert!(
640 msg.contains("filter") || msg.contains("shuffle"),
641 "expected filter/shuffle error, got: {msg}"
642 );
643 }
644
645 #[test]
648 fn test_decode_range_bitmask_dtype_rejected() {
649 let meta = make_global_meta();
650 let desc = DataObjectDescriptor {
651 obj_type: "ntensor".to_string(),
652 ndim: 1,
653 shape: vec![16],
654 strides: vec![1],
655 dtype: Dtype::Bitmask,
656 byte_order: ByteOrder::native(),
657 encoding: "none".to_string(),
658 filter: "none".to_string(),
659 compression: "none".to_string(),
660 params: BTreeMap::new(),
661 hash: None,
662 };
663 let data = vec![0xFF; 2]; let encoded = encode(&meta, &[(&desc, &data)], &EncodeOptions::default()).unwrap();
666
667 let result = decode_range(&encoded, 0, &[(0, 8)], &DecodeOptions::default());
668 assert!(result.is_err());
669 let msg = result.unwrap_err().to_string();
670 assert!(
671 msg.contains("bitmask"),
672 "expected bitmask error, got: {msg}"
673 );
674 }
675
676 #[test]
679 fn test_decode_options_defaults() {
680 let opts = DecodeOptions::default();
681 assert!(!opts.verify_hash);
682 assert!(opts.native_byte_order);
683 }
684
685 #[test]
688 fn test_decode_unknown_encoding_in_descriptor() {
689 let mut desc = make_descriptor(vec![4]);
693 desc.encoding = "foobar".to_string();
694
695 let result = crate::encode::build_pipeline_config_with_backend(
696 &desc,
697 4,
698 Dtype::Float32,
699 pipeline::CompressionBackend::default(),
700 0,
701 );
702 assert!(result.is_err());
703 let msg = result.unwrap_err().to_string();
704 assert!(
705 msg.contains("unknown encoding"),
706 "expected 'unknown encoding', got: {msg}"
707 );
708 }
709
710 #[test]
713 fn test_decode_unknown_compression_in_descriptor() {
714 let mut desc = make_descriptor(vec![4]);
715 desc.compression = "quantum_compress".to_string();
716
717 let result = crate::encode::build_pipeline_config_with_backend(
718 &desc,
719 4,
720 Dtype::Float32,
721 pipeline::CompressionBackend::default(),
722 0,
723 );
724 assert!(result.is_err());
725 let msg = result.unwrap_err().to_string();
726 assert!(
727 msg.contains("unknown compression"),
728 "expected 'unknown compression', got: {msg}"
729 );
730 }
731
732 #[test]
735 fn test_extract_block_offsets_missing() {
736 let params = BTreeMap::new();
737 let result = extract_block_offsets(¶ms);
738 assert!(result.is_err());
739 let msg = result.unwrap_err().to_string();
740 assert!(
741 msg.contains("szip_block_offsets"),
742 "expected szip_block_offsets error, got: {msg}"
743 );
744 }
745
746 #[test]
747 fn test_extract_block_offsets_wrong_type() {
748 let mut params = BTreeMap::new();
749 params.insert(
750 "szip_block_offsets".to_string(),
751 ciborium::Value::Text("not an array".to_string()),
752 );
753 let result = extract_block_offsets(¶ms);
754 assert!(result.is_err());
755 let msg = result.unwrap_err().to_string();
756 assert!(
757 msg.contains("must be an array"),
758 "expected 'must be an array', got: {msg}"
759 );
760 }
761
762 #[test]
763 fn test_extract_block_offsets_non_integer_elements() {
764 let mut params = BTreeMap::new();
765 params.insert(
766 "szip_block_offsets".to_string(),
767 ciborium::Value::Array(vec![
768 ciborium::Value::Float(1.5), ]),
770 );
771 let result = extract_block_offsets(¶ms);
772 assert!(result.is_err());
773 let msg = result.unwrap_err().to_string();
774 assert!(
775 msg.contains("integers"),
776 "expected integers error, got: {msg}"
777 );
778 }
779
780 #[test]
781 fn test_extract_block_offsets_valid() {
782 let mut params = BTreeMap::new();
783 params.insert(
784 "szip_block_offsets".to_string(),
785 ciborium::Value::Array(vec![
786 ciborium::Value::Integer(0.into()),
787 ciborium::Value::Integer(100.into()),
788 ciborium::Value::Integer(200.into()),
789 ]),
790 );
791 let result = extract_block_offsets(¶ms).unwrap();
792 assert_eq!(result, vec![0, 100, 200]);
793 }
794}