1use std::collections::BTreeMap;
10use std::io::Write;
11
12use crate::encode::{
13 EncodeOptions, build_pipeline_config, populate_base_entries, populate_reserved_provenance,
14 validate_no_szip_offsets_for_non_szip, validate_object, validate_szip_block_offsets,
15};
16use crate::error::{Result, TensogramError};
17use crate::framing::EncodedObject;
18use crate::hash::{HashAlgorithm, compute_hash};
19use crate::metadata::{self, RESERVED_KEY};
20use crate::types::{DataObjectDescriptor, GlobalMetadata, HashDescriptor, HashFrame, IndexFrame};
21use crate::wire::{
22 FRAME_END, FRAME_HEADER_SIZE, FrameHeader, FrameType, MessageFlags, PREAMBLE_SIZE, Postamble,
23 Preamble,
24};
25use tensogram_encodings::pipeline;
26
27pub struct StreamingEncoder<W: Write> {
51 writer: W,
52 object_offsets: Vec<u64>,
54 object_lengths: Vec<u64>,
56 hash_entries: Vec<Option<(String, String)>>,
58 completed_objects: Vec<EncodedObject>,
61 bytes_written: u64,
63 hash_algorithm: Option<HashAlgorithm>,
65 global_meta: GlobalMetadata,
67 pending_preceder: bool,
70 preceder_payloads: Vec<Option<BTreeMap<String, ciborium::Value>>>,
73 intra_codec_threads: u32,
78 parallel_threshold_bytes: Option<usize>,
80}
81
82impl<W: Write> StreamingEncoder<W> {
83 pub fn new(
88 mut writer: W,
89 global_meta: &GlobalMetadata,
90 options: &EncodeOptions,
91 ) -> Result<Self> {
92 let meta_cbor = metadata::global_metadata_to_cbor(global_meta)?;
93
94 let mut flags = MessageFlags::default();
98 flags.set(MessageFlags::HEADER_METADATA);
99 flags.set(MessageFlags::FOOTER_METADATA);
100 flags.set(MessageFlags::FOOTER_INDEX);
101 flags.set(MessageFlags::PRECEDER_METADATA);
102 if options.hash_algorithm.is_some() {
103 flags.set(MessageFlags::FOOTER_HASHES);
104 }
105
106 let preamble = Preamble {
107 version: 2,
108 flags,
109 reserved: 0,
110 total_length: 0,
111 };
112 let preamble_bytes = preamble_to_bytes(&preamble);
113 writer.write_all(&preamble_bytes)?;
114 let mut bytes_written = PREAMBLE_SIZE as u64;
115
116 let frame_bytes = build_frame(FrameType::HeaderMetadata, 1, 0, &meta_cbor);
118 writer.write_all(&frame_bytes)?;
119 bytes_written += frame_bytes.len() as u64;
120
121 write_padding(&mut writer, &mut bytes_written)?;
122
123 let intra_codec_threads = crate::parallel::resolve_budget(options.threads);
127
128 Ok(Self {
129 writer,
130 object_offsets: Vec::new(),
131 object_lengths: Vec::new(),
132 hash_entries: Vec::new(),
133 completed_objects: Vec::new(),
134 bytes_written,
135 hash_algorithm: options.hash_algorithm,
136 global_meta: global_meta.clone(),
137 pending_preceder: false,
138 preceder_payloads: Vec::new(),
139 intra_codec_threads,
140 parallel_threshold_bytes: options.parallel_threshold_bytes,
141 })
142 }
143
144 pub fn write_preceder(&mut self, metadata: BTreeMap<String, ciborium::Value>) -> Result<()> {
152 if self.pending_preceder {
153 return Err(TensogramError::Framing(
154 "write_preceder called twice without an intervening write_object/write_object_pre_encoded".to_string(),
155 ));
156 }
157
158 if metadata.contains_key(RESERVED_KEY) {
161 return Err(TensogramError::Metadata(format!(
162 "client code must not write '{RESERVED_KEY}' in preceder metadata; \
163 this field is populated by the library"
164 )));
165 }
166
167 let preceder_meta = GlobalMetadata {
168 version: self.global_meta.version,
169 base: vec![metadata.clone()],
170 ..Default::default()
171 };
172 let cbor = crate::metadata::global_metadata_to_cbor(&preceder_meta)?;
173 let frame_bytes = build_frame(FrameType::PrecederMetadata, 1, 0, &cbor);
174 self.writer.write_all(&frame_bytes)?;
175 self.bytes_written += frame_bytes.len() as u64;
176
177 write_padding(&mut self.writer, &mut self.bytes_written)?;
178
179 self.pending_preceder = true;
180 self.preceder_payloads.push(Some(metadata));
182 Ok(())
183 }
184
185 pub fn write_object(&mut self, desc: &DataObjectDescriptor, data: &[u8]) -> Result<()> {
197 validate_object(desc, data.len())?;
198
199 let shape_product = desc
200 .shape
201 .iter()
202 .try_fold(1u64, |acc, &x| acc.checked_mul(x))
203 .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
204 let num_elements = usize::try_from(shape_product)
205 .map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))?;
206
207 let parallel = crate::parallel::should_parallelise(
211 self.intra_codec_threads,
212 data.len(),
213 self.parallel_threshold_bytes,
214 );
215 let intra = if parallel {
216 self.intra_codec_threads
217 } else {
218 0
219 };
220
221 let config = crate::encode::build_pipeline_config_with_backend(
222 desc,
223 num_elements,
224 desc.dtype,
225 tensogram_encodings::pipeline::CompressionBackend::default(),
226 intra,
227 )?;
228
229 let result =
230 crate::parallel::run_maybe_pooled(self.intra_codec_threads, parallel, intra, || {
231 pipeline::encode_pipeline(data, &config)
232 })
233 .map_err(|e| TensogramError::Encoding(e.to_string()))?;
234
235 let mut final_desc = desc.clone();
237
238 if let Some(offsets) = &result.block_offsets {
239 final_desc.params.insert(
240 "szip_block_offsets".to_string(),
241 ciborium::Value::Array(
242 offsets
243 .iter()
244 .map(|&o| ciborium::Value::Integer(o.into()))
245 .collect(),
246 ),
247 );
248 }
249
250 self.write_object_inner(final_desc, &result.encoded_bytes)
251 }
252
253 #[tracing::instrument(skip(self, descriptor, pre_encoded_bytes))]
270 pub fn write_object_pre_encoded(
271 &mut self,
272 descriptor: &DataObjectDescriptor,
273 pre_encoded_bytes: &[u8],
274 ) -> Result<()> {
275 validate_object(descriptor, pre_encoded_bytes.len())?;
276
277 let shape_product = descriptor
278 .shape
279 .iter()
280 .try_fold(1u64, |acc, &x| acc.checked_mul(x))
281 .ok_or_else(|| TensogramError::Metadata("shape product overflow".to_string()))?;
282 let num_elements = usize::try_from(shape_product)
283 .map_err(|_| TensogramError::Metadata("element count overflows usize".to_string()))?;
284
285 build_pipeline_config(descriptor, num_elements, descriptor.dtype)?;
287
288 validate_no_szip_offsets_for_non_szip(descriptor)?;
290 if descriptor.compression == "szip" && descriptor.params.contains_key("szip_block_offsets")
291 {
292 validate_szip_block_offsets(&descriptor.params, pre_encoded_bytes.len())?;
293 }
294
295 self.write_object_inner(descriptor.clone(), pre_encoded_bytes)
296 }
297
298 fn write_object_inner(
304 &mut self,
305 mut final_desc: DataObjectDescriptor,
306 encoded_bytes: &[u8],
307 ) -> Result<()> {
308 let hash_entry = if let Some(algorithm) = self.hash_algorithm {
310 let hash_value = compute_hash(encoded_bytes, algorithm);
311 let hash_type = algorithm.as_str().to_string();
312 let entry = Some((hash_type.clone(), hash_value.clone()));
313 final_desc.hash = Some(HashDescriptor {
314 hash_type,
315 value: hash_value,
316 });
317 entry
318 } else {
319 None
320 };
321
322 let frame_bytes =
324 crate::framing::encode_data_object_frame(&final_desc, encoded_bytes, false)?;
325
326 self.object_offsets.push(self.bytes_written);
327 self.object_lengths.push(frame_bytes.len() as u64);
328 self.hash_entries.push(hash_entry);
329 self.completed_objects.push(EncodedObject {
333 descriptor: final_desc,
334 encoded_payload: Vec::new(),
335 });
336
337 if self.pending_preceder {
340 self.pending_preceder = false;
341 } else {
342 self.preceder_payloads.push(None);
343 }
344
345 self.writer.write_all(&frame_bytes)?;
347 self.bytes_written += frame_bytes.len() as u64;
348
349 write_padding(&mut self.writer, &mut self.bytes_written)?;
351
352 Ok(())
353 }
354
355 pub fn finish(mut self) -> Result<W> {
360 if self.pending_preceder {
361 return Err(TensogramError::Framing(
362 "dangling PrecederMetadata: finish called without a following write_object/write_object_pre_encoded"
363 .to_string(),
364 ));
365 }
366
367 let footer_start = self.bytes_written;
368
369 {
375 let mut enriched_meta = self.global_meta.clone();
376 populate_base_entries(&mut enriched_meta.base, &self.completed_objects);
377 populate_reserved_provenance(&mut enriched_meta.reserved);
378
379 if self.preceder_payloads.len() != self.completed_objects.len() {
384 return Err(TensogramError::Framing(format!(
385 "internal: preceder_payloads ({}) out of sync with completed_objects ({})",
386 self.preceder_payloads.len(),
387 self.completed_objects.len()
388 )));
389 }
390 for (i, prec) in self.preceder_payloads.iter().enumerate() {
391 if let Some(prec_map) = prec
392 && i < enriched_meta.base.len()
393 {
394 for (k, v) in prec_map {
395 enriched_meta.base[i].insert(k.clone(), v.clone());
396 }
397 }
398 }
399 let meta_cbor = metadata::global_metadata_to_cbor(&enriched_meta)?;
400 let frame_bytes = build_frame(FrameType::FooterMetadata, 1, 0, &meta_cbor);
401 self.writer.write_all(&frame_bytes)?;
402 self.bytes_written += frame_bytes.len() as u64;
403 write_padding(&mut self.writer, &mut self.bytes_written)?;
404 }
405
406 let has_hashes = self.hash_entries.iter().any(|e| e.is_some());
408 if has_hashes {
409 let hash_type = self
410 .hash_algorithm
411 .map(|a| a.as_str().to_string())
412 .unwrap_or_default();
413 let hashes: Vec<String> = self
414 .hash_entries
415 .iter()
416 .map(|e| e.as_ref().map(|(_, v)| v.clone()).unwrap_or_default())
417 .collect();
418 let hash_frame = HashFrame {
419 object_count: self.object_offsets.len() as u64,
420 hash_type,
421 hashes,
422 };
423 let hash_cbor = metadata::hash_frame_to_cbor(&hash_frame)?;
424 let frame_bytes = build_frame(FrameType::FooterHash, 1, 0, &hash_cbor);
425 self.writer.write_all(&frame_bytes)?;
426 self.bytes_written += frame_bytes.len() as u64;
427
428 write_padding(&mut self.writer, &mut self.bytes_written)?;
429 }
430
431 let index = IndexFrame {
433 object_count: self.object_offsets.len() as u64,
434 offsets: self.object_offsets,
435 lengths: self.object_lengths,
436 };
437 let index_cbor = metadata::index_to_cbor(&index)?;
438 let frame_bytes = build_frame(FrameType::FooterIndex, 1, 0, &index_cbor);
439 self.writer.write_all(&frame_bytes)?;
440 self.bytes_written += frame_bytes.len() as u64;
441
442 write_padding(&mut self.writer, &mut self.bytes_written)?;
443
444 let postamble = Postamble {
446 first_footer_offset: footer_start,
447 };
448 let mut postamble_bytes = Vec::with_capacity(16);
449 postamble.write_to(&mut postamble_bytes);
450 self.writer.write_all(&postamble_bytes)?;
451
452 self.writer.flush()?;
453
454 Ok(self.writer)
455 }
456
457 pub fn object_count(&self) -> usize {
459 self.object_offsets.len()
460 }
461
462 pub fn bytes_written(&self) -> u64 {
464 self.bytes_written
465 }
466}
467
468fn preamble_to_bytes(preamble: &Preamble) -> Vec<u8> {
471 let mut out = Vec::with_capacity(PREAMBLE_SIZE);
472 preamble.write_to(&mut out);
473 out
474}
475
476fn build_frame(frame_type: FrameType, version: u16, flags: u16, payload: &[u8]) -> Vec<u8> {
477 let total_length = (FRAME_HEADER_SIZE + payload.len() + FRAME_END.len()) as u64;
478 let fh = FrameHeader {
479 frame_type,
480 version,
481 flags,
482 total_length,
483 };
484 let mut out = Vec::with_capacity(total_length as usize);
485 fh.write_to(&mut out);
486 out.extend_from_slice(payload);
487 out.extend_from_slice(FRAME_END);
488 out
489}
490
491const ZERO_PAD: [u8; 7] = [0; 7];
492
493fn write_padding(writer: &mut impl Write, bytes_written: &mut u64) -> std::io::Result<()> {
494 let pad = (8 - (*bytes_written as usize % 8)) % 8;
495 if pad > 0 {
496 writer.write_all(&ZERO_PAD[..pad])?;
497 *bytes_written += pad as u64;
498 }
499 Ok(())
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use crate::Dtype;
506 use crate::decode::{DecodeOptions, decode};
507 use crate::encode::{EncodeOptions, encode};
508 use crate::types::{ByteOrder, DataObjectDescriptor};
509 use std::collections::BTreeMap;
510
511 fn make_descriptor(shape: Vec<u64>) -> DataObjectDescriptor {
512 let ndim = shape.len() as u64;
513 let mut strides = vec![0u64; shape.len()];
514 if !shape.is_empty() {
515 strides[shape.len() - 1] = 1;
516 for i in (0..shape.len() - 1).rev() {
517 strides[i] = strides[i + 1] * shape[i + 1];
518 }
519 }
520 DataObjectDescriptor {
521 obj_type: "ntensor".to_string(),
522 ndim,
523 shape,
524 strides,
525 dtype: Dtype::Float32,
526 byte_order: ByteOrder::native(),
527 encoding: "none".to_string(),
528 filter: "none".to_string(),
529 compression: "none".to_string(),
530 params: BTreeMap::new(),
531 hash: None,
532 }
533 }
534
535 #[test]
536 fn streaming_single_object_round_trip() {
537 let meta = GlobalMetadata::default();
538 let desc = make_descriptor(vec![4]);
539 let data = vec![0u8; 4 * 4];
540
541 let buf = Vec::new();
543 let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
544 enc.write_object(&desc, &data).unwrap();
545 let result = enc.finish().unwrap();
546
547 let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
549 assert_eq!(decoded_meta.version, 2);
550 assert_eq!(objects.len(), 1);
551 assert_eq!(objects[0].1, data);
552 }
553
554 #[test]
555 fn streaming_multi_object_round_trip() {
556 let meta = GlobalMetadata::default();
557 let desc1 = make_descriptor(vec![4]);
558 let desc2 = make_descriptor(vec![8]);
559 let data1 = vec![1u8; 4 * 4];
560 let data2 = vec![2u8; 8 * 4];
561
562 let buf = Vec::new();
563 let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
564 enc.write_object(&desc1, &data1).unwrap();
565 enc.write_object(&desc2, &data2).unwrap();
566 assert_eq!(enc.object_count(), 2);
567 let result = enc.finish().unwrap();
568
569 let (_, objects) = decode(&result, &DecodeOptions::default()).unwrap();
570 assert_eq!(objects.len(), 2);
571 assert_eq!(objects[0].1, data1);
572 assert_eq!(objects[1].1, data2);
573 }
574
575 #[test]
576 fn streaming_matches_buffered_single_object() {
577 let meta = GlobalMetadata::default();
578 let desc = make_descriptor(vec![4]);
579 let data = vec![42u8; 4 * 4];
580 let options = EncodeOptions {
581 compression_backend: Default::default(),
582 hash_algorithm: Some(HashAlgorithm::Xxh3),
583 emit_preceders: false,
584 threads: 0,
585 parallel_threshold_bytes: None,
586 };
587
588 let buffered = encode(&meta, &[(&desc, &data)], &options).unwrap();
590 let (buf_meta, buf_objects) = decode(
591 &buffered,
592 &DecodeOptions {
593 verify_hash: true,
594 ..Default::default()
595 },
596 )
597 .unwrap();
598
599 let buf = Vec::new();
601 let mut enc = StreamingEncoder::new(buf, &meta, &options).unwrap();
602 enc.write_object(&desc, &data).unwrap();
603 let streamed = enc.finish().unwrap();
604 let (str_meta, str_objects) = decode(
605 &streamed,
606 &DecodeOptions {
607 verify_hash: true,
608 ..Default::default()
609 },
610 )
611 .unwrap();
612
613 assert_eq!(buf_meta.version, str_meta.version);
615 assert_eq!(buf_objects.len(), str_objects.len());
616 assert_eq!(buf_objects[0].0.shape, str_objects[0].0.shape);
617 assert_eq!(buf_objects[0].0.dtype, str_objects[0].0.dtype);
618 assert_eq!(buf_objects[0].1, str_objects[0].1);
619 assert_eq!(
621 buf_objects[0].0.hash.as_ref().unwrap().value,
622 str_objects[0].0.hash.as_ref().unwrap().value
623 );
624 }
625
626 #[test]
627 fn streaming_hash_verification() {
628 let meta = GlobalMetadata::default();
629 let desc = make_descriptor(vec![4]);
630 let data = vec![42u8; 4 * 4];
631 let options = EncodeOptions {
632 compression_backend: Default::default(),
633 hash_algorithm: Some(HashAlgorithm::Xxh3),
634 emit_preceders: false,
635 threads: 0,
636 parallel_threshold_bytes: None,
637 };
638
639 let buf = Vec::new();
640 let mut enc = StreamingEncoder::new(buf, &meta, &options).unwrap();
641 enc.write_object(&desc, &data).unwrap();
642 let result = enc.finish().unwrap();
643
644 let verify_opts = DecodeOptions {
646 verify_hash: true,
647 ..Default::default()
648 };
649 let (_, objects) = decode(&result, &verify_opts).unwrap();
650 assert!(objects[0].0.hash.is_some());
651 }
652
653 #[test]
654 fn streaming_no_objects() {
655 let meta = GlobalMetadata::default();
656 let options = EncodeOptions {
657 compression_backend: Default::default(),
658 hash_algorithm: None,
659 emit_preceders: false,
660 threads: 0,
661 parallel_threshold_bytes: None,
662 };
663
664 let buf = Vec::new();
665 let enc = StreamingEncoder::new(buf, &meta, &options).unwrap();
666 assert_eq!(enc.object_count(), 0);
667 let result = enc.finish().unwrap();
668
669 let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
670 assert_eq!(decoded_meta.version, 2);
671 assert_eq!(objects.len(), 0);
672 }
673
674 #[test]
679 fn streaming_threads_byte_identical_transparent() {
680 let meta = GlobalMetadata::default();
681 let desc = make_descriptor(vec![50_000]);
683 let data: Vec<u8> = (0..50_000)
684 .flat_map(|i| (250.0f32 + (i as f32).sin() * 30.0).to_ne_bytes())
685 .collect();
686
687 let mk = |threads: u32| -> Vec<u8> {
688 let buf = Vec::new();
689 let opts = EncodeOptions {
690 threads,
691 parallel_threshold_bytes: Some(0), ..Default::default()
693 };
694 let mut enc = StreamingEncoder::new(buf, &meta, &opts).unwrap();
695 enc.write_object(&desc, &data).unwrap();
696 enc.finish().unwrap()
697 };
698
699 let payloads = |buf: &[u8]| -> Vec<Vec<u8>> {
701 crate::framing::decode_message(buf)
702 .unwrap()
703 .objects
704 .iter()
705 .map(|(_, p, _)| p.to_vec())
706 .collect()
707 };
708
709 let baseline = mk(0);
710 let payloads_baseline = payloads(&baseline);
711
712 for t in [1u32, 2, 4, 8] {
713 let got = mk(t);
714 assert_eq!(
715 payloads_baseline,
716 payloads(&got),
717 "streaming threads={t} payload must match sequential"
718 );
719 }
720 }
721
722 #[test]
723 fn streaming_with_metadata() {
724 let mut extra = BTreeMap::new();
725 extra.insert(
726 "centre".to_string(),
727 ciborium::Value::Text("ecmwf".to_string()),
728 );
729 let meta = GlobalMetadata {
730 version: 2,
731 extra,
732 ..Default::default()
733 };
734
735 let desc = make_descriptor(vec![4]);
736 let data = vec![0u8; 4 * 4];
737
738 let buf = Vec::new();
739 let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
740 enc.write_object(&desc, &data).unwrap();
741 let result = enc.finish().unwrap();
742
743 let (decoded_meta, _) = decode(&result, &DecodeOptions::default()).unwrap();
744 assert_eq!(
745 decoded_meta.extra.get("centre"),
746 Some(&ciborium::Value::Text("ecmwf".to_string()))
747 );
748 }
749
750 #[test]
753 fn streaming_preceder_round_trip() {
754 let meta = GlobalMetadata::default();
755 let desc = make_descriptor(vec![4]);
756 let data = vec![42u8; 4 * 4];
757
758 let mut prec = BTreeMap::new();
759 prec.insert(
760 "mars".to_string(),
761 ciborium::Value::Map(vec![(
762 ciborium::Value::Text("param".to_string()),
763 ciborium::Value::Text("2t".to_string()),
764 )]),
765 );
766
767 let buf = Vec::new();
768 let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
769 enc.write_preceder(prec).unwrap();
770 enc.write_object(&desc, &data).unwrap();
771 let result = enc.finish().unwrap();
772
773 let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
774 assert_eq!(objects.len(), 1);
775 assert_eq!(objects[0].1, data);
776
777 let mars = decoded_meta.base[0].get("mars");
779 assert!(mars.is_some(), "mars key should be in base[0]");
780 }
781
782 #[test]
783 fn streaming_preceder_wins_over_footer() {
784 let mut footer_base = BTreeMap::new();
787 footer_base.insert(
788 "source".to_string(),
789 ciborium::Value::Text("footer".to_string()),
790 );
791 let meta = GlobalMetadata {
792 version: 2,
793 base: vec![footer_base],
794 ..Default::default()
795 };
796
797 let mut prec = BTreeMap::new();
798 prec.insert(
799 "source".to_string(),
800 ciborium::Value::Text("preceder".to_string()),
801 );
802
803 let desc = make_descriptor(vec![4]);
804 let data = vec![0u8; 4 * 4];
805
806 let buf = Vec::new();
807 let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
808 enc.write_preceder(prec).unwrap();
809 enc.write_object(&desc, &data).unwrap();
810 let result = enc.finish().unwrap();
811
812 let (decoded_meta, _) = decode(&result, &DecodeOptions::default()).unwrap();
813 let source = decoded_meta.base[0].get("source").and_then(|v| match v {
814 ciborium::Value::Text(s) => Some(s.as_str()),
815 _ => None,
816 });
817 assert_eq!(source, Some("preceder"), "preceder should win over footer");
818 }
819
820 #[test]
821 fn streaming_consecutive_preceder_error() {
822 let meta = GlobalMetadata::default();
823 let buf = Vec::new();
824 let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
825
826 enc.write_preceder(BTreeMap::new()).unwrap();
827 let result = enc.write_preceder(BTreeMap::new());
828 assert!(
829 result.is_err(),
830 "two write_preceder calls without intervening write_object should fail"
831 );
832 }
833
834 #[test]
835 fn streaming_dangling_preceder_error() {
836 let meta = GlobalMetadata::default();
837 let buf = Vec::new();
838 let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
839
840 enc.write_preceder(BTreeMap::new()).unwrap();
841 let result = enc.finish();
842 assert!(
843 result.is_err(),
844 "finish with a dangling preceder should fail"
845 );
846 }
847
848 #[test]
849 fn streaming_mixed_objects_with_and_without_preceders() {
850 let meta = GlobalMetadata::default();
851 let desc0 = make_descriptor(vec![4]);
852 let desc1 = make_descriptor(vec![8]);
853 let data0 = vec![1u8; 4 * 4];
854 let data1 = vec![2u8; 8 * 4];
855
856 let mut prec = BTreeMap::new();
857 prec.insert(
858 "note".to_string(),
859 ciborium::Value::Text("only for obj 0".to_string()),
860 );
861
862 let buf = Vec::new();
863 let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
864 enc.write_preceder(prec).unwrap();
866 enc.write_object(&desc0, &data0).unwrap();
867 enc.write_object(&desc1, &data1).unwrap();
869 let result = enc.finish().unwrap();
870
871 let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
872 assert_eq!(objects.len(), 2);
873 assert_eq!(objects[0].1, data0);
874 assert_eq!(objects[1].1, data1);
875
876 assert!(decoded_meta.base[0].contains_key("note"));
878 assert!(!decoded_meta.base[1].contains_key("note"));
880 }
881
882 #[test]
883 fn streaming_preceder_metadata_preservation() {
884 let meta = GlobalMetadata::default();
887 let desc = make_descriptor(vec![2]);
888 let data = vec![0u8; 2 * 4];
889
890 let mut prec = BTreeMap::new();
891 prec.insert("units".to_string(), ciborium::Value::Text("K".to_string()));
892 prec.insert(
893 "mars".to_string(),
894 ciborium::Value::Map(vec![
895 (
896 ciborium::Value::Text("param".to_string()),
897 ciborium::Value::Text("2t".to_string()),
898 ),
899 (
900 ciborium::Value::Text("levtype".to_string()),
901 ciborium::Value::Text("sfc".to_string()),
902 ),
903 ]),
904 );
905
906 let buf = Vec::new();
907 let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
908 enc.write_preceder(prec).unwrap();
909 enc.write_object(&desc, &data).unwrap();
910 let result = enc.finish().unwrap();
911
912 let (decoded_meta, _) = decode(&result, &DecodeOptions::default()).unwrap();
913 let p = &decoded_meta.base[0];
914 assert_eq!(
915 p.get("units"),
916 Some(&ciborium::Value::Text("K".to_string()))
917 );
918 assert!(p.contains_key("mars"));
919 assert!(p.contains_key("_reserved_"));
921 }
922
923 #[test]
926 fn streaming_preceder_with_reserved_rejected() {
927 let meta = GlobalMetadata::default();
928 let buf = Vec::new();
929 let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
930
931 let mut prec = BTreeMap::new();
932 prec.insert("_reserved_".to_string(), ciborium::Value::Map(vec![]));
933
934 let result = enc.write_preceder(prec);
935 assert!(result.is_err(), "_reserved_ in preceder should be rejected");
936 let err = result.unwrap_err().to_string();
937 assert!(
938 err.contains("_reserved_"),
939 "error should mention _reserved_: {err}"
940 );
941 }
942
943 #[test]
944 fn streaming_preceder_reserved_stripped_on_decode() {
945 let mut prec_entry = BTreeMap::new();
952 prec_entry.insert(
953 "mars".to_string(),
954 ciborium::Value::Map(vec![(
955 ciborium::Value::Text("param".to_string()),
956 ciborium::Value::Text("2t".to_string()),
957 )]),
958 );
959 prec_entry.insert(
960 "_reserved_".to_string(),
961 ciborium::Value::Map(vec![(
962 ciborium::Value::Text("rogue".to_string()),
963 ciborium::Value::Text("bad".to_string()),
964 )]),
965 );
966
967 let preceder_meta = GlobalMetadata {
971 version: 2,
972 base: vec![prec_entry],
973 ..Default::default()
974 };
975 let preceder_cbor = crate::metadata::global_metadata_to_cbor(&preceder_meta).unwrap();
976
977 let desc_for_frame = make_descriptor(vec![4]);
979 let payload = vec![0u8; 4 * 4];
980 let frame =
981 crate::framing::encode_data_object_frame(&desc_for_frame, &payload, false).unwrap();
982
983 let mut footer_base = BTreeMap::new();
985 let tensor_map = ciborium::Value::Map(vec![
986 (
987 ciborium::Value::Text("ndim".to_string()),
988 ciborium::Value::Integer(1.into()),
989 ),
990 (
991 ciborium::Value::Text("shape".to_string()),
992 ciborium::Value::Array(vec![ciborium::Value::Integer(4.into())]),
993 ),
994 (
995 ciborium::Value::Text("strides".to_string()),
996 ciborium::Value::Array(vec![ciborium::Value::Integer(1.into())]),
997 ),
998 (
999 ciborium::Value::Text("dtype".to_string()),
1000 ciborium::Value::Text("float32".to_string()),
1001 ),
1002 ]);
1003 footer_base.insert(
1004 "_reserved_".to_string(),
1005 ciborium::Value::Map(vec![(
1006 ciborium::Value::Text("tensor".to_string()),
1007 tensor_map,
1008 )]),
1009 );
1010 let footer_meta = GlobalMetadata {
1011 version: 2,
1012 base: vec![footer_base],
1013 ..Default::default()
1014 };
1015 let footer_cbor = crate::metadata::global_metadata_to_cbor(&footer_meta).unwrap();
1016
1017 use crate::wire::*;
1019 let header_meta_cbor =
1020 crate::metadata::global_metadata_to_cbor(&GlobalMetadata::default()).unwrap();
1021
1022 let mut out = Vec::new();
1023 out.extend_from_slice(&[0u8; PREAMBLE_SIZE]);
1024
1025 let total_length = (FRAME_HEADER_SIZE + header_meta_cbor.len() + FRAME_END.len()) as u64;
1027 let fh = FrameHeader {
1028 frame_type: FrameType::HeaderMetadata,
1029 version: 1,
1030 flags: 0,
1031 total_length,
1032 };
1033 fh.write_to(&mut out);
1034 out.extend_from_slice(&header_meta_cbor);
1035 out.extend_from_slice(FRAME_END);
1036 let pad = (8 - (out.len() % 8)) % 8;
1037 out.extend(std::iter::repeat_n(0u8, pad));
1038
1039 let total_length = (FRAME_HEADER_SIZE + preceder_cbor.len() + FRAME_END.len()) as u64;
1041 let fh = FrameHeader {
1042 frame_type: FrameType::PrecederMetadata,
1043 version: 1,
1044 flags: 0,
1045 total_length,
1046 };
1047 fh.write_to(&mut out);
1048 out.extend_from_slice(&preceder_cbor);
1049 out.extend_from_slice(FRAME_END);
1050 let pad = (8 - (out.len() % 8)) % 8;
1051 out.extend(std::iter::repeat_n(0u8, pad));
1052
1053 out.extend_from_slice(&frame);
1055 let pad = (8 - (out.len() % 8)) % 8;
1056 out.extend(std::iter::repeat_n(0u8, pad));
1057
1058 let total_length = (FRAME_HEADER_SIZE + footer_cbor.len() + FRAME_END.len()) as u64;
1060 let fh = FrameHeader {
1061 frame_type: FrameType::FooterMetadata,
1062 version: 1,
1063 flags: 0,
1064 total_length,
1065 };
1066 fh.write_to(&mut out);
1067 out.extend_from_slice(&footer_cbor);
1068 out.extend_from_slice(FRAME_END);
1069 let pad = (8 - (out.len() % 8)) % 8;
1070 out.extend(std::iter::repeat_n(0u8, pad));
1071
1072 let postamble_offset = out.len();
1074 let postamble = Postamble {
1075 first_footer_offset: postamble_offset as u64,
1076 };
1077 postamble.write_to(&mut out);
1078
1079 let total_length = out.len() as u64;
1081 let mut flags = MessageFlags::default();
1082 flags.set(MessageFlags::HEADER_METADATA);
1083 flags.set(MessageFlags::FOOTER_METADATA);
1084 flags.set(MessageFlags::PRECEDER_METADATA);
1085 let preamble = Preamble {
1086 version: 2,
1087 flags,
1088 reserved: 0,
1089 total_length,
1090 };
1091 let mut preamble_bytes = Vec::new();
1092 preamble.write_to(&mut preamble_bytes);
1093 out[0..PREAMBLE_SIZE].copy_from_slice(&preamble_bytes);
1094
1095 let decoded = crate::framing::decode_message(&out).unwrap();
1097
1098 let base0 = &decoded.global_metadata.base[0];
1101 assert!(
1102 base0.contains_key("mars"),
1103 "mars from preceder should survive"
1104 );
1105 let reserved = base0.get("_reserved_");
1107 assert!(
1108 reserved.is_some(),
1109 "_reserved_ from footer should be present"
1110 );
1111 if let Some(ciborium::Value::Map(pairs)) = reserved {
1112 let has_tensor = pairs
1113 .iter()
1114 .any(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1115 assert!(has_tensor, "tensor key from footer should be preserved");
1116 let has_rogue = pairs
1117 .iter()
1118 .any(|(k, _)| *k == ciborium::Value::Text("rogue".to_string()));
1119 assert!(
1120 !has_rogue,
1121 "rogue key from preceder's _reserved_ should have been stripped"
1122 );
1123 }
1124 }
1125
1126 #[test]
1129 fn test_streaming_mixed_mode_pre_encoded() {
1130 let meta = GlobalMetadata::default();
1132
1133 let desc0 = make_descriptor(vec![4]);
1134 let desc2 = make_descriptor(vec![6]);
1135 let desc1 = make_descriptor(vec![5]);
1137
1138 let data0 = vec![1u8; 4 * 4];
1139 let pre_encoded1 = vec![2u8; 5 * 4]; let data2 = vec![3u8; 6 * 4];
1141
1142 let buf = Vec::new();
1143 let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
1144 enc.write_object(&desc0, &data0).unwrap();
1145 enc.write_object_pre_encoded(&desc1, &pre_encoded1).unwrap();
1146 enc.write_object(&desc2, &data2).unwrap();
1147 assert_eq!(enc.object_count(), 3);
1148 let result = enc.finish().unwrap();
1149
1150 let (_, objects) = decode(&result, &DecodeOptions::default()).unwrap();
1151 assert_eq!(objects.len(), 3);
1152 assert_eq!(objects[0].1, data0, "object 0 payload mismatch");
1155 assert_eq!(objects[1].1, pre_encoded1, "object 1 payload mismatch");
1156 assert_eq!(objects[2].1, data2, "object 2 payload mismatch");
1157 }
1158
1159 #[test]
1160 fn test_streaming_preceder_then_pre_encoded() {
1161 let meta = GlobalMetadata::default();
1164 let desc = make_descriptor(vec![4]);
1165 let pre_encoded = vec![42u8; 4 * 4];
1166
1167 let mut prec = BTreeMap::new();
1168 prec.insert(
1169 "mars".to_string(),
1170 ciborium::Value::Map(vec![(
1171 ciborium::Value::Text("param".to_string()),
1172 ciborium::Value::Text("2t".to_string()),
1173 )]),
1174 );
1175
1176 let buf = Vec::new();
1177 let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
1178 enc.write_preceder(prec).unwrap();
1179 enc.write_object_pre_encoded(&desc, &pre_encoded).unwrap();
1180 let result = enc.finish().unwrap();
1181
1182 let (decoded_meta, objects) = decode(&result, &DecodeOptions::default()).unwrap();
1183 assert_eq!(objects.len(), 1);
1184 assert_eq!(objects[0].1, pre_encoded, "pre-encoded payload mismatch");
1186 let mars = decoded_meta.base[0].get("mars");
1188 assert!(
1189 mars.is_some(),
1190 "mars key from preceder should be in base[0]"
1191 );
1192 }
1193
1194 #[test]
1195 fn streaming_finish_preserves_preceder_does_not_clobber_reserved_tensor() {
1196 let meta = GlobalMetadata::default();
1199 let desc = make_descriptor(vec![4]);
1200 let data = vec![42u8; 4 * 4];
1201
1202 let mut prec = BTreeMap::new();
1203 prec.insert("units".to_string(), ciborium::Value::Text("K".to_string()));
1204
1205 let buf = Vec::new();
1206 let mut enc = StreamingEncoder::new(buf, &meta, &EncodeOptions::default()).unwrap();
1207 enc.write_preceder(prec).unwrap();
1208 enc.write_object(&desc, &data).unwrap();
1209 let result = enc.finish().unwrap();
1210
1211 let (decoded_meta, _) = decode(&result, &DecodeOptions::default()).unwrap();
1212 let base0 = &decoded_meta.base[0];
1213
1214 assert!(base0.contains_key("units"));
1216
1217 let reserved = base0.get("_reserved_").expect("_reserved_ missing");
1219 if let ciborium::Value::Map(pairs) = reserved {
1220 let has_tensor = pairs
1221 .iter()
1222 .any(|(k, _)| *k == ciborium::Value::Text("tensor".to_string()));
1223 assert!(
1224 has_tensor,
1225 "_reserved_.tensor should be present after preceder merge"
1226 );
1227 } else {
1228 panic!("_reserved_ should be a map");
1229 }
1230 }
1231}