1use crate::error::{Result, TransformError};
7use rayon::prelude::*;
8use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
9
10#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum WaveletType {
13 Haar,
15 Daubechies(usize),
17 Symlet(usize),
19 Coiflet(usize),
21 Biorthogonal(usize, usize),
23}
24
25#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum BoundaryMode {
28 Zero,
30 Constant,
32 Symmetric,
34 Periodic,
36 Reflect,
38}
39
40#[derive(Debug, Clone)]
42pub struct WaveletFilters {
43 pub dec_lo: Vec<f64>,
45 pub dec_hi: Vec<f64>,
47 pub rec_lo: Vec<f64>,
49 pub rec_hi: Vec<f64>,
51}
52
53impl WaveletFilters {
54 pub fn from_wavelet(wavelet: WaveletType) -> Result<Self> {
56 match wavelet {
57 WaveletType::Haar => Self::haar(),
58 WaveletType::Daubechies(n) => Self::daubechies(n),
59 WaveletType::Symlet(n) => Self::symlet(n),
60 WaveletType::Coiflet(n) => Self::coiflet(n),
61 WaveletType::Biorthogonal(p, q) => Self::biorthogonal(p, q),
62 }
63 }
64
65 fn haar() -> Result<Self> {
67 let norm = 1.0 / 2.0_f64.sqrt();
68 Ok(WaveletFilters {
69 dec_lo: vec![norm, norm],
70 dec_hi: vec![norm, -norm],
71 rec_lo: vec![norm, norm],
72 rec_hi: vec![-norm, norm],
73 })
74 }
75
76 fn daubechies(n: usize) -> Result<Self> {
78 match n {
79 2 => {
80 let sqrt3 = 3.0_f64.sqrt();
82 let denom = 4.0 * 2.0_f64.sqrt();
83 let dec_lo = vec![
84 (1.0 + sqrt3) / denom,
85 (3.0 + sqrt3) / denom,
86 (3.0 - sqrt3) / denom,
87 (1.0 - sqrt3) / denom,
88 ];
89 let mut dec_hi = Vec::with_capacity(dec_lo.len());
90 for (i, &val) in dec_lo.iter().enumerate().rev() {
91 dec_hi.push(if i % 2 == 0 { val } else { -val });
92 }
93
94 let mut rec_lo = dec_lo.clone();
95 rec_lo.reverse();
96 let mut rec_hi = dec_hi.clone();
97 rec_hi.reverse();
98
99 Ok(WaveletFilters {
100 dec_lo,
101 dec_hi,
102 rec_lo,
103 rec_hi,
104 })
105 }
106 4 => {
107 let dec_lo = vec![
109 -0.010597401784997,
110 0.032883011666983,
111 0.030841381835987,
112 -0.187034811718881,
113 -0.027983769416984,
114 0.630880767929590,
115 0.714846570552542,
116 0.230377813308855,
117 ];
118 let mut dec_hi = Vec::with_capacity(dec_lo.len());
119 for (i, &val) in dec_lo.iter().enumerate().rev() {
120 dec_hi.push(if i % 2 == 0 { val } else { -val });
121 }
122
123 let mut rec_lo = dec_lo.clone();
124 rec_lo.reverse();
125 let mut rec_hi = dec_hi.clone();
126 rec_hi.reverse();
127
128 Ok(WaveletFilters {
129 dec_lo,
130 dec_hi,
131 rec_lo,
132 rec_hi,
133 })
134 }
135 1 => Self::haar(),
137 3 => {
138 let dec_lo: Vec<f64> = vec![
140 0.035226291882100656,
141 -0.08544127388224149,
142 -0.13501102001039084,
143 0.4598775021193313,
144 0.8068915093133388,
145 0.3326705529509569,
146 ];
147 let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
148 for (i, &val) in dec_lo.iter().enumerate().rev() {
149 dec_hi.push(if i % 2 == 0 { val } else { -val });
150 }
151 let mut rec_lo = dec_lo.clone();
152 rec_lo.reverse();
153 let mut rec_hi = dec_hi.clone();
154 rec_hi.reverse();
155 Ok(WaveletFilters {
156 dec_lo,
157 dec_hi,
158 rec_lo,
159 rec_hi,
160 })
161 }
162 5 => {
163 let dec_lo: Vec<f64> = vec![
165 0.003335725285001549,
166 -0.012580751999015526,
167 -0.006241490213011705,
168 0.07757149384006515,
169 -0.03224486958502952,
170 -0.24229488706619015,
171 0.13842814590110342,
172 0.7243085284385744,
173 0.6038292697974898,
174 0.16010239797412501,
175 ];
176 let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
177 for (i, &val) in dec_lo.iter().enumerate().rev() {
178 dec_hi.push(if i % 2 == 0 { val } else { -val });
179 }
180 let mut rec_lo = dec_lo.clone();
181 rec_lo.reverse();
182 let mut rec_hi = dec_hi.clone();
183 rec_hi.reverse();
184 Ok(WaveletFilters {
185 dec_lo,
186 dec_hi,
187 rec_lo,
188 rec_hi,
189 })
190 }
191 6 => {
192 let dec_lo: Vec<f64> = vec![
194 -0.0010773010853084796,
195 0.004777257510945511,
196 0.0005538422011614961,
197 -0.03158203931748603,
198 0.027522865530305727,
199 0.09750160558732304,
200 -0.12976686756726194,
201 -0.22626469396543983,
202 0.31525035170919763,
203 0.7511339080210954,
204 0.49462389039845306,
205 0.11154074335010947,
206 ];
207 let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
208 for (i, &val) in dec_lo.iter().enumerate().rev() {
209 dec_hi.push(if i % 2 == 0 { val } else { -val });
210 }
211 let mut rec_lo = dec_lo.clone();
212 rec_lo.reverse();
213 let mut rec_hi = dec_hi.clone();
214 rec_hi.reverse();
215 Ok(WaveletFilters {
216 dec_lo,
217 dec_hi,
218 rec_lo,
219 rec_hi,
220 })
221 }
222 7 => {
223 let dec_lo: Vec<f64> = vec![
225 0.00035371379997452024,
226 -0.0018016407040474908,
227 0.0004295779729213665,
228 0.01255099855609984,
229 -0.01657454163066688,
230 -0.03802993693501441,
231 0.08061260915108308,
232 0.07130921926683026,
233 -0.22403618499387498,
234 -0.14390600392856498,
235 0.4697822874051931,
236 0.7291320908462351,
237 0.3965393194819173,
238 0.07785205408500918,
239 ];
240 let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
241 for (i, &val) in dec_lo.iter().enumerate().rev() {
242 dec_hi.push(if i % 2 == 0 { val } else { -val });
243 }
244 let mut rec_lo = dec_lo.clone();
245 rec_lo.reverse();
246 let mut rec_hi = dec_hi.clone();
247 rec_hi.reverse();
248 Ok(WaveletFilters {
249 dec_lo,
250 dec_hi,
251 rec_lo,
252 rec_hi,
253 })
254 }
255 8 => {
256 let dec_lo: Vec<f64> = vec![
258 -0.00011747678412476953,
259 0.0006754494064505693,
260 -0.00039174037337694705,
261 -0.004870352993451574,
262 0.008746094047405777,
263 0.013981027917398282,
264 -0.044088253930794755,
265 -0.017369301001807547,
266 0.12874742662047847,
267 0.0004724845739132828,
268 -0.2840155429615469,
269 -0.015829105256349306,
270 0.5853546836542067,
271 0.6756307362972898,
272 0.31287159091429995,
273 0.05441584224310401,
274 ];
275 let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
276 for (i, &val) in dec_lo.iter().enumerate().rev() {
277 dec_hi.push(if i % 2 == 0 { val } else { -val });
278 }
279 let mut rec_lo = dec_lo.clone();
280 rec_lo.reverse();
281 let mut rec_hi = dec_hi.clone();
282 rec_hi.reverse();
283 Ok(WaveletFilters {
284 dec_lo,
285 dec_hi,
286 rec_lo,
287 rec_hi,
288 })
289 }
290 9 => {
291 let dec_lo: Vec<f64> = vec![
293 3.93473203162716e-05,
294 -0.0002519631889427101,
295 0.00023038576352319597,
296 0.0018476468830562265,
297 -0.00428150368246343,
298 -0.004723204757751397,
299 0.022361662123679096,
300 0.00025094711483145197,
301 -0.06763282906132997,
302 0.03072568147933338,
303 0.14854074933810638,
304 -0.09684078322297646,
305 -0.2932737832791749,
306 0.13319738582500756,
307 0.6572880780513005,
308 0.6048231236901112,
309 0.24383467461259034,
310 0.038077947363878345,
311 ];
312 let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
313 for (i, &val) in dec_lo.iter().enumerate().rev() {
314 dec_hi.push(if i % 2 == 0 { val } else { -val });
315 }
316 let mut rec_lo = dec_lo.clone();
317 rec_lo.reverse();
318 let mut rec_hi = dec_hi.clone();
319 rec_hi.reverse();
320 Ok(WaveletFilters {
321 dec_lo,
322 dec_hi,
323 rec_lo,
324 rec_hi,
325 })
326 }
327 10 => {
328 let dec_lo: Vec<f64> = vec![
330 -1.3264202894521244e-05,
331 9.358867032006959e-05,
332 -0.00011646685512928545,
333 -0.0006858566949597116,
334 0.001992405295185056,
335 0.001395351747052901,
336 -0.010733175483330575,
337 0.0036065535669561697,
338 0.033212674059341,
339 -0.029457536821875813,
340 -0.07139414716639708,
341 0.09305736460357235,
342 0.12736934033579325,
343 -0.19594627437737705,
344 -0.24984642432731538,
345 0.2811723436605775,
346 0.6884590394536035,
347 0.5272011889317256,
348 0.1881768000776915,
349 0.026670057900555554,
350 ];
351 let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
352 for (i, &val) in dec_lo.iter().enumerate().rev() {
353 dec_hi.push(if i % 2 == 0 { val } else { -val });
354 }
355 let mut rec_lo = dec_lo.clone();
356 rec_lo.reverse();
357 let mut rec_hi = dec_hi.clone();
358 rec_hi.reverse();
359 Ok(WaveletFilters {
360 dec_lo,
361 dec_hi,
362 rec_lo,
363 rec_hi,
364 })
365 }
366 _ => Err(TransformError::InvalidInput(format!(
367 "Daubechies-{} not implemented (supported: 1-10)",
368 n
369 ))),
370 }
371 }
372
373 fn symlet(n: usize) -> Result<Self> {
375 Self::daubechies(n)
377 }
378
379 fn coiflet(n: usize) -> Result<Self> {
381 match n {
382 1 => {
383 let dec_lo: Vec<f64> = vec![
385 -0.015655728135791993,
386 -0.07273261951252645,
387 0.3848648468648578,
388 0.8525720202116004,
389 0.3378976624574818,
390 -0.07273261951252645,
391 ];
392 let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
393 for (i, &val) in dec_lo.iter().enumerate().rev() {
394 dec_hi.push(if i % 2 == 0 { val } else { -val });
395 }
396
397 let mut rec_lo = dec_lo.clone();
398 rec_lo.reverse();
399 let mut rec_hi = dec_hi.clone();
400 rec_hi.reverse();
401
402 Ok(WaveletFilters {
403 dec_lo,
404 dec_hi,
405 rec_lo,
406 rec_hi,
407 })
408 }
409 2 => {
410 let dec_lo: Vec<f64> = vec![
412 -0.000720549445520347,
413 -0.0018232088709110323,
414 0.005611434819368834,
415 0.02368017194684777,
416 -0.05943441864643109,
417 -0.07648859907828076,
418 0.4170051844232391,
419 0.8127236354494135,
420 0.3861100668227629,
421 -0.0673725547237256,
422 -0.04146493678687178,
423 0.01638733646320364,
424 ];
425 let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
426 for (i, &val) in dec_lo.iter().enumerate().rev() {
427 dec_hi.push(if i % 2 == 0 { val } else { -val });
428 }
429 let mut rec_lo = dec_lo.clone();
430 rec_lo.reverse();
431 let mut rec_hi = dec_hi.clone();
432 rec_hi.reverse();
433 Ok(WaveletFilters {
434 dec_lo,
435 dec_hi,
436 rec_lo,
437 rec_hi,
438 })
439 }
440 3 => {
441 let dec_lo: Vec<f64> = vec![
443 -3.459977319727278e-05,
444 -7.0983302506379e-05,
445 0.0004662169598204029,
446 0.0011175187708306303,
447 -0.0025745176881367972,
448 -0.009007976136730624,
449 0.015880544863669452,
450 0.03455502757329774,
451 -0.08230192710629983,
452 -0.07179982161915484,
453 0.42848347637737,
454 0.7937772226260872,
455 0.40517690240911824,
456 -0.06112339000297255,
457 -0.06577191128146936,
458 0.023452696142077168,
459 0.007782596425672746,
460 -0.003793512864380802,
461 ];
462 let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
463 for (i, &val) in dec_lo.iter().enumerate().rev() {
464 dec_hi.push(if i % 2 == 0 { val } else { -val });
465 }
466 let mut rec_lo = dec_lo.clone();
467 rec_lo.reverse();
468 let mut rec_hi = dec_hi.clone();
469 rec_hi.reverse();
470 Ok(WaveletFilters {
471 dec_lo,
472 dec_hi,
473 rec_lo,
474 rec_hi,
475 })
476 }
477 4 => {
478 let dec_lo: Vec<f64> = vec![
480 -1.7849909144933469e-06,
481 -3.259647940030751e-06,
482 3.1229861599195265e-05,
483 6.233885431278719e-05,
484 -0.0002599743371222568,
485 -0.0005890202246332165,
486 0.0012665610789256603,
487 0.0037514346971460866,
488 -0.0056582838001308835,
489 -0.015211728187697211,
490 0.02508225333794961,
491 0.03933442260558915,
492 -0.09622042453595264,
493 -0.06662747236681717,
494 0.43438603311435653,
495 0.7822389344242826,
496 0.41530842700068227,
497 -0.05607731960356926,
498 -0.08126671024919373,
499 0.02668230466960483,
500 0.01606894713157503,
501 -0.007346167936268051,
502 -0.001629492425226786,
503 0.000892313902537003,
504 ];
505 let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
506 for (i, &val) in dec_lo.iter().enumerate().rev() {
507 dec_hi.push(if i % 2 == 0 { val } else { -val });
508 }
509 let mut rec_lo = dec_lo.clone();
510 rec_lo.reverse();
511 let mut rec_hi = dec_hi.clone();
512 rec_hi.reverse();
513 Ok(WaveletFilters {
514 dec_lo,
515 dec_hi,
516 rec_lo,
517 rec_hi,
518 })
519 }
520 5 => {
521 let dec_lo: Vec<f64> = vec![
523 -9.604010112767894e-08,
524 -1.6237995172048338e-07,
525 2.0612203985788783e-06,
526 3.7007277113394796e-06,
527 -2.1270221672515614e-05,
528 -4.12198619242655e-05,
529 0.00014035632812373243,
530 0.0003018579416682448,
531 -0.0006375589261258812,
532 -0.0016616273039298788,
533 0.0024315754425382886,
534 0.006761520220620417,
535 -0.009159507338676163,
536 -0.019758391600965465,
537 0.032674799467057355,
538 0.041287530472117834,
539 -0.10556315130733723,
540 -0.06203775157498196,
541 0.4379823066591634,
542 0.7742936228603274,
543 0.42157126673075435,
544 -0.052046670253554764,
545 -0.09192158806008609,
546 0.028169744270532353,
547 0.023408322118927783,
548 -0.010131584846900276,
549 -0.00415931262757864,
550 0.0021782943778456947,
551 0.0003585777411617577,
552 -0.000212081862067494,
553 ];
554 let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
555 for (i, &val) in dec_lo.iter().enumerate().rev() {
556 dec_hi.push(if i % 2 == 0 { val } else { -val });
557 }
558 let mut rec_lo = dec_lo.clone();
559 rec_lo.reverse();
560 let mut rec_hi = dec_hi.clone();
561 rec_hi.reverse();
562 Ok(WaveletFilters {
563 dec_lo,
564 dec_hi,
565 rec_lo,
566 rec_hi,
567 })
568 }
569 _ => Err(TransformError::InvalidInput(format!(
570 "Coiflet-{} not implemented (supported: 1-5)",
571 n
572 ))),
573 }
574 }
575
576 fn biorthogonal(_p: usize, _q: usize) -> Result<Self> {
578 Self::haar()
580 }
581}
582
583#[derive(Debug, Clone)]
585pub struct DWT {
586 wavelet: WaveletType,
587 filters: WaveletFilters,
588 boundary: BoundaryMode,
589 level: Option<usize>,
590}
591
592impl DWT {
593 pub fn new(wavelet: WaveletType) -> Result<Self> {
595 let filters = WaveletFilters::from_wavelet(wavelet)?;
596 Ok(DWT {
597 wavelet,
598 filters,
599 boundary: BoundaryMode::Symmetric,
600 level: None,
601 })
602 }
603
604 pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
606 self.boundary = boundary;
607 self
608 }
609
610 pub fn with_level(mut self, level: usize) -> Self {
612 self.level = Some(level);
613 self
614 }
615
616 pub fn decompose(&self, signal: &ArrayView1<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
618 let n = signal.len();
619 if n < 2 {
620 return Err(TransformError::InvalidInput(
621 "Signal too short for DWT".to_string(),
622 ));
623 }
624
625 let extended = self.extend_signal(signal)?;
627
628 let approx = self.convolve_downsample(&extended, &self.filters.dec_lo)?;
630 let detail = self.convolve_downsample(&extended, &self.filters.dec_hi)?;
631
632 Ok((approx, detail))
633 }
634
635 pub fn wavedec(&self, signal: &ArrayView1<f64>) -> Result<Vec<Array1<f64>>> {
637 let max_level = self.max_decomposition_level(signal.len());
638 let level = self.level.unwrap_or(max_level).min(max_level);
639
640 let mut coeffs = Vec::with_capacity(level + 1);
641 let mut current = signal.to_owned();
642
643 for _ in 0..level {
644 let (approx, detail) = self.decompose(¤t.view())?;
645 coeffs.push(detail);
646 current = approx;
647 }
648
649 coeffs.push(current);
651 coeffs.reverse();
652
653 Ok(coeffs)
654 }
655
656 pub fn reconstruct(
658 &self,
659 approx: &ArrayView1<f64>,
660 detail: &ArrayView1<f64>,
661 ) -> Result<Array1<f64>> {
662 let approx_up = self.upsample_convolve(approx, &self.filters.rec_lo)?;
664 let detail_up = self.upsample_convolve(detail, &self.filters.rec_hi)?;
665
666 let min_len = approx_up.len().min(detail_up.len());
668 let mut reconstructed = Array1::zeros(min_len);
669 for i in 0..min_len {
670 reconstructed[i] = approx_up[i] + detail_up[i];
671 }
672
673 Ok(reconstructed)
674 }
675
676 pub fn waverec(&self, coeffs: &[Array1<f64>]) -> Result<Array1<f64>> {
678 if coeffs.is_empty() {
679 return Err(TransformError::InvalidInput(
680 "No coefficients provided for reconstruction".to_string(),
681 ));
682 }
683
684 let mut current = coeffs[0].clone();
685
686 for detail in &coeffs[1..] {
687 current = self.reconstruct(¤t.view(), &detail.view())?;
688 }
689
690 Ok(current)
691 }
692
693 fn extend_signal(&self, signal: &ArrayView1<f64>) -> Result<Array1<f64>> {
696 let filter_len = self.filters.dec_lo.len();
697 let n = signal.len();
698 let pad_len = filter_len - 1;
699
700 let mut extended = Array1::zeros(n + 2 * pad_len);
701
702 match self.boundary {
703 BoundaryMode::Zero => {
704 for i in 0..n {
705 extended[i + pad_len] = signal[i];
706 }
707 }
708 BoundaryMode::Constant => {
709 let first = signal[0];
710 let last = signal[n - 1];
711 for i in 0..pad_len {
712 extended[i] = first;
713 extended[n + pad_len + i] = last;
714 }
715 for i in 0..n {
716 extended[i + pad_len] = signal[i];
717 }
718 }
719 BoundaryMode::Symmetric => {
720 for i in 0..pad_len {
721 extended[pad_len - 1 - i] = signal[i.min(n - 1)];
722 extended[n + pad_len + i] = signal[(n - 1 - i).max(0)];
723 }
724 for i in 0..n {
725 extended[i + pad_len] = signal[i];
726 }
727 }
728 BoundaryMode::Periodic => {
729 for i in 0..pad_len {
730 extended[i] = signal[(n - pad_len + i) % n];
731 extended[n + pad_len + i] = signal[i % n];
732 }
733 for i in 0..n {
734 extended[i + pad_len] = signal[i];
735 }
736 }
737 BoundaryMode::Reflect => {
738 for i in 0..pad_len {
739 let idx1 = if i < n { i } else { n - 1 };
740 let idx2 = if n > i + 1 { n - 1 - i } else { 0 };
741 extended[pad_len - 1 - i] = signal[idx1];
742 extended[n + pad_len + i] = signal[idx2];
743 }
744 for i in 0..n {
745 extended[i + pad_len] = signal[i];
746 }
747 }
748 }
749
750 Ok(extended)
751 }
752
753 fn convolve_downsample(&self, signal: &Array1<f64>, filter: &[f64]) -> Result<Array1<f64>> {
754 let n = signal.len();
755 let filter_len = filter.len();
756 let output_len = (n + 1) / 2;
757 let mut output = Array1::zeros(output_len);
758
759 for i in 0..output_len {
760 let pos = i * 2;
761 let mut sum = 0.0;
762
763 for (j, &coeff) in filter.iter().enumerate() {
764 let idx = pos + j;
765 if idx < n {
766 sum += signal[idx] * coeff;
767 }
768 }
769
770 output[i] = sum;
771 }
772
773 Ok(output)
774 }
775
776 fn upsample_convolve(&self, signal: &ArrayView1<f64>, filter: &[f64]) -> Result<Array1<f64>> {
777 let n = signal.len();
778 let filter_len = filter.len();
779 let output_len = n * 2;
780 let mut output = Array1::zeros(output_len);
781
782 let mut upsampled = Array1::zeros(output_len);
784 for i in 0..n {
785 upsampled[i * 2] = signal[i];
786 }
787
788 for i in 0..output_len {
790 let mut sum = 0.0;
791 for (j, &coeff) in filter.iter().enumerate() {
792 if i >= j && i - j < output_len {
793 sum += upsampled[i - j] * coeff;
794 }
795 }
796 output[i] = sum;
797 }
798
799 Ok(output)
800 }
801
802 fn max_decomposition_level(&self, signal_len: usize) -> usize {
803 let filter_len = self.filters.dec_lo.len();
804 let mut level: usize = 0;
805 let mut current_len = signal_len;
806
807 while current_len >= filter_len {
808 current_len = (current_len + 1) / 2;
809 level += 1;
810 }
811
812 level.saturating_sub(1)
813 }
814}
815
816#[derive(Debug, Clone)]
818pub struct DWT2D {
819 wavelet: WaveletType,
820 filters: WaveletFilters,
821 boundary: BoundaryMode,
822 level: Option<usize>,
823}
824
825impl DWT2D {
826 pub fn new(wavelet: WaveletType) -> Result<Self> {
828 let filters = WaveletFilters::from_wavelet(wavelet)?;
829 Ok(DWT2D {
830 wavelet,
831 filters,
832 boundary: BoundaryMode::Symmetric,
833 level: None,
834 })
835 }
836
837 pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
839 self.boundary = boundary;
840 self
841 }
842
843 pub fn with_level(mut self, level: usize) -> Self {
845 self.level = Some(level);
846 self
847 }
848
849 pub fn decompose2(&self, image: &ArrayView2<f64>) -> Result<Dwt2dCoeffs> {
851 let (rows, cols) = image.dim();
852 if rows < 2 || cols < 2 {
853 return Err(TransformError::InvalidInput(
854 "Image too small for 2D DWT".to_string(),
855 ));
856 }
857
858 let dwt1d = DWT {
859 wavelet: self.wavelet,
860 filters: self.filters.clone(),
861 boundary: self.boundary,
862 level: None,
863 };
864
865 let mut row_results_approx = Vec::with_capacity(rows);
867 let mut row_results_detail = Vec::with_capacity(rows);
868
869 for row_idx in 0..rows {
870 let row = image.row(row_idx);
871 let (approx, detail) = dwt1d.decompose(&row)?;
872 row_results_approx.push(approx);
873 row_results_detail.push(detail);
874 }
875
876 let approx_rows = row_results_approx[0].len();
877 let detail_rows = row_results_detail[0].len();
878
879 let mut approx_mat = Array2::zeros((rows, approx_rows));
881 let mut detail_mat = Array2::zeros((rows, detail_rows));
882
883 for (i, (app, det)) in row_results_approx
884 .iter()
885 .zip(row_results_detail.iter())
886 .enumerate()
887 {
888 for (j, &val) in app.iter().enumerate() {
889 approx_mat[[i, j]] = val;
890 }
891 for (j, &val) in det.iter().enumerate() {
892 detail_mat[[i, j]] = val;
893 }
894 }
895
896 let (ll, lh) = self.decompose_columns(&approx_mat.view(), &dwt1d)?;
898 let (hl, hh) = self.decompose_columns(&detail_mat.view(), &dwt1d)?;
899
900 Ok(Dwt2dCoeffs { ll, lh, hl, hh })
901 }
902
903 fn decompose_columns(
904 &self,
905 mat: &ArrayView2<f64>,
906 dwt1d: &DWT,
907 ) -> Result<(Array2<f64>, Array2<f64>)> {
908 let (rows, cols) = mat.dim();
909 let mut col_results_approx = Vec::with_capacity(cols);
910 let mut col_results_detail = Vec::with_capacity(cols);
911
912 for col_idx in 0..cols {
913 let col = mat.column(col_idx);
914 let (approx, detail) = dwt1d.decompose(&col)?;
915 col_results_approx.push(approx);
916 col_results_detail.push(detail);
917 }
918
919 let approx_cols = col_results_approx[0].len();
920 let detail_cols = col_results_detail[0].len();
921
922 let mut approx_result = Array2::zeros((approx_cols, cols));
923 let mut detail_result = Array2::zeros((detail_cols, cols));
924
925 for (j, (app, det)) in col_results_approx
926 .iter()
927 .zip(col_results_detail.iter())
928 .enumerate()
929 {
930 for (i, &val) in app.iter().enumerate() {
931 approx_result[[i, j]] = val;
932 }
933 for (i, &val) in det.iter().enumerate() {
934 detail_result[[i, j]] = val;
935 }
936 }
937
938 Ok((approx_result, detail_result))
939 }
940
941 pub fn wavedec2(&self, image: &ArrayView2<f64>) -> Result<Vec<Dwt2dCoeffs>> {
943 let max_level = self.max_decomposition_level_2d(image.dim());
944 let level = self.level.unwrap_or(max_level).min(max_level);
945
946 let mut coeffs = Vec::with_capacity(level);
947 let mut current = image.to_owned();
948
949 for _ in 0..level {
950 let dwt2d_coeffs = self.decompose2(¤t.view())?;
951 coeffs.push(dwt2d_coeffs.clone());
952 current = dwt2d_coeffs.ll;
953 }
954
955 Ok(coeffs)
956 }
957
958 fn max_decomposition_level_2d(&self, shape: (usize, usize)) -> usize {
959 let filter_len = self.filters.dec_lo.len();
960 let min_dim = shape.0.min(shape.1);
961
962 let mut level: usize = 0;
963 let mut current_dim = min_dim;
964
965 while current_dim >= filter_len {
966 current_dim = (current_dim + 1) / 2;
967 level += 1;
968 }
969
970 level.saturating_sub(1)
971 }
972}
973
974#[derive(Debug, Clone)]
976pub struct Dwt2dCoeffs {
977 pub ll: Array2<f64>,
979 pub lh: Array2<f64>,
981 pub hl: Array2<f64>,
983 pub hh: Array2<f64>,
985}
986
987#[derive(Debug, Clone)]
992pub struct Dwt3dCoeffs {
993 pub lll: Array3<f64>,
995 pub llh: Array3<f64>,
997 pub lhl: Array3<f64>,
999 pub lhh: Array3<f64>,
1001 pub hll: Array3<f64>,
1003 pub hlh: Array3<f64>,
1005 pub hhl: Array3<f64>,
1007 pub hhh: Array3<f64>,
1009}
1010
1011#[derive(Debug, Clone)]
1013pub struct DWTN {
1014 wavelet: WaveletType,
1015 boundary: BoundaryMode,
1016 level: Option<usize>,
1017}
1018
1019impl DWTN {
1020 pub fn new(wavelet: WaveletType) -> Self {
1022 DWTN {
1023 wavelet,
1024 boundary: BoundaryMode::Symmetric,
1025 level: None,
1026 }
1027 }
1028
1029 pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
1031 self.boundary = boundary;
1032 self
1033 }
1034
1035 pub fn with_level(mut self, level: usize) -> Self {
1037 self.level = Some(level);
1038 self
1039 }
1040
1041 pub fn decompose3(&self, volume: &Array3<f64>) -> Result<Dwt3dCoeffs> {
1051 let (d0, d1, d2) = volume.dim();
1052 if d0 < 2 || d1 < 2 || d2 < 2 {
1053 return Err(TransformError::InvalidInput(
1054 "Volume too small for 3D DWT: all dimensions must be >= 2".to_string(),
1055 ));
1056 }
1057
1058 let dwt1d = DWT::new(self.wavelet)?.with_boundary(self.boundary);
1059
1060 let half0 = (d0 + 1) / 2; let mut lo_x = Array3::<f64>::zeros((half0, d1, d2));
1065 let mut hi_x = Array3::<f64>::zeros((half0, d1, d2));
1066
1067 for j in 0..d1 {
1068 for k in 0..d2 {
1069 let col: Vec<f64> = (0..d0).map(|i| volume[[i, j, k]]).collect();
1070 let col_arr = Array1::from(col);
1071 let (approx, detail) = dwt1d.decompose(&col_arr.view())?;
1072 for i in 0..approx.len().min(half0) {
1073 lo_x[[i, j, k]] = approx[i];
1074 }
1075 for i in 0..detail.len().min(half0) {
1076 hi_x[[i, j, k]] = detail[i];
1077 }
1078 }
1079 }
1080
1081 let half1 = (d1 + 1) / 2;
1083 let (lo_x_lo_y, lo_x_hi_y) =
1084 self.apply_dwt_axis1_3d(&lo_x, &dwt1d, half0, d1, d2, half1)?;
1085 let (hi_x_lo_y, hi_x_hi_y) =
1086 self.apply_dwt_axis1_3d(&hi_x, &dwt1d, half0, d1, d2, half1)?;
1087
1088 let half2 = (d2 + 1) / 2;
1090 let (lll, llh) = self.apply_dwt_axis2_3d(&lo_x_lo_y, &dwt1d, half0, half1, d2, half2)?;
1091 let (lhl, lhh) = self.apply_dwt_axis2_3d(&lo_x_hi_y, &dwt1d, half0, half1, d2, half2)?;
1092 let (hll, hlh) = self.apply_dwt_axis2_3d(&hi_x_lo_y, &dwt1d, half0, half1, d2, half2)?;
1093 let (hhl, hhh) = self.apply_dwt_axis2_3d(&hi_x_hi_y, &dwt1d, half0, half1, d2, half2)?;
1094
1095 Ok(Dwt3dCoeffs {
1096 lll,
1097 llh,
1098 lhl,
1099 lhh,
1100 hll,
1101 hlh,
1102 hhl,
1103 hhh,
1104 })
1105 }
1106
1107 fn apply_dwt_axis1_3d(
1109 &self,
1110 arr: &Array3<f64>,
1111 dwt1d: &DWT,
1112 size0: usize,
1113 size1: usize,
1114 size2: usize,
1115 out1: usize,
1116 ) -> Result<(Array3<f64>, Array3<f64>)> {
1117 let mut lo = Array3::<f64>::zeros((size0, out1, size2));
1118 let mut hi = Array3::<f64>::zeros((size0, out1, size2));
1119
1120 for i in 0..size0 {
1121 for k in 0..size2 {
1122 let col: Vec<f64> = (0..size1).map(|j| arr[[i, j, k]]).collect();
1123 let col_arr = Array1::from(col);
1124 let (approx, detail) = dwt1d.decompose(&col_arr.view())?;
1125 for j in 0..approx.len().min(out1) {
1126 lo[[i, j, k]] = approx[j];
1127 }
1128 for j in 0..detail.len().min(out1) {
1129 hi[[i, j, k]] = detail[j];
1130 }
1131 }
1132 }
1133 Ok((lo, hi))
1134 }
1135
1136 fn apply_dwt_axis2_3d(
1138 &self,
1139 arr: &Array3<f64>,
1140 dwt1d: &DWT,
1141 size0: usize,
1142 size1: usize,
1143 size2: usize,
1144 out2: usize,
1145 ) -> Result<(Array3<f64>, Array3<f64>)> {
1146 let mut lo = Array3::<f64>::zeros((size0, size1, out2));
1147 let mut hi = Array3::<f64>::zeros((size0, size1, out2));
1148
1149 for i in 0..size0 {
1150 for j in 0..size1 {
1151 let row: Vec<f64> = (0..size2).map(|k| arr[[i, j, k]]).collect();
1152 let row_arr = Array1::from(row);
1153 let (approx, detail) = dwt1d.decompose(&row_arr.view())?;
1154 for k in 0..approx.len().min(out2) {
1155 lo[[i, j, k]] = approx[k];
1156 }
1157 for k in 0..detail.len().min(out2) {
1158 hi[[i, j, k]] = detail[k];
1159 }
1160 }
1161 }
1162 Ok((lo, hi))
1163 }
1164}
1165
1166#[cfg(test)]
1167mod tests {
1168 use super::*;
1169 use approx::assert_abs_diff_eq;
1170
1171 #[test]
1172 fn test_dwt_haar() -> Result<()> {
1173 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1174 let dwt = DWT::new(WaveletType::Haar)?;
1175
1176 let (approx, detail) = dwt.decompose(&signal.view())?;
1177
1178 assert!(approx.len() > 0);
1179 assert!(detail.len() > 0);
1180 assert_eq!(approx.len(), detail.len());
1181
1182 Ok(())
1183 }
1184
1185 #[test]
1186 fn test_dwt_multilevel() -> Result<()> {
1187 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1188 let dwt = DWT::new(WaveletType::Haar)?.with_level(2);
1189
1190 let coeffs = dwt.wavedec(&signal.view())?;
1191
1192 assert_eq!(coeffs.len(), 3); Ok(())
1195 }
1196
1197 #[test]
1198 fn test_dwt_reconstruction() -> Result<()> {
1199 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1200 let dwt = DWT::new(WaveletType::Haar)?;
1201
1202 let (approx, detail) = dwt.decompose(&signal.view())?;
1203 let reconstructed = dwt.reconstruct(&approx.view(), &detail.view())?;
1204
1205 assert!(reconstructed.len() >= signal.len() - 2);
1207
1208 Ok(())
1209 }
1210
1211 #[test]
1212 fn test_dwt2d() -> Result<()> {
1213 let image = Array2::from_shape_fn((8, 8), |(i, j)| (i + j) as f64);
1214 let dwt2d = DWT2D::new(WaveletType::Haar)?;
1215
1216 let coeffs = dwt2d.decompose2(&image.view())?;
1217
1218 assert!(coeffs.ll.len() > 0);
1219 assert!(coeffs.lh.len() > 0);
1220 assert!(coeffs.hl.len() > 0);
1221 assert!(coeffs.hh.len() > 0);
1222
1223 Ok(())
1224 }
1225
1226 #[test]
1227 fn test_wavelet_filters() -> Result<()> {
1228 let filters = WaveletFilters::from_wavelet(WaveletType::Haar)?;
1229
1230 assert_eq!(filters.dec_lo.len(), 2);
1231 assert_eq!(filters.dec_hi.len(), 2);
1232 assert_eq!(filters.rec_lo.len(), 2);
1233 assert_eq!(filters.rec_hi.len(), 2);
1234
1235 Ok(())
1236 }
1237
1238 fn check_filter_normalisation(filters: &WaveletFilters) {
1240 let sum: f64 = filters.dec_lo.iter().sum();
1241 let diff = (sum - 2.0_f64.sqrt()).abs();
1242 assert!(
1243 diff < 1e-6,
1244 "dec_lo sum {sum} is not sqrt(2); diff = {diff}"
1245 );
1246 }
1247
1248 #[test]
1249 fn test_daubechies_db1_is_haar() -> Result<()> {
1250 let haar = WaveletFilters::from_wavelet(WaveletType::Haar)?;
1251 let db1 = WaveletFilters::from_wavelet(WaveletType::Daubechies(1))?;
1252 assert_abs_diff_eq!(haar.dec_lo[0], db1.dec_lo[0], epsilon = 1e-10);
1253 assert_abs_diff_eq!(haar.dec_lo[1], db1.dec_lo[1], epsilon = 1e-10);
1254 Ok(())
1255 }
1256
1257 #[test]
1258 fn test_daubechies_db3_filters() -> Result<()> {
1259 let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(3))?;
1260 assert_eq!(f.dec_lo.len(), 6);
1261 check_filter_normalisation(&f);
1262 Ok(())
1263 }
1264
1265 #[test]
1266 fn test_daubechies_db5_filters() -> Result<()> {
1267 let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(5))?;
1268 assert_eq!(f.dec_lo.len(), 10);
1269 check_filter_normalisation(&f);
1270 Ok(())
1271 }
1272
1273 #[test]
1274 fn test_daubechies_db6_filters() -> Result<()> {
1275 let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(6))?;
1276 assert_eq!(f.dec_lo.len(), 12);
1277 check_filter_normalisation(&f);
1278 Ok(())
1279 }
1280
1281 #[test]
1282 fn test_daubechies_db7_filters() -> Result<()> {
1283 let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(7))?;
1284 assert_eq!(f.dec_lo.len(), 14);
1285 check_filter_normalisation(&f);
1286 Ok(())
1287 }
1288
1289 #[test]
1290 fn test_daubechies_db8_filters() -> Result<()> {
1291 let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(8))?;
1292 assert_eq!(f.dec_lo.len(), 16);
1293 check_filter_normalisation(&f);
1294 Ok(())
1295 }
1296
1297 #[test]
1298 fn test_daubechies_db10_filters() -> Result<()> {
1299 let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(10))?;
1300 assert_eq!(f.dec_lo.len(), 20);
1301 check_filter_normalisation(&f);
1302 Ok(())
1303 }
1304
1305 #[test]
1306 fn test_daubechies_unsupported_returns_error() {
1307 let result = WaveletFilters::from_wavelet(WaveletType::Daubechies(11));
1308 assert!(result.is_err(), "DB11 should return an error");
1309 }
1310
1311 #[test]
1312 fn test_coiflet1_filters() -> Result<()> {
1313 let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(1))?;
1314 assert_eq!(f.dec_lo.len(), 6);
1315 check_filter_normalisation(&f);
1316 Ok(())
1317 }
1318
1319 #[test]
1320 fn test_coiflet2_filters() -> Result<()> {
1321 let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(2))?;
1322 assert_eq!(f.dec_lo.len(), 12);
1323 check_filter_normalisation(&f);
1324 Ok(())
1325 }
1326
1327 #[test]
1328 fn test_coiflet3_filters() -> Result<()> {
1329 let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(3))?;
1330 assert_eq!(f.dec_lo.len(), 18);
1331 check_filter_normalisation(&f);
1332 Ok(())
1333 }
1334
1335 #[test]
1336 fn test_coiflet4_filters() -> Result<()> {
1337 let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(4))?;
1338 assert_eq!(f.dec_lo.len(), 24);
1339 check_filter_normalisation(&f);
1340 Ok(())
1341 }
1342
1343 #[test]
1344 fn test_coiflet5_filters() -> Result<()> {
1345 let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(5))?;
1346 assert_eq!(f.dec_lo.len(), 30);
1347 check_filter_normalisation(&f);
1348 Ok(())
1349 }
1350
1351 #[test]
1352 fn test_coiflet_unsupported_returns_error() {
1353 let result = WaveletFilters::from_wavelet(WaveletType::Coiflet(6));
1354 assert!(result.is_err(), "Coif6 should return an error");
1355 }
1356
1357 #[test]
1358 fn test_dwt_db3_roundtrip() -> Result<()> {
1359 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1360 let dwt = DWT::new(WaveletType::Daubechies(3))?;
1361
1362 let (approx, detail) = dwt.decompose(&signal.view())?;
1363 let reconstructed = dwt.reconstruct(&approx.view(), &detail.view())?;
1364
1365 assert!(reconstructed.len() >= signal.len() - 2);
1367 Ok(())
1368 }
1369
1370 #[test]
1371 fn test_dwt_coif2_roundtrip() -> Result<()> {
1372 let signal = Array1::from_vec(vec![
1373 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1374 ]);
1375 let dwt = DWT::new(WaveletType::Coiflet(2))?;
1376
1377 let (approx, detail) = dwt.decompose(&signal.view())?;
1378 let reconstructed = dwt.reconstruct(&approx.view(), &detail.view())?;
1379
1380 assert!(reconstructed.len() >= signal.len() - 4);
1381 Ok(())
1382 }
1383
1384 fn check_unit_energy(filters: &WaveletFilters) {
1390 let energy: f64 = filters.dec_lo.iter().map(|x| x * x).sum();
1391 let diff = (energy - 1.0).abs();
1392 assert!(
1393 diff < 1e-10,
1394 "dec_lo unit-energy check failed: sum-of-squares = {energy}, diff from 1.0 = {diff}"
1395 );
1396 }
1397
1398 fn check_qmf_orthogonality(filters: &WaveletFilters) {
1400 let inner: f64 = filters
1401 .dec_lo
1402 .iter()
1403 .zip(filters.dec_hi.iter())
1404 .map(|(a, b)| a * b)
1405 .sum();
1406 let diff = inner.abs();
1407 assert!(
1408 diff < 1e-10,
1409 "QMF orthogonality check failed: <dec_lo, dec_hi> = {inner}, abs = {diff}"
1410 );
1411 }
1412
1413 #[test]
1418 fn test_db2_all_invariants() -> Result<()> {
1419 let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(2))?;
1420 assert_eq!(f.dec_lo.len(), 4, "db2 length must be 4");
1421 check_filter_normalisation(&f);
1422 check_unit_energy(&f);
1423 check_qmf_orthogonality(&f);
1424 Ok(())
1425 }
1426
1427 #[test]
1428 fn test_db4_all_invariants() -> Result<()> {
1429 let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(4))?;
1430 assert_eq!(f.dec_lo.len(), 8, "db4 length must be 8");
1431 check_filter_normalisation(&f);
1432 check_unit_energy(&f);
1433 check_qmf_orthogonality(&f);
1434 Ok(())
1435 }
1436
1437 #[test]
1438 fn test_db6_all_invariants() -> Result<()> {
1439 let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(6))?;
1440 assert_eq!(f.dec_lo.len(), 12, "db6 length must be 12");
1441 check_filter_normalisation(&f);
1442 check_unit_energy(&f);
1443 check_qmf_orthogonality(&f);
1444 Ok(())
1445 }
1446
1447 #[test]
1448 fn test_db8_all_invariants() -> Result<()> {
1449 let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(8))?;
1450 assert_eq!(f.dec_lo.len(), 16, "db8 length must be 16");
1451 check_filter_normalisation(&f);
1452 check_unit_energy(&f);
1453 check_qmf_orthogonality(&f);
1454 Ok(())
1455 }
1456
1457 #[test]
1458 fn test_db10_all_invariants() -> Result<()> {
1459 let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(10))?;
1460 assert_eq!(f.dec_lo.len(), 20, "db10 length must be 20");
1461 check_filter_normalisation(&f);
1462 check_unit_energy(&f);
1463 check_qmf_orthogonality(&f);
1464 Ok(())
1465 }
1466
1467 #[test]
1472 fn test_coif1_all_invariants() -> Result<()> {
1473 let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(1))?;
1474 assert_eq!(f.dec_lo.len(), 6, "coif1 length must be 6");
1475 check_filter_normalisation(&f);
1476 check_unit_energy(&f);
1477 check_qmf_orthogonality(&f);
1478 Ok(())
1479 }
1480
1481 #[test]
1482 fn test_coif2_all_invariants() -> Result<()> {
1483 let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(2))?;
1484 assert_eq!(f.dec_lo.len(), 12, "coif2 length must be 12");
1485 check_filter_normalisation(&f);
1486 check_unit_energy(&f);
1487 check_qmf_orthogonality(&f);
1488 Ok(())
1489 }
1490
1491 #[test]
1492 fn test_coif3_all_invariants() -> Result<()> {
1493 let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(3))?;
1494 assert_eq!(f.dec_lo.len(), 18, "coif3 length must be 18");
1495 check_filter_normalisation(&f);
1496 check_unit_energy(&f);
1497 check_qmf_orthogonality(&f);
1498 Ok(())
1499 }
1500
1501 #[test]
1502 fn test_coif4_all_invariants() -> Result<()> {
1503 let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(4))?;
1504 assert_eq!(f.dec_lo.len(), 24, "coif4 length must be 24");
1505 check_filter_normalisation(&f);
1506 check_unit_energy(&f);
1507 check_qmf_orthogonality(&f);
1508 Ok(())
1509 }
1510
1511 #[test]
1512 fn test_coif5_all_invariants() -> Result<()> {
1513 let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(5))?;
1514 assert_eq!(f.dec_lo.len(), 30, "coif5 length must be 30");
1515 check_filter_normalisation(&f);
1516 check_unit_energy(&f);
1517 check_qmf_orthogonality(&f);
1518 Ok(())
1519 }
1520
1521 #[test]
1534 fn test_dwt3d_constant_volume_lll_scaling() -> Result<()> {
1535 let c = 3.0_f64;
1536 let volume = Array3::from_elem((8, 8, 8), c);
1538 let dwtn = DWTN::new(WaveletType::Haar);
1539 let coeffs = dwtn.decompose3(&volume)?;
1540
1541 assert!(coeffs.lll.len() > 0, "LLL subband must not be empty");
1543 let expected_lll = c * 2.0_f64.sqrt().powi(3); for val in coeffs.lll.iter() {
1545 assert_abs_diff_eq!(*val, expected_lll, epsilon = 1e-10);
1546 }
1547
1548 let detail_bands: [&Array3<f64>; 7] = [
1550 &coeffs.llh,
1551 &coeffs.lhl,
1552 &coeffs.lhh,
1553 &coeffs.hll,
1554 &coeffs.hlh,
1555 &coeffs.hhl,
1556 &coeffs.hhh,
1557 ];
1558 for band in detail_bands {
1559 for val in band.iter() {
1560 assert_abs_diff_eq!(*val, 0.0, epsilon = 1e-10);
1561 }
1562 }
1563 Ok(())
1564 }
1565
1566 #[test]
1568 fn test_dwt3d_output_shape() -> Result<()> {
1569 let volume = Array3::from_shape_fn((8, 6, 4), |(i, j, k)| (i + j + k) as f64);
1570 let dwtn = DWTN::new(WaveletType::Haar);
1571 let coeffs = dwtn.decompose3(&volume)?;
1572
1573 assert_eq!(coeffs.lll.dim(), (4, 3, 2));
1575 assert_eq!(coeffs.hhh.dim(), (4, 3, 2));
1576 Ok(())
1577 }
1578
1579 #[test]
1581 fn test_dwt3d_rejects_too_small_volume() {
1582 let volume = Array3::from_elem((1, 8, 8), 1.0);
1583 let dwtn = DWTN::new(WaveletType::Haar);
1584 assert!(
1585 dwtn.decompose3(&volume).is_err(),
1586 "decompose3 must reject a volume with any dimension < 2"
1587 );
1588 }
1589}