1extern crate zstd_sys;
19
20use tensogram_sz3_sys::SZ3_Config;
21
22#[derive(Clone, Debug, Copy)]
28#[non_exhaustive]
29pub enum CompressionAlgorithm {
30 Interpolation,
31 InterpolationLorenzo,
32 LorenzoRegression {
33 lorenzo: bool,
34 lorenzo_second_order: bool,
35 regression: bool,
36 },
37 BiologyMolecularData,
38 BiologyMolecularDataGromacsXtc,
39 NoPrediction,
40 Lossless,
41}
42
43impl CompressionAlgorithm {
44 fn decode(config: SZ3_Config) -> Result<Self> {
45 match config.cmprAlgo as u32 {
46 tensogram_sz3_sys::SZ3::ALGO_ALGO_INTERP => Ok(Self::Interpolation),
47 tensogram_sz3_sys::SZ3::ALGO_ALGO_INTERP_LORENZO => Ok(Self::InterpolationLorenzo),
48 tensogram_sz3_sys::SZ3::ALGO_ALGO_LORENZO_REG => Ok(Self::LorenzoRegression {
49 lorenzo: config.lorenzo,
50 lorenzo_second_order: config.lorenzo2,
51 regression: config.regression,
52 }),
53 tensogram_sz3_sys::SZ3::ALGO_ALGO_BIOMD => Ok(Self::BiologyMolecularData),
54 tensogram_sz3_sys::SZ3::ALGO_ALGO_BIOMDXTC => Ok(Self::BiologyMolecularDataGromacsXtc),
55 tensogram_sz3_sys::SZ3::ALGO_ALGO_NOPRED => Ok(Self::NoPrediction),
56 tensogram_sz3_sys::SZ3::ALGO_ALGO_LOSSLESS => Ok(Self::Lossless),
57 algo => Err(SZ3Error::UnsupportedAlgorithm(algo)),
58 }
59 }
60
61 fn code(&self) -> u8 {
62 (match self {
63 Self::Interpolation => tensogram_sz3_sys::SZ3::ALGO_ALGO_INTERP,
64 Self::InterpolationLorenzo => tensogram_sz3_sys::SZ3::ALGO_ALGO_INTERP_LORENZO,
65 Self::LorenzoRegression { .. } => tensogram_sz3_sys::SZ3::ALGO_ALGO_LORENZO_REG,
66 Self::BiologyMolecularData => tensogram_sz3_sys::SZ3::ALGO_ALGO_BIOMD,
67 Self::BiologyMolecularDataGromacsXtc => tensogram_sz3_sys::SZ3::ALGO_ALGO_BIOMDXTC,
68 Self::NoPrediction => tensogram_sz3_sys::SZ3::ALGO_ALGO_NOPRED,
69 Self::Lossless => tensogram_sz3_sys::SZ3::ALGO_ALGO_LOSSLESS,
70 }) as _
71 }
72
73 fn lorenzo(&self) -> bool {
74 match self {
75 Self::LorenzoRegression { lorenzo, .. } => *lorenzo,
76 _ => true,
77 }
78 }
79
80 fn lorenzo_second_order(&self) -> bool {
81 match self {
82 Self::LorenzoRegression {
83 lorenzo_second_order,
84 ..
85 } => *lorenzo_second_order,
86 _ => true,
87 }
88 }
89
90 fn regression(&self) -> bool {
91 match self {
92 Self::LorenzoRegression { regression, .. } => *regression,
93 _ => true,
94 }
95 }
96
97 pub fn interpolation() -> Self {
99 Self::Interpolation
100 }
101
102 pub fn interpolation_lorenzo() -> Self {
104 Self::InterpolationLorenzo
105 }
106
107 pub fn lorenzo_regression() -> Self {
109 Self::LorenzoRegression {
110 lorenzo: true,
111 lorenzo_second_order: false,
112 regression: true,
113 }
114 }
115
116 pub fn lorenzo_regression_custom(
118 lorenzo: Option<bool>,
119 lorenzo_second_order: Option<bool>,
120 regression: Option<bool>,
121 ) -> Self {
122 Self::LorenzoRegression {
123 lorenzo: lorenzo.unwrap_or(true),
124 lorenzo_second_order: lorenzo_second_order.unwrap_or(false),
125 regression: regression.unwrap_or(true),
126 }
127 }
128
129 pub fn biology_molecular_data() -> Self {
131 Self::BiologyMolecularData
132 }
133
134 pub fn biology_molecular_data_gromacs_xtc() -> Self {
136 Self::BiologyMolecularDataGromacsXtc
137 }
138
139 pub fn no_prediction() -> Self {
141 Self::NoPrediction
142 }
143
144 pub fn lossless() -> Self {
146 Self::Lossless
147 }
148}
149
150impl Default for CompressionAlgorithm {
151 fn default() -> Self {
152 Self::interpolation_lorenzo()
153 }
154}
155
156#[derive(Clone, Debug, Copy)]
163#[non_exhaustive]
164pub enum ErrorBound {
165 Absolute(f64),
166 Relative(f64),
167 PSNR(f64),
168 L2Norm(f64),
169 AbsoluteAndRelative {
170 absolute_bound: f64,
171 relative_bound: f64,
172 },
173 AbsoluteOrRelative {
174 absolute_bound: f64,
175 relative_bound: f64,
176 },
177}
178
179impl ErrorBound {
180 fn decode(config: SZ3_Config) -> Result<Self> {
181 match config.errorBoundMode as u32 {
182 tensogram_sz3_sys::SZ3::EB_EB_ABS => Ok(Self::Absolute(config.absErrorBound)),
183 tensogram_sz3_sys::SZ3::EB_EB_REL => Ok(Self::Relative(config.relErrorBound)),
184 tensogram_sz3_sys::SZ3::EB_EB_PSNR => Ok(Self::PSNR(config.psnrErrorBound)),
185 tensogram_sz3_sys::SZ3::EB_EB_L2NORM => Ok(Self::L2Norm(config.l2normErrorBound)),
186 tensogram_sz3_sys::SZ3::EB_EB_ABS_OR_REL => Ok(Self::AbsoluteOrRelative {
187 absolute_bound: config.absErrorBound,
188 relative_bound: config.relErrorBound,
189 }),
190 tensogram_sz3_sys::SZ3::EB_EB_ABS_AND_REL => Ok(Self::AbsoluteAndRelative {
191 absolute_bound: config.absErrorBound,
192 relative_bound: config.relErrorBound,
193 }),
194 mode => Err(SZ3Error::UnsupportedErrorBound(mode)),
195 }
196 }
197
198 fn code(&self) -> u8 {
199 (match self {
200 Self::Absolute(_) => tensogram_sz3_sys::SZ3::EB_EB_ABS,
201 Self::Relative(_) => tensogram_sz3_sys::SZ3::EB_EB_REL,
202 Self::PSNR(_) => tensogram_sz3_sys::SZ3::EB_EB_PSNR,
203 Self::L2Norm(_) => tensogram_sz3_sys::SZ3::EB_EB_L2NORM,
204 Self::AbsoluteAndRelative { .. } => tensogram_sz3_sys::SZ3::EB_EB_ABS_AND_REL,
205 Self::AbsoluteOrRelative { .. } => tensogram_sz3_sys::SZ3::EB_EB_ABS_OR_REL,
206 }) as _
207 }
208
209 fn abs_bound(&self) -> f64 {
210 match self {
211 Self::Absolute(bound) => *bound,
212 Self::AbsoluteOrRelative { absolute_bound, .. }
213 | Self::AbsoluteAndRelative { absolute_bound, .. } => *absolute_bound,
214 _ => 0.0,
215 }
216 }
217
218 fn rel_bound(&self) -> f64 {
219 match self {
220 Self::Relative(bound) => *bound,
221 Self::AbsoluteOrRelative { relative_bound, .. }
222 | Self::AbsoluteAndRelative { relative_bound, .. } => *relative_bound,
223 _ => 0.0,
224 }
225 }
226
227 fn l2norm_bound(&self) -> f64 {
228 match self {
229 Self::L2Norm(bound) => *bound,
230 _ => 0.0,
231 }
232 }
233
234 fn psnr_bound(&self) -> f64 {
235 match self {
236 Self::PSNR(bound) => *bound,
237 _ => 0.0,
238 }
239 }
240}
241
242#[derive(Clone, Debug)]
251pub struct Config {
252 compression_algorithm: CompressionAlgorithm,
253 error_bound: ErrorBound,
254 openmp: bool,
255 quantization_bincount: u32,
256 block_size: Option<u32>,
257}
258
259impl Config {
260 pub fn new(error_bound: ErrorBound) -> Self {
263 Self {
264 compression_algorithm: CompressionAlgorithm::default(),
265 error_bound,
266 openmp: false,
267 quantization_bincount: 65536,
268 block_size: None,
269 }
270 }
271
272 fn from_decompressed(config: SZ3_Config) -> Result<Self> {
273 Ok(Self {
274 compression_algorithm: CompressionAlgorithm::decode(config)?,
275 error_bound: ErrorBound::decode(config)?,
276 openmp: config.openmp,
277 quantization_bincount: config.quantbinCnt as _,
278 block_size: Some(config.blockSize as _),
279 })
280 }
281
282 pub fn compression_algorithm(mut self, compression_algorithm: CompressionAlgorithm) -> Self {
284 self.compression_algorithm = compression_algorithm;
285 self
286 }
287
288 pub fn error_bound(mut self, error_bound: ErrorBound) -> Self {
290 self.error_bound = error_bound;
291 self
292 }
293
294 #[cfg(feature = "openmp")]
296 pub fn openmp(mut self, openmp: bool) -> Self {
297 self.openmp = openmp;
298 self
299 }
300
301 pub fn quantization_bincount(mut self, quantization_bincount: u32) -> Self {
303 self.quantization_bincount = quantization_bincount;
304 self
305 }
306
307 pub fn block_size(mut self, block_size: u32) -> Self {
309 self.block_size = Some(block_size);
310 self
311 }
312
313 pub fn automatic_block_size(mut self) -> Self {
315 self.block_size = None;
316 self
317 }
318}
319
320pub trait SZ3Compressible: private::Sealed + Sized {}
330impl SZ3Compressible for f32 {}
331impl SZ3Compressible for f64 {}
332impl SZ3Compressible for u8 {}
333impl SZ3Compressible for i8 {}
334impl SZ3Compressible for u16 {}
335impl SZ3Compressible for i16 {}
336impl SZ3Compressible for u32 {}
337impl SZ3Compressible for i32 {}
338impl SZ3Compressible for u64 {}
339impl SZ3Compressible for i64 {}
340
341mod private {
342 pub trait Sealed: Copy {
343 const SZ_DATA_TYPE: u8;
344
345 unsafe fn compress_size_bound(config: tensogram_sz3_sys::SZ3_Config) -> usize;
346
347 unsafe fn compress(
348 config: tensogram_sz3_sys::SZ3_Config,
349 data: *const Self,
350 compressed_data: *mut u8,
351 compressed_capacity: usize,
352 ) -> usize;
353
354 unsafe fn decompress(
355 compressed_data: *const u8,
356 compressed_len: usize,
357 decompressed_data: *mut Self,
358 );
359 }
360
361 macro_rules! impl_sealed {
362 ($($impl_mod:ident),*) => {
363 $(impl Sealed for tensogram_sz3_sys::$impl_mod::ty {
364 const SZ_DATA_TYPE: u8 = tensogram_sz3_sys::$impl_mod::DATA_TYPE_TYPE;
365
366 unsafe fn compress_size_bound(config: tensogram_sz3_sys::SZ3_Config) -> usize {
367 unsafe { tensogram_sz3_sys::$impl_mod::compress_size_bound(config) }
368 }
369
370 unsafe fn compress(
371 config: tensogram_sz3_sys::SZ3_Config,
372 data: *const Self,
373 compressed_data: *mut u8,
374 compressed_capacity: usize,
375 ) -> usize {
376 unsafe {
377 tensogram_sz3_sys::$impl_mod::compress(
378 config,
379 data,
380 compressed_data.cast(),
381 compressed_capacity,
382 )
383 }
384 }
385
386 unsafe fn decompress(
387 compressed_data: *const u8,
388 compressed_len: usize,
389 decompressed_data: *mut Self,
390 ) {
391 unsafe {
392 tensogram_sz3_sys::$impl_mod::decompress(
393 compressed_data.cast(),
394 compressed_len,
395 decompressed_data,
396 )
397 }
398 }
399 })*
400 }
401 }
402
403 impl_sealed!(
404 impl_f32, impl_f64, impl_u8, impl_i8, impl_u16, impl_i16, impl_u32, impl_i32, impl_u64,
405 impl_i64
406 );
407}
408
409#[derive(Clone, Debug)]
419pub struct DimensionedData<V: SZ3Compressible, T: std::ops::Deref<Target = [V]>> {
420 data: T,
421 dims: Vec<usize>,
422}
423
424#[derive(Clone, Debug)]
426pub struct DimensionedDataBuilder<'a, V> {
427 data: &'a [V],
428 dims: Vec<usize>,
429 remainder: usize,
430}
431
432#[derive(Debug)]
434pub struct DimensionedDataBuilderMut<'a, V> {
435 data: &'a mut [V],
436 dims: Vec<usize>,
437 remainder: usize,
438}
439
440impl<V: SZ3Compressible, T: std::ops::Deref<Target = [V]>> DimensionedData<V, T> {
441 pub fn build<'a>(data: &'a T) -> DimensionedDataBuilder<'a, V> {
443 DimensionedDataBuilder {
444 data,
445 dims: vec![],
446 remainder: data.len(),
447 }
448 }
449
450 pub fn data(&self) -> &[V] {
452 &self.data
453 }
454
455 pub fn into_data(self) -> T {
457 self.data
458 }
459
460 pub fn dims(&self) -> &[usize] {
462 &self.dims
463 }
464
465 fn len(&self) -> usize {
466 self.data.len()
467 }
468
469 fn as_ptr(&self) -> *const V {
470 self.data.as_ptr()
471 }
472}
473
474impl<V: SZ3Compressible, T: std::ops::DerefMut<Target = [V]>> DimensionedData<V, T> {
475 pub fn build_mut<'a>(data: &'a mut T) -> DimensionedDataBuilderMut<'a, V> {
477 DimensionedDataBuilderMut {
478 remainder: data.len(),
479 data,
480 dims: vec![],
481 }
482 }
483
484 pub fn data_mut(&mut self) -> &mut [V] {
486 &mut self.data
487 }
488}
489
490#[derive(thiserror::Error, Debug)]
496pub enum SZ3Error {
497 #[error(
498 "invalid dimension specification for data of length {len}: already specified dimensions \
499 {dims:?}, and wanted to add dimension with length {wanted}, but this does not divide \
500 {remainder} cleanly"
501 )]
502 InvalidDimensionSize {
503 dims: Vec<usize>,
504 len: usize,
505 wanted: usize,
506 remainder: usize,
507 },
508 #[error("dimension with size one has no use")]
509 OneSizedDimension,
510 #[error(
511 "dimension specification {dims:?} for data of length {len} does not cover whole space, \
512 missing a dimension of {remainder}"
513 )]
514 UnderSpecifiedDimensions {
515 dims: Vec<usize>,
516 len: usize,
517 remainder: usize,
518 },
519 #[error("cannot decompress to array with a different data type")]
520 DecompressedDataTypeMismatch,
521 #[error("unsupported SZ3 compression algorithm code: {0}")]
522 UnsupportedAlgorithm(u32),
523 #[error("unsupported SZ3 error bound mode: {0}")]
524 UnsupportedErrorBound(u32),
525 #[error(
526 "cannot decompress array with dimensions {found:?} to array with different dimensions {expected:?}"
527 )]
528 DecompressedDimsMismatch {
529 found: Vec<usize>,
530 expected: Vec<usize>,
531 },
532}
533
534type Result<T> = std::result::Result<T, SZ3Error>;
535
536macro_rules! impl_dimensioned_data_builder {
541 ($($builder:ident => $data:ty),*) => {
542 $(impl<'a, V: SZ3Compressible> $builder<'a, V> {
543 pub fn dim(mut self, length: usize) -> Result<Self> {
548 if length == 1 {
549 if self.dims.is_empty() && self.remainder == 1 {
550 self.dims.push(1);
551 Ok(self)
552 } else {
553 Err(SZ3Error::OneSizedDimension)
554 }
555 } else if self.remainder % length != 0 {
556 Err(SZ3Error::InvalidDimensionSize {
557 dims: self.dims,
558 len: self.data.len(),
559 wanted: length,
560 remainder: self.remainder,
561 })
562 } else {
563 self.dims.push(length);
564 self.remainder /= length;
565 Ok(self)
566 }
567 }
568
569 pub fn remainder_dim(self) -> Result<$data> {
571 let remainder = self.remainder;
572 self.dim(remainder)?.finish()
573 }
574
575 pub fn finish(self) -> Result<$data> {
580 if self.remainder != 1 {
581 Err(SZ3Error::UnderSpecifiedDimensions {
582 dims: self.dims,
583 len: self.data.len(),
584 remainder: self.remainder,
585 })
586 } else {
587 Ok(DimensionedData {
588 data: self.data,
589 dims: self.dims,
590 })
591 }
592 }
593 })*
594 };
595}
596
597impl_dimensioned_data_builder! {
598 DimensionedDataBuilder => DimensionedData<V, &'a [V]>,
599 DimensionedDataBuilderMut => DimensionedData<V, &'a mut [V]>
600}
601
602struct ParsedConfig {
607 config: Config,
608 len: usize,
609 dims: Vec<usize>,
610 data_type: u8,
611}
612
613impl ParsedConfig {
614 fn from_compressed(compressed_data: &[u8]) -> Result<Self> {
615 let raw = unsafe {
616 tensogram_sz3_sys::sz3_decompress_config(
617 compressed_data.as_ptr().cast(),
618 compressed_data.len(),
619 )
620 };
621 let dims: Vec<usize> = (0..raw.N)
622 .map(|i| unsafe { std::ptr::read(raw.dims.add(i as usize)) })
623 .collect();
624 unsafe {
625 tensogram_sz3_sys::sz3_dealloc_size_t(raw.dims);
626 }
627 let SZ3_Config {
628 num: len,
629 dataType: data_type,
630 ..
631 } = raw;
632 let config = Config::from_decompressed(raw)?;
633 Ok(Self {
634 config,
635 len,
636 dims,
637 data_type,
638 })
639 }
640}
641
642pub fn compress<V: SZ3Compressible, T: std::ops::Deref<Target = [V]>>(
648 data: &DimensionedData<V, T>,
649 error_bound: ErrorBound,
650) -> Result<Vec<u8>> {
651 let config = Config::new(error_bound);
652 compress_with_config(data, &config)
653}
654
655pub fn compress_with_config<V: SZ3Compressible, T: std::ops::Deref<Target = [V]>>(
657 data: &DimensionedData<V, T>,
658 config: &Config,
659) -> Result<Vec<u8>> {
660 let mut compressed_data = Vec::new();
661 compress_into_with_config(data, config, &mut compressed_data)?;
662 Ok(compressed_data)
663}
664
665pub fn compress_into<V: SZ3Compressible, T: std::ops::Deref<Target = [V]>>(
667 data: &DimensionedData<V, T>,
668 error_bound: ErrorBound,
669 compressed_data: &mut Vec<u8>,
670) -> Result<()> {
671 let config = Config::new(error_bound);
672 compress_into_with_config(data, &config, compressed_data)
673}
674
675pub fn compress_into_with_config<V: SZ3Compressible, T: std::ops::Deref<Target = [V]>>(
678 data: &DimensionedData<V, T>,
679 config: &Config,
680 compressed_data: &mut Vec<u8>,
681) -> Result<()> {
682 let block_size = config.block_size.unwrap_or(match data.dims().len() {
683 1 => 128,
684 2 => 16,
685 _ => 6,
686 });
687
688 let raw_config = SZ3_Config {
689 N: data.dims().len() as _,
690 dims: data.dims.as_ptr() as _,
691 num: data.len() as _,
692 errorBoundMode: config.error_bound.code(),
693 absErrorBound: config.error_bound.abs_bound(),
694 relErrorBound: config.error_bound.rel_bound(),
695 l2normErrorBound: config.error_bound.l2norm_bound(),
696 psnrErrorBound: config.error_bound.psnr_bound(),
697 cmprAlgo: config.compression_algorithm.code(),
698 lorenzo: config.compression_algorithm.lorenzo(),
699 lorenzo2: config.compression_algorithm.lorenzo_second_order(),
700 regression: config.compression_algorithm.regression(),
701 openmp: config.openmp,
702 dataType: V::SZ_DATA_TYPE as _,
703 blockSize: block_size as _,
704 quantbinCnt: config.quantization_bincount as _,
705 };
706
707 let capacity: usize = unsafe { V::compress_size_bound(raw_config) };
708 compressed_data.reserve(capacity);
709
710 let len = unsafe {
711 V::compress(
712 raw_config,
713 data.as_ptr(),
714 compressed_data
715 .spare_capacity_mut()
716 .as_mut_ptr()
717 .cast::<u8>(),
718 capacity,
719 )
720 };
721 unsafe { compressed_data.set_len(compressed_data.len() + len) };
722
723 Ok(())
724}
725
726pub fn decompress<V: SZ3Compressible, T: std::ops::Deref<Target = [u8]>>(
729 compressed_data: T,
730) -> Result<(Config, DimensionedData<V, Vec<V>>)> {
731 let ParsedConfig {
732 config,
733 len,
734 dims,
735 data_type,
736 } = ParsedConfig::from_compressed(&compressed_data)?;
737
738 if data_type != V::SZ_DATA_TYPE {
739 return Err(SZ3Error::DecompressedDataTypeMismatch);
740 }
741
742 let decompressed_data = unsafe {
743 let mut decompressed_data: Vec<V> = Vec::with_capacity(len);
744
745 V::decompress(
746 compressed_data.as_ptr(),
747 compressed_data.len(),
748 decompressed_data
749 .spare_capacity_mut()
750 .as_mut_ptr()
751 .cast::<V>(),
752 );
753
754 decompressed_data.set_len(len);
755 decompressed_data
756 };
757
758 Ok((
759 config,
760 DimensionedData {
761 data: decompressed_data,
762 dims,
763 },
764 ))
765}
766
767pub fn decompress_into_dimensioned<
772 V: SZ3Compressible,
773 C: std::ops::Deref<Target = [u8]>,
774 D: std::ops::DerefMut<Target = [V]>,
775>(
776 compressed_data: C,
777 decompressed_data: &mut DimensionedData<V, D>,
778) -> Result<Config> {
779 let ParsedConfig {
780 config,
781 len,
782 dims,
783 data_type,
784 } = ParsedConfig::from_compressed(&compressed_data)?;
785
786 if data_type != V::SZ_DATA_TYPE {
787 return Err(SZ3Error::DecompressedDataTypeMismatch);
788 }
789
790 if decompressed_data.dims() != dims.as_slice() {
791 return Err(SZ3Error::DecompressedDimsMismatch {
792 found: dims,
793 expected: decompressed_data.dims.clone(),
794 });
795 }
796
797 assert_eq!(decompressed_data.len(), len);
798
799 unsafe {
800 V::decompress(
801 compressed_data.as_ptr(),
802 compressed_data.len(),
803 decompressed_data.data.as_mut_ptr(),
804 );
805 }
806
807 Ok(config)
808}
809
810#[cfg(test)]
815mod tests {
816 use super::*;
817
818 #[test]
819 fn round_trip_f64() {
820 let data: Vec<f64> = (0..256)
821 .map(|i| (i as f64 / 256.0 * std::f64::consts::PI).sin())
822 .collect();
823 let dimensioned = DimensionedData::<f64, _>::build(&data)
824 .dim(256)
825 .unwrap()
826 .finish()
827 .unwrap();
828 let compressed = compress(&dimensioned, ErrorBound::Absolute(1e-6)).unwrap();
829 let (_config, decompressed) = decompress::<f64, _>(&*compressed).unwrap();
830 assert_eq!(decompressed.data().len(), data.len());
831 for (orig, dec) in data.iter().zip(decompressed.data()) {
832 assert!(
833 (orig - dec).abs() <= 1e-6,
834 "orig={orig}, dec={dec}, diff={}",
835 (orig - dec).abs()
836 );
837 }
838 }
839
840 #[test]
841 fn round_trip_f32() {
842 let data: Vec<f32> = (0..256)
843 .map(|i| (i as f32 / 256.0 * std::f32::consts::PI).sin())
844 .collect();
845 let dimensioned = DimensionedData::<f32, _>::build(&data)
846 .dim(256)
847 .unwrap()
848 .finish()
849 .unwrap();
850 let compressed = compress(&dimensioned, ErrorBound::Absolute(1e-4)).unwrap();
851 let (_config, decompressed) = decompress::<f32, _>(&*compressed).unwrap();
852 assert_eq!(decompressed.data().len(), data.len());
853 for (orig, dec) in data.iter().zip(decompressed.data()) {
854 assert!(
855 (orig - dec).abs() <= 1e-4,
856 "orig={orig}, dec={dec}, diff={}",
857 (orig - dec).abs()
858 );
859 }
860 }
861
862 #[test]
863 fn dimension_errors() {
864 let data: Vec<f64> = vec![1.0; 100];
865 let err = DimensionedData::<f64, _>::build(&data).dim(1);
866 assert!(matches!(err.unwrap_err(), SZ3Error::OneSizedDimension));
867 let err = DimensionedData::<f64, _>::build(&data).dim(7);
868 assert!(matches!(
869 err.unwrap_err(),
870 SZ3Error::InvalidDimensionSize { .. }
871 ));
872 let err = DimensionedData::<f64, _>::build(&data)
873 .dim(10)
874 .unwrap()
875 .finish();
876 assert!(matches!(
877 err.unwrap_err(),
878 SZ3Error::UnderSpecifiedDimensions { .. }
879 ));
880 }
881
882 #[test]
883 fn round_trip_u8() {
884 let data: Vec<u8> = (0..256).map(|i| i as u8).collect();
885 let d = DimensionedData::<u8, _>::build(&data)
886 .dim(256)
887 .unwrap()
888 .finish()
889 .unwrap();
890 let compressed = compress(&d, ErrorBound::Absolute(0.0)).unwrap();
891 let (_, dec) = decompress::<u8, _>(&*compressed).unwrap();
892 assert_eq!(dec.data(), data.as_slice());
893 }
894
895 #[test]
896 fn round_trip_i8() {
897 let data: Vec<i8> = (-128..127).map(|i| i as i8).collect();
898 let d = DimensionedData::<i8, _>::build(&data)
899 .dim(data.len())
900 .unwrap()
901 .finish()
902 .unwrap();
903 let compressed = compress(&d, ErrorBound::Absolute(0.0)).unwrap();
904 let (_, dec) = decompress::<i8, _>(&*compressed).unwrap();
905 assert_eq!(dec.data(), data.as_slice());
906 }
907
908 #[test]
909 fn round_trip_u16() {
910 let data: Vec<u16> = (0..256).map(|i| i as u16 * 100).collect();
911 let d = DimensionedData::<u16, _>::build(&data)
912 .dim(256)
913 .unwrap()
914 .finish()
915 .unwrap();
916 let compressed = compress(&d, ErrorBound::Absolute(0.0)).unwrap();
917 let (_, dec) = decompress::<u16, _>(&*compressed).unwrap();
918 assert_eq!(dec.data(), data.as_slice());
919 }
920
921 #[test]
922 fn round_trip_i16() {
923 let data: Vec<i16> = (-128..128).map(|i| i as i16 * 50).collect();
924 let d = DimensionedData::<i16, _>::build(&data)
925 .dim(256)
926 .unwrap()
927 .finish()
928 .unwrap();
929 let compressed = compress(&d, ErrorBound::Absolute(0.0)).unwrap();
930 let (_, dec) = decompress::<i16, _>(&*compressed).unwrap();
931 assert_eq!(dec.data(), data.as_slice());
932 }
933
934 #[test]
935 fn round_trip_u32() {
936 let data: Vec<u32> = (0..256).map(|i| i as u32 * 1000).collect();
937 let d = DimensionedData::<u32, _>::build(&data)
938 .dim(256)
939 .unwrap()
940 .finish()
941 .unwrap();
942 let compressed = compress(&d, ErrorBound::Absolute(0.0)).unwrap();
943 let (_, dec) = decompress::<u32, _>(&*compressed).unwrap();
944 assert_eq!(dec.data(), data.as_slice());
945 }
946
947 #[test]
948 fn round_trip_i32() {
949 let data: Vec<i32> = (-128..128).map(|i: i32| i * 1000).collect();
950 let d = DimensionedData::<i32, _>::build(&data)
951 .dim(256)
952 .unwrap()
953 .finish()
954 .unwrap();
955 let compressed = compress(&d, ErrorBound::Absolute(0.0)).unwrap();
956 let (_, dec) = decompress::<i32, _>(&*compressed).unwrap();
957 assert_eq!(dec.data(), data.as_slice());
958 }
959
960 #[test]
961 fn round_trip_u64() {
962 let data: Vec<u64> = (0..256).map(|i| i as u64 * 10000).collect();
963 let d = DimensionedData::<u64, _>::build(&data)
964 .dim(256)
965 .unwrap()
966 .finish()
967 .unwrap();
968 let compressed = compress(&d, ErrorBound::Absolute(0.0)).unwrap();
969 let (_, dec) = decompress::<u64, _>(&*compressed).unwrap();
970 assert_eq!(dec.data(), data.as_slice());
971 }
972
973 #[test]
974 fn round_trip_i64() {
975 let data: Vec<i64> = (-128..128).map(|i| i as i64 * 10000).collect();
976 let d = DimensionedData::<i64, _>::build(&data)
977 .dim(256)
978 .unwrap()
979 .finish()
980 .unwrap();
981 let compressed = compress(&d, ErrorBound::Absolute(0.0)).unwrap();
982 let (_, dec) = decompress::<i64, _>(&*compressed).unwrap();
983 assert_eq!(dec.data(), data.as_slice());
984 }
985
986 #[test]
987 fn error_bound_relative() {
988 let data: Vec<f64> = (0..256)
989 .map(|i| (i as f64 / 256.0 * std::f64::consts::PI).sin() + 2.0)
990 .collect();
991 let d = DimensionedData::<f64, _>::build(&data)
992 .dim(256)
993 .unwrap()
994 .finish()
995 .unwrap();
996 let cfg = Config::new(ErrorBound::Relative(1e-4));
997 let c = compress_with_config(&d, &cfg).unwrap();
998 let (_, dec) = decompress::<f64, _>(&*c).unwrap();
999 assert_eq!(dec.data().len(), data.len());
1000 }
1001
1002 #[test]
1003 fn error_bound_psnr() {
1004 let data: Vec<f64> = (0..256)
1005 .map(|i| (i as f64 / 256.0 * std::f64::consts::PI).sin())
1006 .collect();
1007 let d = DimensionedData::<f64, _>::build(&data)
1008 .dim(256)
1009 .unwrap()
1010 .finish()
1011 .unwrap();
1012 let cfg = Config::new(ErrorBound::PSNR(80.0));
1013 let c = compress_with_config(&d, &cfg).unwrap();
1014 let (_, dec) = decompress::<f64, _>(&*c).unwrap();
1015 assert_eq!(dec.data().len(), data.len());
1016 }
1017
1018 #[test]
1019 fn error_bound_l2norm() {
1020 let data: Vec<f64> = (0..256)
1021 .map(|i| (i as f64 / 256.0 * std::f64::consts::PI).sin())
1022 .collect();
1023 let d = DimensionedData::<f64, _>::build(&data)
1024 .dim(256)
1025 .unwrap()
1026 .finish()
1027 .unwrap();
1028 let cfg = Config::new(ErrorBound::L2Norm(1e-3));
1029 let c = compress_with_config(&d, &cfg).unwrap();
1030 let (_, dec) = decompress::<f64, _>(&*c).unwrap();
1031 assert_eq!(dec.data().len(), data.len());
1032 }
1033
1034 #[test]
1035 fn error_bound_abs_and_rel() {
1036 let data: Vec<f64> = (0..256)
1037 .map(|i| (i as f64 / 256.0 * std::f64::consts::PI).sin() + 2.0)
1038 .collect();
1039 let d = DimensionedData::<f64, _>::build(&data)
1040 .dim(256)
1041 .unwrap()
1042 .finish()
1043 .unwrap();
1044 let cfg = Config::new(ErrorBound::AbsoluteAndRelative {
1045 absolute_bound: 1e-4,
1046 relative_bound: 1e-3,
1047 });
1048 let c = compress_with_config(&d, &cfg).unwrap();
1049 let (_, dec) = decompress::<f64, _>(&*c).unwrap();
1050 assert_eq!(dec.data().len(), data.len());
1051 }
1052
1053 #[test]
1054 fn error_bound_abs_or_rel() {
1055 let data: Vec<f64> = (0..256)
1056 .map(|i| (i as f64 / 256.0 * std::f64::consts::PI).sin() + 2.0)
1057 .collect();
1058 let d = DimensionedData::<f64, _>::build(&data)
1059 .dim(256)
1060 .unwrap()
1061 .finish()
1062 .unwrap();
1063 let cfg = Config::new(ErrorBound::AbsoluteOrRelative {
1064 absolute_bound: 1e-4,
1065 relative_bound: 1e-3,
1066 });
1067 let c = compress_with_config(&d, &cfg).unwrap();
1068 let (_, dec) = decompress::<f64, _>(&*c).unwrap();
1069 assert_eq!(dec.data().len(), data.len());
1070 }
1071
1072 #[test]
1073 fn algo_interpolation() {
1074 let data: Vec<f64> = (0..256).map(|i| i as f64).collect();
1075 let d = DimensionedData::<f64, _>::build(&data)
1076 .dim(256)
1077 .unwrap()
1078 .finish()
1079 .unwrap();
1080 let cfg = Config::new(ErrorBound::Absolute(1e-6))
1081 .compression_algorithm(CompressionAlgorithm::interpolation());
1082 let c = compress_with_config(&d, &cfg).unwrap();
1083 let (dc, dec) = decompress::<f64, _>(&*c).unwrap();
1084 assert_eq!(dec.data().len(), data.len());
1085 assert!(matches!(
1086 dc.compression_algorithm,
1087 CompressionAlgorithm::Interpolation
1088 ));
1089 }
1090
1091 #[test]
1092 fn algo_lorenzo_regression() {
1093 let data: Vec<f64> = (0..256).map(|i| i as f64).collect();
1094 let d = DimensionedData::<f64, _>::build(&data)
1095 .dim(256)
1096 .unwrap()
1097 .finish()
1098 .unwrap();
1099 let cfg = Config::new(ErrorBound::Absolute(1e-6))
1100 .compression_algorithm(CompressionAlgorithm::lorenzo_regression());
1101 let c = compress_with_config(&d, &cfg).unwrap();
1102 let (dc, dec) = decompress::<f64, _>(&*c).unwrap();
1103 assert_eq!(dec.data().len(), data.len());
1104 assert!(matches!(
1105 dc.compression_algorithm,
1106 CompressionAlgorithm::LorenzoRegression { .. }
1107 ));
1108 }
1109
1110 #[test]
1111 fn algo_lossless() {
1112 let data: Vec<f64> = (0..256).map(|i| i as f64).collect();
1113 let d = DimensionedData::<f64, _>::build(&data)
1114 .dim(256)
1115 .unwrap()
1116 .finish()
1117 .unwrap();
1118 let cfg = Config::new(ErrorBound::Absolute(0.0))
1119 .compression_algorithm(CompressionAlgorithm::lossless());
1120 let c = compress_with_config(&d, &cfg).unwrap();
1121 let (_, dec) = decompress::<f64, _>(&*c).unwrap();
1122 assert_eq!(dec.data(), data.as_slice());
1123 }
1124
1125 #[test]
1126 fn config_quantization_bincount() {
1127 let data: Vec<f64> = (0..256).map(|i| i as f64).collect();
1128 let d = DimensionedData::<f64, _>::build(&data)
1129 .dim(256)
1130 .unwrap()
1131 .finish()
1132 .unwrap();
1133 let cfg = Config::new(ErrorBound::Absolute(1e-6)).quantization_bincount(1024);
1134 let c = compress_with_config(&d, &cfg).unwrap();
1135 let (dc, _) = decompress::<f64, _>(&*c).unwrap();
1136 assert_eq!(dc.quantization_bincount, 1024);
1137 }
1138
1139 #[test]
1140 fn config_block_size() {
1141 let data: Vec<f64> = (0..256).map(|i| i as f64).collect();
1142 let d = DimensionedData::<f64, _>::build(&data)
1143 .dim(256)
1144 .unwrap()
1145 .finish()
1146 .unwrap();
1147 let cfg = Config::new(ErrorBound::Absolute(1e-6)).block_size(64);
1148 let c = compress_with_config(&d, &cfg).unwrap();
1149 let (dc, _) = decompress::<f64, _>(&*c).unwrap();
1150 assert_eq!(dc.block_size, Some(64));
1151 }
1152
1153 #[test]
1154 fn config_automatic_block_size() {
1155 let data: Vec<f64> = (0..256).map(|i| i as f64).collect();
1156 let d = DimensionedData::<f64, _>::build(&data)
1157 .dim(256)
1158 .unwrap()
1159 .finish()
1160 .unwrap();
1161 let cfg = Config::new(ErrorBound::Absolute(1e-6))
1162 .block_size(64)
1163 .automatic_block_size();
1164 assert!(cfg.block_size.is_none());
1165 let c = compress_with_config(&d, &cfg).unwrap();
1166 let (_, dec) = decompress::<f64, _>(&*c).unwrap();
1167 assert_eq!(dec.data().len(), data.len());
1168 }
1169
1170 #[test]
1171 fn config_error_bound_setter() {
1172 let cfg = Config::new(ErrorBound::Absolute(1.0)).error_bound(ErrorBound::Relative(0.5));
1173 assert!(matches!(cfg.error_bound, ErrorBound::Relative(_)));
1174 }
1175
1176 #[test]
1177 fn dimensioned_data_accessors() {
1178 let data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1179 let dd = DimensionedData::<f64, _>::build(&data)
1180 .dim(2)
1181 .unwrap()
1182 .dim(3)
1183 .unwrap()
1184 .finish()
1185 .unwrap();
1186 assert_eq!(dd.dims(), &[2, 3]);
1187 assert_eq!(dd.data(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1188 let owned = dd.into_data();
1189 assert_eq!(owned, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1190 }
1191
1192 #[test]
1193 fn dimensioned_data_build_mut() {
1194 let mut data: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0];
1195 let mut dd = DimensionedData::<f64, _>::build_mut(&mut data)
1196 .dim(4)
1197 .unwrap()
1198 .finish()
1199 .unwrap();
1200 assert_eq!(dd.data(), &[1.0, 2.0, 3.0, 4.0]);
1201 dd.data_mut()[0] = 99.0;
1202 assert_eq!(dd.data()[0], 99.0);
1203 }
1204
1205 #[test]
1206 fn dimensioned_data_remainder_dim() {
1207 let data: Vec<f64> = vec![1.0; 120];
1208 let dd = DimensionedData::<f64, _>::build(&data)
1209 .dim(10)
1210 .unwrap()
1211 .remainder_dim()
1212 .unwrap();
1213 assert_eq!(dd.dims(), &[10, 12]);
1214 assert_eq!(dd.data().len(), 120);
1215 }
1216
1217 #[test]
1218 fn dimensioned_data_singleton() {
1219 let data: Vec<f64> = vec![42.0];
1220 let dd = DimensionedData::<f64, _>::build(&data)
1221 .dim(1)
1222 .unwrap()
1223 .finish()
1224 .unwrap();
1225 assert_eq!(dd.dims(), &[1]);
1226 assert_eq!(dd.data(), &[42.0]);
1227 }
1228
1229 #[test]
1230 fn decompress_into_dimensioned_round_trip() {
1231 let data: Vec<f64> = (0..256)
1232 .map(|i| (i as f64 / 256.0 * std::f64::consts::PI).sin())
1233 .collect();
1234 let d = DimensionedData::<f64, _>::build(&data)
1235 .dim(256)
1236 .unwrap()
1237 .finish()
1238 .unwrap();
1239 let compressed = compress(&d, ErrorBound::Absolute(1e-6)).unwrap();
1240 let mut output = vec![0.0f64; 256];
1241 let mut out_dim = DimensionedData::<f64, _>::build_mut(&mut output)
1242 .dim(256)
1243 .unwrap()
1244 .finish()
1245 .unwrap();
1246 let _cfg = decompress_into_dimensioned(&*compressed, &mut out_dim).unwrap();
1247 for (orig, dec) in data.iter().zip(out_dim.data()) {
1248 assert!((orig - dec).abs() <= 1e-6);
1249 }
1250 }
1251
1252 #[test]
1253 fn decompress_data_type_mismatch() {
1254 let data: Vec<f32> = (0..256).map(|i| i as f32).collect();
1255 let d = DimensionedData::<f32, _>::build(&data)
1256 .dim(256)
1257 .unwrap()
1258 .finish()
1259 .unwrap();
1260 let compressed = compress(&d, ErrorBound::Absolute(1e-4)).unwrap();
1261 let result = decompress::<f64, _>(&*compressed);
1262 assert!(matches!(
1263 result.unwrap_err(),
1264 SZ3Error::DecompressedDataTypeMismatch
1265 ));
1266 }
1267
1268 #[test]
1269 fn decompress_into_dims_mismatch() {
1270 let data: Vec<f64> = (0..256).map(|i| i as f64).collect();
1271 let d = DimensionedData::<f64, _>::build(&data)
1272 .dim(256)
1273 .unwrap()
1274 .finish()
1275 .unwrap();
1276 let compressed = compress(&d, ErrorBound::Absolute(1e-6)).unwrap();
1277 let mut output = vec![0.0f64; 256];
1278 let mut out_dim = DimensionedData::<f64, _>::build_mut(&mut output)
1279 .dim(16)
1280 .unwrap()
1281 .dim(16)
1282 .unwrap()
1283 .finish()
1284 .unwrap();
1285 let result = decompress_into_dimensioned(&*compressed, &mut out_dim);
1286 assert!(matches!(
1287 result.unwrap_err(),
1288 SZ3Error::DecompressedDimsMismatch { .. }
1289 ));
1290 }
1291
1292 #[test]
1293 fn decompress_into_type_mismatch() {
1294 let data: Vec<f32> = (0..256).map(|i| i as f32).collect();
1295 let d = DimensionedData::<f32, _>::build(&data)
1296 .dim(256)
1297 .unwrap()
1298 .finish()
1299 .unwrap();
1300 let compressed = compress(&d, ErrorBound::Absolute(1e-4)).unwrap();
1301 let mut output = vec![0.0f64; 256];
1302 let mut out_dim = DimensionedData::<f64, _>::build_mut(&mut output)
1303 .dim(256)
1304 .unwrap()
1305 .finish()
1306 .unwrap();
1307 let result = decompress_into_dimensioned(&*compressed, &mut out_dim);
1308 assert!(matches!(
1309 result.unwrap_err(),
1310 SZ3Error::DecompressedDataTypeMismatch
1311 ));
1312 }
1313
1314 #[test]
1315 fn compress_into_appends() {
1316 let data: Vec<f64> = (0..256).map(|i| i as f64).collect();
1317 let d = DimensionedData::<f64, _>::build(&data)
1318 .dim(256)
1319 .unwrap()
1320 .finish()
1321 .unwrap();
1322 let mut buf = vec![0xAA, 0xBB, 0xCC];
1323 compress_into(&d, ErrorBound::Absolute(1e-6), &mut buf).unwrap();
1324 assert_eq!(&buf[..3], &[0xAA, 0xBB, 0xCC]);
1325 assert!(buf.len() > 3);
1326 let (_, dec) = decompress::<f64, _>(&buf[3..]).unwrap();
1327 assert_eq!(dec.data().len(), data.len());
1328 }
1329
1330 #[test]
1331 fn compress_into_with_config_appends() {
1332 let data: Vec<f64> = (0..256).map(|i| i as f64).collect();
1333 let d = DimensionedData::<f64, _>::build(&data)
1334 .dim(256)
1335 .unwrap()
1336 .finish()
1337 .unwrap();
1338 let cfg = Config::new(ErrorBound::Absolute(1e-6))
1339 .compression_algorithm(CompressionAlgorithm::lossless());
1340 let mut buf = vec![0xDD, 0xEE];
1341 compress_into_with_config(&d, &cfg, &mut buf).unwrap();
1342 assert_eq!(&buf[..2], &[0xDD, 0xEE]);
1343 let (_, dec) = decompress::<f64, _>(&buf[2..]).unwrap();
1344 assert_eq!(dec.data(), data.as_slice());
1345 }
1346
1347 #[test]
1348 fn round_trip_2d() {
1349 let data: Vec<f64> = (0..256).map(|i| i as f64).collect();
1350 let d = DimensionedData::<f64, _>::build(&data)
1351 .dim(16)
1352 .unwrap()
1353 .dim(16)
1354 .unwrap()
1355 .finish()
1356 .unwrap();
1357 let c = compress(&d, ErrorBound::Absolute(1e-6)).unwrap();
1358 let (_, dec) = decompress::<f64, _>(&*c).unwrap();
1359 assert_eq!(dec.dims(), &[16, 16]);
1360 }
1361
1362 #[test]
1363 fn round_trip_3d() {
1364 let data: Vec<f64> = (0..216).map(|i| i as f64).collect();
1365 let d = DimensionedData::<f64, _>::build(&data)
1366 .dim(6)
1367 .unwrap()
1368 .dim(6)
1369 .unwrap()
1370 .dim(6)
1371 .unwrap()
1372 .finish()
1373 .unwrap();
1374 let c = compress(&d, ErrorBound::Absolute(1e-6)).unwrap();
1375 let (_, dec) = decompress::<f64, _>(&*c).unwrap();
1376 assert_eq!(dec.dims(), &[6, 6, 6]);
1377 }
1378
1379 #[test]
1380 fn algo_constructors() {
1381 assert!(matches!(
1382 CompressionAlgorithm::interpolation(),
1383 CompressionAlgorithm::Interpolation
1384 ));
1385 assert!(matches!(
1386 CompressionAlgorithm::interpolation_lorenzo(),
1387 CompressionAlgorithm::InterpolationLorenzo
1388 ));
1389 assert!(matches!(
1390 CompressionAlgorithm::lorenzo_regression(),
1391 CompressionAlgorithm::LorenzoRegression {
1392 lorenzo: true,
1393 lorenzo_second_order: false,
1394 regression: true
1395 }
1396 ));
1397 let a =
1398 CompressionAlgorithm::lorenzo_regression_custom(Some(false), Some(true), Some(false));
1399 assert!(matches!(
1400 a,
1401 CompressionAlgorithm::LorenzoRegression {
1402 lorenzo: false,
1403 lorenzo_second_order: true,
1404 regression: false
1405 }
1406 ));
1407 let a = CompressionAlgorithm::lorenzo_regression_custom(None, None, None);
1408 assert!(matches!(
1409 a,
1410 CompressionAlgorithm::LorenzoRegression {
1411 lorenzo: true,
1412 lorenzo_second_order: false,
1413 regression: true
1414 }
1415 ));
1416 assert!(matches!(
1417 CompressionAlgorithm::biology_molecular_data(),
1418 CompressionAlgorithm::BiologyMolecularData
1419 ));
1420 assert!(matches!(
1421 CompressionAlgorithm::biology_molecular_data_gromacs_xtc(),
1422 CompressionAlgorithm::BiologyMolecularDataGromacsXtc
1423 ));
1424 assert!(matches!(
1425 CompressionAlgorithm::no_prediction(),
1426 CompressionAlgorithm::NoPrediction
1427 ));
1428 assert!(matches!(
1429 CompressionAlgorithm::lossless(),
1430 CompressionAlgorithm::Lossless
1431 ));
1432 assert!(matches!(
1433 CompressionAlgorithm::default(),
1434 CompressionAlgorithm::InterpolationLorenzo
1435 ));
1436 }
1437
1438 #[test]
1439 fn error_display_messages() {
1440 let e = SZ3Error::OneSizedDimension;
1441 assert!(format!("{e}").contains("size one"));
1442 let e = SZ3Error::DecompressedDataTypeMismatch;
1443 assert!(format!("{e}").contains("different data type"));
1444 let e = SZ3Error::DecompressedDimsMismatch {
1445 found: vec![256],
1446 expected: vec![16, 16],
1447 };
1448 assert!(format!("{e}").contains("[256]"));
1449 let e = SZ3Error::InvalidDimensionSize {
1450 dims: vec![10],
1451 len: 100,
1452 wanted: 7,
1453 remainder: 10,
1454 };
1455 assert!(format!("{e}").contains("7"));
1456 let e = SZ3Error::UnderSpecifiedDimensions {
1457 dims: vec![10],
1458 len: 100,
1459 remainder: 10,
1460 };
1461 assert!(format!("{e}").contains("10"));
1462 }
1463}