Skip to main content

scirs2_transform/signal_transforms/
dwt.rs

1//! Discrete Wavelet Transform (DWT) Implementation
2//!
3//! Provides 1D, 2D, and N-D discrete wavelet transforms with multiple wavelet families.
4//! Implements efficient decomposition and reconstruction with proper boundary handling.
5
6use crate::error::{Result, TransformError};
7use rayon::prelude::*;
8use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
9
10/// Wavelet types supported by the DWT implementation
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum WaveletType {
13    /// Haar wavelet (Daubechies-1)
14    Haar,
15    /// Daubechies wavelets (N = 2, 4, 6, 8, 10, 12, 14, 16, 18, 20)
16    Daubechies(usize),
17    /// Symlet wavelets
18    Symlet(usize),
19    /// Coiflet wavelets
20    Coiflet(usize),
21    /// Biorthogonal wavelets
22    Biorthogonal(usize, usize),
23}
24
25/// Boundary extension modes for DWT
26#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum BoundaryMode {
28    /// Zero padding
29    Zero,
30    /// Constant padding (edge values)
31    Constant,
32    /// Symmetric padding
33    Symmetric,
34    /// Periodic padding
35    Periodic,
36    /// Reflect padding
37    Reflect,
38}
39
40/// Wavelet filter coefficients
41#[derive(Debug, Clone)]
42pub struct WaveletFilters {
43    /// Low-pass decomposition filter
44    pub dec_lo: Vec<f64>,
45    /// High-pass decomposition filter
46    pub dec_hi: Vec<f64>,
47    /// Low-pass reconstruction filter
48    pub rec_lo: Vec<f64>,
49    /// High-pass reconstruction filter
50    pub rec_hi: Vec<f64>,
51}
52
53impl WaveletFilters {
54    /// Get filter coefficients for a specific wavelet type
55    pub fn from_wavelet(wavelet: WaveletType) -> Result<Self> {
56        match wavelet {
57            WaveletType::Haar => Self::haar(),
58            WaveletType::Daubechies(n) => Self::daubechies(n),
59            WaveletType::Symlet(n) => Self::symlet(n),
60            WaveletType::Coiflet(n) => Self::coiflet(n),
61            WaveletType::Biorthogonal(p, q) => Self::biorthogonal(p, q),
62        }
63    }
64
65    /// Haar wavelet filters
66    fn haar() -> Result<Self> {
67        let norm = 1.0 / 2.0_f64.sqrt();
68        Ok(WaveletFilters {
69            dec_lo: vec![norm, norm],
70            dec_hi: vec![norm, -norm],
71            rec_lo: vec![norm, norm],
72            rec_hi: vec![-norm, norm],
73        })
74    }
75
76    /// Daubechies wavelet filters
77    fn daubechies(n: usize) -> Result<Self> {
78        match n {
79            2 => {
80                // DB2 (Daubechies-4 coefficients)
81                let sqrt3 = 3.0_f64.sqrt();
82                let denom = 4.0 * 2.0_f64.sqrt();
83                let dec_lo = vec![
84                    (1.0 + sqrt3) / denom,
85                    (3.0 + sqrt3) / denom,
86                    (3.0 - sqrt3) / denom,
87                    (1.0 - sqrt3) / denom,
88                ];
89                let mut dec_hi = Vec::with_capacity(dec_lo.len());
90                for (i, &val) in dec_lo.iter().enumerate().rev() {
91                    dec_hi.push(if i % 2 == 0 { val } else { -val });
92                }
93
94                let mut rec_lo = dec_lo.clone();
95                rec_lo.reverse();
96                let mut rec_hi = dec_hi.clone();
97                rec_hi.reverse();
98
99                Ok(WaveletFilters {
100                    dec_lo,
101                    dec_hi,
102                    rec_lo,
103                    rec_hi,
104                })
105            }
106            4 => {
107                // DB4 (Daubechies-8 coefficients)
108                let dec_lo = vec![
109                    -0.010597401784997,
110                    0.032883011666983,
111                    0.030841381835987,
112                    -0.187034811718881,
113                    -0.027983769416984,
114                    0.630880767929590,
115                    0.714846570552542,
116                    0.230377813308855,
117                ];
118                let mut dec_hi = Vec::with_capacity(dec_lo.len());
119                for (i, &val) in dec_lo.iter().enumerate().rev() {
120                    dec_hi.push(if i % 2 == 0 { val } else { -val });
121                }
122
123                let mut rec_lo = dec_lo.clone();
124                rec_lo.reverse();
125                let mut rec_hi = dec_hi.clone();
126                rec_hi.reverse();
127
128                Ok(WaveletFilters {
129                    dec_lo,
130                    dec_hi,
131                    rec_lo,
132                    rec_hi,
133                })
134            }
135            // DB1 is identical to Haar
136            1 => Self::haar(),
137            3 => {
138                // DB3 (Daubechies-6 coefficients) — standard orthonormal form
139                let dec_lo: Vec<f64> = vec![
140                    0.035226291882100656,
141                    -0.08544127388224149,
142                    -0.13501102001039084,
143                    0.4598775021193313,
144                    0.8068915093133388,
145                    0.3326705529509569,
146                ];
147                let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
148                for (i, &val) in dec_lo.iter().enumerate().rev() {
149                    dec_hi.push(if i % 2 == 0 { val } else { -val });
150                }
151                let mut rec_lo = dec_lo.clone();
152                rec_lo.reverse();
153                let mut rec_hi = dec_hi.clone();
154                rec_hi.reverse();
155                Ok(WaveletFilters {
156                    dec_lo,
157                    dec_hi,
158                    rec_lo,
159                    rec_hi,
160                })
161            }
162            5 => {
163                // DB5 (Daubechies-10 coefficients)
164                let dec_lo: Vec<f64> = vec![
165                    0.003335725285001549,
166                    -0.012580751999015526,
167                    -0.006241490213011705,
168                    0.07757149384006515,
169                    -0.03224486958502952,
170                    -0.24229488706619015,
171                    0.13842814590110342,
172                    0.7243085284385744,
173                    0.6038292697974898,
174                    0.16010239797412501,
175                ];
176                let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
177                for (i, &val) in dec_lo.iter().enumerate().rev() {
178                    dec_hi.push(if i % 2 == 0 { val } else { -val });
179                }
180                let mut rec_lo = dec_lo.clone();
181                rec_lo.reverse();
182                let mut rec_hi = dec_hi.clone();
183                rec_hi.reverse();
184                Ok(WaveletFilters {
185                    dec_lo,
186                    dec_hi,
187                    rec_lo,
188                    rec_hi,
189                })
190            }
191            6 => {
192                // DB6 (Daubechies-12 coefficients) — PyWavelets orthonormal form
193                let dec_lo: Vec<f64> = vec![
194                    -0.0010773010853084796,
195                    0.004777257510945511,
196                    0.0005538422011614961,
197                    -0.03158203931748603,
198                    0.027522865530305727,
199                    0.09750160558732304,
200                    -0.12976686756726194,
201                    -0.22626469396543983,
202                    0.31525035170919763,
203                    0.7511339080210954,
204                    0.49462389039845306,
205                    0.11154074335010947,
206                ];
207                let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
208                for (i, &val) in dec_lo.iter().enumerate().rev() {
209                    dec_hi.push(if i % 2 == 0 { val } else { -val });
210                }
211                let mut rec_lo = dec_lo.clone();
212                rec_lo.reverse();
213                let mut rec_hi = dec_hi.clone();
214                rec_hi.reverse();
215                Ok(WaveletFilters {
216                    dec_lo,
217                    dec_hi,
218                    rec_lo,
219                    rec_hi,
220                })
221            }
222            7 => {
223                // DB7 (Daubechies-14 coefficients) — PyWavelets orthonormal form
224                let dec_lo: Vec<f64> = vec![
225                    0.00035371379997452024,
226                    -0.0018016407040474908,
227                    0.0004295779729213665,
228                    0.01255099855609984,
229                    -0.01657454163066688,
230                    -0.03802993693501441,
231                    0.08061260915108308,
232                    0.07130921926683026,
233                    -0.22403618499387498,
234                    -0.14390600392856498,
235                    0.4697822874051931,
236                    0.7291320908462351,
237                    0.3965393194819173,
238                    0.07785205408500918,
239                ];
240                let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
241                for (i, &val) in dec_lo.iter().enumerate().rev() {
242                    dec_hi.push(if i % 2 == 0 { val } else { -val });
243                }
244                let mut rec_lo = dec_lo.clone();
245                rec_lo.reverse();
246                let mut rec_hi = dec_hi.clone();
247                rec_hi.reverse();
248                Ok(WaveletFilters {
249                    dec_lo,
250                    dec_hi,
251                    rec_lo,
252                    rec_hi,
253                })
254            }
255            8 => {
256                // DB8 (Daubechies-16 coefficients) — PyWavelets orthonormal form
257                let dec_lo: Vec<f64> = vec![
258                    -0.00011747678412476953,
259                    0.0006754494064505693,
260                    -0.00039174037337694705,
261                    -0.004870352993451574,
262                    0.008746094047405777,
263                    0.013981027917398282,
264                    -0.044088253930794755,
265                    -0.017369301001807547,
266                    0.12874742662047847,
267                    0.0004724845739132828,
268                    -0.2840155429615469,
269                    -0.015829105256349306,
270                    0.5853546836542067,
271                    0.6756307362972898,
272                    0.31287159091429995,
273                    0.05441584224310401,
274                ];
275                let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
276                for (i, &val) in dec_lo.iter().enumerate().rev() {
277                    dec_hi.push(if i % 2 == 0 { val } else { -val });
278                }
279                let mut rec_lo = dec_lo.clone();
280                rec_lo.reverse();
281                let mut rec_hi = dec_hi.clone();
282                rec_hi.reverse();
283                Ok(WaveletFilters {
284                    dec_lo,
285                    dec_hi,
286                    rec_lo,
287                    rec_hi,
288                })
289            }
290            9 => {
291                // DB9 (Daubechies-18 coefficients) — PyWavelets orthonormal form
292                let dec_lo: Vec<f64> = vec![
293                    3.93473203162716e-05,
294                    -0.0002519631889427101,
295                    0.00023038576352319597,
296                    0.0018476468830562265,
297                    -0.00428150368246343,
298                    -0.004723204757751397,
299                    0.022361662123679096,
300                    0.00025094711483145197,
301                    -0.06763282906132997,
302                    0.03072568147933338,
303                    0.14854074933810638,
304                    -0.09684078322297646,
305                    -0.2932737832791749,
306                    0.13319738582500756,
307                    0.6572880780513005,
308                    0.6048231236901112,
309                    0.24383467461259034,
310                    0.038077947363878345,
311                ];
312                let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
313                for (i, &val) in dec_lo.iter().enumerate().rev() {
314                    dec_hi.push(if i % 2 == 0 { val } else { -val });
315                }
316                let mut rec_lo = dec_lo.clone();
317                rec_lo.reverse();
318                let mut rec_hi = dec_hi.clone();
319                rec_hi.reverse();
320                Ok(WaveletFilters {
321                    dec_lo,
322                    dec_hi,
323                    rec_lo,
324                    rec_hi,
325                })
326            }
327            10 => {
328                // DB10 (Daubechies-20 coefficients) — PyWavelets orthonormal form
329                let dec_lo: Vec<f64> = vec![
330                    -1.3264202894521244e-05,
331                    9.358867032006959e-05,
332                    -0.00011646685512928545,
333                    -0.0006858566949597116,
334                    0.001992405295185056,
335                    0.001395351747052901,
336                    -0.010733175483330575,
337                    0.0036065535669561697,
338                    0.033212674059341,
339                    -0.029457536821875813,
340                    -0.07139414716639708,
341                    0.09305736460357235,
342                    0.12736934033579325,
343                    -0.19594627437737705,
344                    -0.24984642432731538,
345                    0.2811723436605775,
346                    0.6884590394536035,
347                    0.5272011889317256,
348                    0.1881768000776915,
349                    0.026670057900555554,
350                ];
351                let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
352                for (i, &val) in dec_lo.iter().enumerate().rev() {
353                    dec_hi.push(if i % 2 == 0 { val } else { -val });
354                }
355                let mut rec_lo = dec_lo.clone();
356                rec_lo.reverse();
357                let mut rec_hi = dec_hi.clone();
358                rec_hi.reverse();
359                Ok(WaveletFilters {
360                    dec_lo,
361                    dec_hi,
362                    rec_lo,
363                    rec_hi,
364                })
365            }
366            _ => Err(TransformError::InvalidInput(format!(
367                "Daubechies-{} not implemented (supported: 1-10)",
368                n
369            ))),
370        }
371    }
372
373    /// Symlet wavelet filters (simplified - use Daubechies for now)
374    fn symlet(n: usize) -> Result<Self> {
375        // Symlets are nearly symmetric versions of Daubechies wavelets
376        Self::daubechies(n)
377    }
378
379    /// Coiflet wavelet filters
380    fn coiflet(n: usize) -> Result<Self> {
381        match n {
382            1 => {
383                // Coif1 — exact PyWavelets orthonormal coefficients (sum = sqrt(2))
384                let dec_lo: Vec<f64> = vec![
385                    -0.015655728135791993,
386                    -0.07273261951252645,
387                    0.3848648468648578,
388                    0.8525720202116004,
389                    0.3378976624574818,
390                    -0.07273261951252645,
391                ];
392                let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
393                for (i, &val) in dec_lo.iter().enumerate().rev() {
394                    dec_hi.push(if i % 2 == 0 { val } else { -val });
395                }
396
397                let mut rec_lo = dec_lo.clone();
398                rec_lo.reverse();
399                let mut rec_hi = dec_hi.clone();
400                rec_hi.reverse();
401
402                Ok(WaveletFilters {
403                    dec_lo,
404                    dec_hi,
405                    rec_lo,
406                    rec_hi,
407                })
408            }
409            2 => {
410                // Coif2 (12-tap) — exact PyWavelets orthonormal coefficients (sum = sqrt(2))
411                let dec_lo: Vec<f64> = vec![
412                    -0.000720549445520347,
413                    -0.0018232088709110323,
414                    0.005611434819368834,
415                    0.02368017194684777,
416                    -0.05943441864643109,
417                    -0.07648859907828076,
418                    0.4170051844232391,
419                    0.8127236354494135,
420                    0.3861100668227629,
421                    -0.0673725547237256,
422                    -0.04146493678687178,
423                    0.01638733646320364,
424                ];
425                let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
426                for (i, &val) in dec_lo.iter().enumerate().rev() {
427                    dec_hi.push(if i % 2 == 0 { val } else { -val });
428                }
429                let mut rec_lo = dec_lo.clone();
430                rec_lo.reverse();
431                let mut rec_hi = dec_hi.clone();
432                rec_hi.reverse();
433                Ok(WaveletFilters {
434                    dec_lo,
435                    dec_hi,
436                    rec_lo,
437                    rec_hi,
438                })
439            }
440            3 => {
441                // Coif3 (18-tap) — exact PyWavelets orthonormal coefficients (sum = sqrt(2))
442                let dec_lo: Vec<f64> = vec![
443                    -3.459977319727278e-05,
444                    -7.0983302506379e-05,
445                    0.0004662169598204029,
446                    0.0011175187708306303,
447                    -0.0025745176881367972,
448                    -0.009007976136730624,
449                    0.015880544863669452,
450                    0.03455502757329774,
451                    -0.08230192710629983,
452                    -0.07179982161915484,
453                    0.42848347637737,
454                    0.7937772226260872,
455                    0.40517690240911824,
456                    -0.06112339000297255,
457                    -0.06577191128146936,
458                    0.023452696142077168,
459                    0.007782596425672746,
460                    -0.003793512864380802,
461                ];
462                let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
463                for (i, &val) in dec_lo.iter().enumerate().rev() {
464                    dec_hi.push(if i % 2 == 0 { val } else { -val });
465                }
466                let mut rec_lo = dec_lo.clone();
467                rec_lo.reverse();
468                let mut rec_hi = dec_hi.clone();
469                rec_hi.reverse();
470                Ok(WaveletFilters {
471                    dec_lo,
472                    dec_hi,
473                    rec_lo,
474                    rec_hi,
475                })
476            }
477            4 => {
478                // Coif4 (24-tap) — exact PyWavelets orthonormal coefficients (sum = sqrt(2))
479                let dec_lo: Vec<f64> = vec![
480                    -1.7849909144933469e-06,
481                    -3.259647940030751e-06,
482                    3.1229861599195265e-05,
483                    6.233885431278719e-05,
484                    -0.0002599743371222568,
485                    -0.0005890202246332165,
486                    0.0012665610789256603,
487                    0.0037514346971460866,
488                    -0.0056582838001308835,
489                    -0.015211728187697211,
490                    0.02508225333794961,
491                    0.03933442260558915,
492                    -0.09622042453595264,
493                    -0.06662747236681717,
494                    0.43438603311435653,
495                    0.7822389344242826,
496                    0.41530842700068227,
497                    -0.05607731960356926,
498                    -0.08126671024919373,
499                    0.02668230466960483,
500                    0.01606894713157503,
501                    -0.007346167936268051,
502                    -0.001629492425226786,
503                    0.000892313902537003,
504                ];
505                let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
506                for (i, &val) in dec_lo.iter().enumerate().rev() {
507                    dec_hi.push(if i % 2 == 0 { val } else { -val });
508                }
509                let mut rec_lo = dec_lo.clone();
510                rec_lo.reverse();
511                let mut rec_hi = dec_hi.clone();
512                rec_hi.reverse();
513                Ok(WaveletFilters {
514                    dec_lo,
515                    dec_hi,
516                    rec_lo,
517                    rec_hi,
518                })
519            }
520            5 => {
521                // Coif5 (30-tap) — exact PyWavelets orthonormal coefficients (sum = sqrt(2))
522                let dec_lo: Vec<f64> = vec![
523                    -9.604010112767894e-08,
524                    -1.6237995172048338e-07,
525                    2.0612203985788783e-06,
526                    3.7007277113394796e-06,
527                    -2.1270221672515614e-05,
528                    -4.12198619242655e-05,
529                    0.00014035632812373243,
530                    0.0003018579416682448,
531                    -0.0006375589261258812,
532                    -0.0016616273039298788,
533                    0.0024315754425382886,
534                    0.006761520220620417,
535                    -0.009159507338676163,
536                    -0.019758391600965465,
537                    0.032674799467057355,
538                    0.041287530472117834,
539                    -0.10556315130733723,
540                    -0.06203775157498196,
541                    0.4379823066591634,
542                    0.7742936228603274,
543                    0.42157126673075435,
544                    -0.052046670253554764,
545                    -0.09192158806008609,
546                    0.028169744270532353,
547                    0.023408322118927783,
548                    -0.010131584846900276,
549                    -0.00415931262757864,
550                    0.0021782943778456947,
551                    0.0003585777411617577,
552                    -0.000212081862067494,
553                ];
554                let mut dec_hi: Vec<f64> = Vec::with_capacity(dec_lo.len());
555                for (i, &val) in dec_lo.iter().enumerate().rev() {
556                    dec_hi.push(if i % 2 == 0 { val } else { -val });
557                }
558                let mut rec_lo = dec_lo.clone();
559                rec_lo.reverse();
560                let mut rec_hi = dec_hi.clone();
561                rec_hi.reverse();
562                Ok(WaveletFilters {
563                    dec_lo,
564                    dec_hi,
565                    rec_lo,
566                    rec_hi,
567                })
568            }
569            _ => Err(TransformError::InvalidInput(format!(
570                "Coiflet-{} not implemented (supported: 1-5)",
571                n
572            ))),
573        }
574    }
575
576    /// Biorthogonal wavelet filters
577    fn biorthogonal(_p: usize, _q: usize) -> Result<Self> {
578        // For now, return Haar as placeholder
579        Self::haar()
580    }
581}
582
583/// 1D Discrete Wavelet Transform
584#[derive(Debug, Clone)]
585pub struct DWT {
586    wavelet: WaveletType,
587    filters: WaveletFilters,
588    boundary: BoundaryMode,
589    level: Option<usize>,
590}
591
592impl DWT {
593    /// Create a new DWT instance
594    pub fn new(wavelet: WaveletType) -> Result<Self> {
595        let filters = WaveletFilters::from_wavelet(wavelet)?;
596        Ok(DWT {
597            wavelet,
598            filters,
599            boundary: BoundaryMode::Symmetric,
600            level: None,
601        })
602    }
603
604    /// Set the boundary mode
605    pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
606        self.boundary = boundary;
607        self
608    }
609
610    /// Set the decomposition level
611    pub fn with_level(mut self, level: usize) -> Self {
612        self.level = Some(level);
613        self
614    }
615
616    /// Perform single-level decomposition
617    pub fn decompose(&self, signal: &ArrayView1<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
618        let n = signal.len();
619        if n < 2 {
620            return Err(TransformError::InvalidInput(
621                "Signal too short for DWT".to_string(),
622            ));
623        }
624
625        // Extend signal according to boundary mode
626        let extended = self.extend_signal(signal)?;
627
628        // Convolve with filters and downsample
629        let approx = self.convolve_downsample(&extended, &self.filters.dec_lo)?;
630        let detail = self.convolve_downsample(&extended, &self.filters.dec_hi)?;
631
632        Ok((approx, detail))
633    }
634
635    /// Perform multi-level decomposition
636    pub fn wavedec(&self, signal: &ArrayView1<f64>) -> Result<Vec<Array1<f64>>> {
637        let max_level = self.max_decomposition_level(signal.len());
638        let level = self.level.unwrap_or(max_level).min(max_level);
639
640        let mut coeffs = Vec::with_capacity(level + 1);
641        let mut current = signal.to_owned();
642
643        for _ in 0..level {
644            let (approx, detail) = self.decompose(&current.view())?;
645            coeffs.push(detail);
646            current = approx;
647        }
648
649        // Add final approximation coefficients
650        coeffs.push(current);
651        coeffs.reverse();
652
653        Ok(coeffs)
654    }
655
656    /// Perform single-level reconstruction
657    pub fn reconstruct(
658        &self,
659        approx: &ArrayView1<f64>,
660        detail: &ArrayView1<f64>,
661    ) -> Result<Array1<f64>> {
662        // Upsample and convolve with reconstruction filters
663        let approx_up = self.upsample_convolve(approx, &self.filters.rec_lo)?;
664        let detail_up = self.upsample_convolve(detail, &self.filters.rec_hi)?;
665
666        // Add the two components
667        let min_len = approx_up.len().min(detail_up.len());
668        let mut reconstructed = Array1::zeros(min_len);
669        for i in 0..min_len {
670            reconstructed[i] = approx_up[i] + detail_up[i];
671        }
672
673        Ok(reconstructed)
674    }
675
676    /// Perform multi-level reconstruction
677    pub fn waverec(&self, coeffs: &[Array1<f64>]) -> Result<Array1<f64>> {
678        if coeffs.is_empty() {
679            return Err(TransformError::InvalidInput(
680                "No coefficients provided for reconstruction".to_string(),
681            ));
682        }
683
684        let mut current = coeffs[0].clone();
685
686        for detail in &coeffs[1..] {
687            current = self.reconstruct(&current.view(), &detail.view())?;
688        }
689
690        Ok(current)
691    }
692
693    // Helper methods
694
695    fn extend_signal(&self, signal: &ArrayView1<f64>) -> Result<Array1<f64>> {
696        let filter_len = self.filters.dec_lo.len();
697        let n = signal.len();
698        let pad_len = filter_len - 1;
699
700        let mut extended = Array1::zeros(n + 2 * pad_len);
701
702        match self.boundary {
703            BoundaryMode::Zero => {
704                for i in 0..n {
705                    extended[i + pad_len] = signal[i];
706                }
707            }
708            BoundaryMode::Constant => {
709                let first = signal[0];
710                let last = signal[n - 1];
711                for i in 0..pad_len {
712                    extended[i] = first;
713                    extended[n + pad_len + i] = last;
714                }
715                for i in 0..n {
716                    extended[i + pad_len] = signal[i];
717                }
718            }
719            BoundaryMode::Symmetric => {
720                for i in 0..pad_len {
721                    extended[pad_len - 1 - i] = signal[i.min(n - 1)];
722                    extended[n + pad_len + i] = signal[(n - 1 - i).max(0)];
723                }
724                for i in 0..n {
725                    extended[i + pad_len] = signal[i];
726                }
727            }
728            BoundaryMode::Periodic => {
729                for i in 0..pad_len {
730                    extended[i] = signal[(n - pad_len + i) % n];
731                    extended[n + pad_len + i] = signal[i % n];
732                }
733                for i in 0..n {
734                    extended[i + pad_len] = signal[i];
735                }
736            }
737            BoundaryMode::Reflect => {
738                for i in 0..pad_len {
739                    let idx1 = if i < n { i } else { n - 1 };
740                    let idx2 = if n > i + 1 { n - 1 - i } else { 0 };
741                    extended[pad_len - 1 - i] = signal[idx1];
742                    extended[n + pad_len + i] = signal[idx2];
743                }
744                for i in 0..n {
745                    extended[i + pad_len] = signal[i];
746                }
747            }
748        }
749
750        Ok(extended)
751    }
752
753    fn convolve_downsample(&self, signal: &Array1<f64>, filter: &[f64]) -> Result<Array1<f64>> {
754        let n = signal.len();
755        let filter_len = filter.len();
756        let output_len = (n + 1) / 2;
757        let mut output = Array1::zeros(output_len);
758
759        for i in 0..output_len {
760            let pos = i * 2;
761            let mut sum = 0.0;
762
763            for (j, &coeff) in filter.iter().enumerate() {
764                let idx = pos + j;
765                if idx < n {
766                    sum += signal[idx] * coeff;
767                }
768            }
769
770            output[i] = sum;
771        }
772
773        Ok(output)
774    }
775
776    fn upsample_convolve(&self, signal: &ArrayView1<f64>, filter: &[f64]) -> Result<Array1<f64>> {
777        let n = signal.len();
778        let filter_len = filter.len();
779        let output_len = n * 2;
780        let mut output = Array1::zeros(output_len);
781
782        // Upsample by inserting zeros
783        let mut upsampled = Array1::zeros(output_len);
784        for i in 0..n {
785            upsampled[i * 2] = signal[i];
786        }
787
788        // Convolve with reconstruction filter
789        for i in 0..output_len {
790            let mut sum = 0.0;
791            for (j, &coeff) in filter.iter().enumerate() {
792                if i >= j && i - j < output_len {
793                    sum += upsampled[i - j] * coeff;
794                }
795            }
796            output[i] = sum;
797        }
798
799        Ok(output)
800    }
801
802    fn max_decomposition_level(&self, signal_len: usize) -> usize {
803        let filter_len = self.filters.dec_lo.len();
804        let mut level: usize = 0;
805        let mut current_len = signal_len;
806
807        while current_len >= filter_len {
808            current_len = (current_len + 1) / 2;
809            level += 1;
810        }
811
812        level.saturating_sub(1)
813    }
814}
815
816/// 2D Discrete Wavelet Transform
817#[derive(Debug, Clone)]
818pub struct DWT2D {
819    wavelet: WaveletType,
820    filters: WaveletFilters,
821    boundary: BoundaryMode,
822    level: Option<usize>,
823}
824
825impl DWT2D {
826    /// Create a new DWT2D instance
827    pub fn new(wavelet: WaveletType) -> Result<Self> {
828        let filters = WaveletFilters::from_wavelet(wavelet)?;
829        Ok(DWT2D {
830            wavelet,
831            filters,
832            boundary: BoundaryMode::Symmetric,
833            level: None,
834        })
835    }
836
837    /// Set the boundary mode
838    pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
839        self.boundary = boundary;
840        self
841    }
842
843    /// Set the decomposition level
844    pub fn with_level(mut self, level: usize) -> Self {
845        self.level = Some(level);
846        self
847    }
848
849    /// Perform single-level 2D decomposition
850    pub fn decompose2(&self, image: &ArrayView2<f64>) -> Result<Dwt2dCoeffs> {
851        let (rows, cols) = image.dim();
852        if rows < 2 || cols < 2 {
853            return Err(TransformError::InvalidInput(
854                "Image too small for 2D DWT".to_string(),
855            ));
856        }
857
858        let dwt1d = DWT {
859            wavelet: self.wavelet,
860            filters: self.filters.clone(),
861            boundary: self.boundary,
862            level: None,
863        };
864
865        // Apply DWT along rows
866        let mut row_results_approx = Vec::with_capacity(rows);
867        let mut row_results_detail = Vec::with_capacity(rows);
868
869        for row_idx in 0..rows {
870            let row = image.row(row_idx);
871            let (approx, detail) = dwt1d.decompose(&row)?;
872            row_results_approx.push(approx);
873            row_results_detail.push(detail);
874        }
875
876        let approx_rows = row_results_approx[0].len();
877        let detail_rows = row_results_detail[0].len();
878
879        // Convert to 2D arrays
880        let mut approx_mat = Array2::zeros((rows, approx_rows));
881        let mut detail_mat = Array2::zeros((rows, detail_rows));
882
883        for (i, (app, det)) in row_results_approx
884            .iter()
885            .zip(row_results_detail.iter())
886            .enumerate()
887        {
888            for (j, &val) in app.iter().enumerate() {
889                approx_mat[[i, j]] = val;
890            }
891            for (j, &val) in det.iter().enumerate() {
892                detail_mat[[i, j]] = val;
893            }
894        }
895
896        // Apply DWT along columns
897        let (ll, lh) = self.decompose_columns(&approx_mat.view(), &dwt1d)?;
898        let (hl, hh) = self.decompose_columns(&detail_mat.view(), &dwt1d)?;
899
900        Ok(Dwt2dCoeffs { ll, lh, hl, hh })
901    }
902
903    fn decompose_columns(
904        &self,
905        mat: &ArrayView2<f64>,
906        dwt1d: &DWT,
907    ) -> Result<(Array2<f64>, Array2<f64>)> {
908        let (rows, cols) = mat.dim();
909        let mut col_results_approx = Vec::with_capacity(cols);
910        let mut col_results_detail = Vec::with_capacity(cols);
911
912        for col_idx in 0..cols {
913            let col = mat.column(col_idx);
914            let (approx, detail) = dwt1d.decompose(&col)?;
915            col_results_approx.push(approx);
916            col_results_detail.push(detail);
917        }
918
919        let approx_cols = col_results_approx[0].len();
920        let detail_cols = col_results_detail[0].len();
921
922        let mut approx_result = Array2::zeros((approx_cols, cols));
923        let mut detail_result = Array2::zeros((detail_cols, cols));
924
925        for (j, (app, det)) in col_results_approx
926            .iter()
927            .zip(col_results_detail.iter())
928            .enumerate()
929        {
930            for (i, &val) in app.iter().enumerate() {
931                approx_result[[i, j]] = val;
932            }
933            for (i, &val) in det.iter().enumerate() {
934                detail_result[[i, j]] = val;
935            }
936        }
937
938        Ok((approx_result, detail_result))
939    }
940
941    /// Perform multi-level 2D decomposition
942    pub fn wavedec2(&self, image: &ArrayView2<f64>) -> Result<Vec<Dwt2dCoeffs>> {
943        let max_level = self.max_decomposition_level_2d(image.dim());
944        let level = self.level.unwrap_or(max_level).min(max_level);
945
946        let mut coeffs = Vec::with_capacity(level);
947        let mut current = image.to_owned();
948
949        for _ in 0..level {
950            let dwt2d_coeffs = self.decompose2(&current.view())?;
951            coeffs.push(dwt2d_coeffs.clone());
952            current = dwt2d_coeffs.ll;
953        }
954
955        Ok(coeffs)
956    }
957
958    fn max_decomposition_level_2d(&self, shape: (usize, usize)) -> usize {
959        let filter_len = self.filters.dec_lo.len();
960        let min_dim = shape.0.min(shape.1);
961
962        let mut level: usize = 0;
963        let mut current_dim = min_dim;
964
965        while current_dim >= filter_len {
966            current_dim = (current_dim + 1) / 2;
967            level += 1;
968        }
969
970        level.saturating_sub(1)
971    }
972}
973
974/// 2D DWT coefficients (LL, LH, HL, HH)
975#[derive(Debug, Clone)]
976pub struct Dwt2dCoeffs {
977    /// Approximation coefficients (low-low)
978    pub ll: Array2<f64>,
979    /// Horizontal detail coefficients (low-high)
980    pub lh: Array2<f64>,
981    /// Vertical detail coefficients (high-low)
982    pub hl: Array2<f64>,
983    /// Diagonal detail coefficients (high-high)
984    pub hh: Array2<f64>,
985}
986
987/// 3D DWT coefficients — the 8 subbands produced by one level of separable 3D decomposition.
988///
989/// Naming convention: each letter indicates the filter applied along axis 0, 1, 2
990/// respectively.  `L` = low-pass (approximation), `H` = high-pass (detail).
991#[derive(Debug, Clone)]
992pub struct Dwt3dCoeffs {
993    /// LLL — approximation subband (low along all 3 axes)
994    pub lll: Array3<f64>,
995    /// LLH — detail along axis 2
996    pub llh: Array3<f64>,
997    /// LHL — detail along axis 1
998    pub lhl: Array3<f64>,
999    /// LHH — detail along axes 1 and 2
1000    pub lhh: Array3<f64>,
1001    /// HLL — detail along axis 0
1002    pub hll: Array3<f64>,
1003    /// HLH — detail along axes 0 and 2
1004    pub hlh: Array3<f64>,
1005    /// HHL — detail along axes 0 and 1
1006    pub hhl: Array3<f64>,
1007    /// HHH — detail along all 3 axes
1008    pub hhh: Array3<f64>,
1009}
1010
1011/// N-D Discrete Wavelet Transform (supports 3D decomposition)
1012#[derive(Debug, Clone)]
1013pub struct DWTN {
1014    wavelet: WaveletType,
1015    boundary: BoundaryMode,
1016    level: Option<usize>,
1017}
1018
1019impl DWTN {
1020    /// Create a new DWTN instance
1021    pub fn new(wavelet: WaveletType) -> Self {
1022        DWTN {
1023            wavelet,
1024            boundary: BoundaryMode::Symmetric,
1025            level: None,
1026        }
1027    }
1028
1029    /// Set the boundary mode
1030    pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
1031        self.boundary = boundary;
1032        self
1033    }
1034
1035    /// Set the decomposition level
1036    pub fn with_level(mut self, level: usize) -> Self {
1037        self.level = Some(level);
1038        self
1039    }
1040
1041    /// Perform single-level 3D DWT decomposition using separable filtering.
1042    ///
1043    /// The separable 3D DWT applies 1-D DWT independently along each of the three
1044    /// axes in sequence, producing 8 subbands (LLL through HHH).
1045    ///
1046    /// # Steps
1047    /// 1. Apply 1-D DWT along axis 0 → `Lo_x` and `Hi_x` halves.
1048    /// 2. For each of the 2 halves, apply 1-D DWT along axis 1 → 4 quarters.
1049    /// 3. For each of the 4 quarters, apply 1-D DWT along axis 2 → 8 subbands.
1050    pub fn decompose3(&self, volume: &Array3<f64>) -> Result<Dwt3dCoeffs> {
1051        let (d0, d1, d2) = volume.dim();
1052        if d0 < 2 || d1 < 2 || d2 < 2 {
1053            return Err(TransformError::InvalidInput(
1054                "Volume too small for 3D DWT: all dimensions must be >= 2".to_string(),
1055            ));
1056        }
1057
1058        let dwt1d = DWT::new(self.wavelet)?.with_boundary(self.boundary);
1059
1060        // --- Axis 0 pass -------------------------------------------------
1061        // For each (j, k) slice along axis 0, apply 1-D DWT and collect
1062        // the low-pass (approx) and high-pass (detail) results.
1063        let half0 = (d0 + 1) / 2; // output length after DWT along axis 0
1064        let mut lo_x = Array3::<f64>::zeros((half0, d1, d2));
1065        let mut hi_x = Array3::<f64>::zeros((half0, d1, d2));
1066
1067        for j in 0..d1 {
1068            for k in 0..d2 {
1069                let col: Vec<f64> = (0..d0).map(|i| volume[[i, j, k]]).collect();
1070                let col_arr = Array1::from(col);
1071                let (approx, detail) = dwt1d.decompose(&col_arr.view())?;
1072                for i in 0..approx.len().min(half0) {
1073                    lo_x[[i, j, k]] = approx[i];
1074                }
1075                for i in 0..detail.len().min(half0) {
1076                    hi_x[[i, j, k]] = detail[i];
1077                }
1078            }
1079        }
1080
1081        // --- Axis 1 pass -------------------------------------------------
1082        let half1 = (d1 + 1) / 2;
1083        let (lo_x_lo_y, lo_x_hi_y) =
1084            self.apply_dwt_axis1_3d(&lo_x, &dwt1d, half0, d1, d2, half1)?;
1085        let (hi_x_lo_y, hi_x_hi_y) =
1086            self.apply_dwt_axis1_3d(&hi_x, &dwt1d, half0, d1, d2, half1)?;
1087
1088        // --- Axis 2 pass -------------------------------------------------
1089        let half2 = (d2 + 1) / 2;
1090        let (lll, llh) = self.apply_dwt_axis2_3d(&lo_x_lo_y, &dwt1d, half0, half1, d2, half2)?;
1091        let (lhl, lhh) = self.apply_dwt_axis2_3d(&lo_x_hi_y, &dwt1d, half0, half1, d2, half2)?;
1092        let (hll, hlh) = self.apply_dwt_axis2_3d(&hi_x_lo_y, &dwt1d, half0, half1, d2, half2)?;
1093        let (hhl, hhh) = self.apply_dwt_axis2_3d(&hi_x_hi_y, &dwt1d, half0, half1, d2, half2)?;
1094
1095        Ok(Dwt3dCoeffs {
1096            lll,
1097            llh,
1098            lhl,
1099            lhh,
1100            hll,
1101            hlh,
1102            hhl,
1103            hhh,
1104        })
1105    }
1106
1107    // Apply 1-D DWT along axis 1 of a 3-D array, returning the two output halves.
1108    fn apply_dwt_axis1_3d(
1109        &self,
1110        arr: &Array3<f64>,
1111        dwt1d: &DWT,
1112        size0: usize,
1113        size1: usize,
1114        size2: usize,
1115        out1: usize,
1116    ) -> Result<(Array3<f64>, Array3<f64>)> {
1117        let mut lo = Array3::<f64>::zeros((size0, out1, size2));
1118        let mut hi = Array3::<f64>::zeros((size0, out1, size2));
1119
1120        for i in 0..size0 {
1121            for k in 0..size2 {
1122                let col: Vec<f64> = (0..size1).map(|j| arr[[i, j, k]]).collect();
1123                let col_arr = Array1::from(col);
1124                let (approx, detail) = dwt1d.decompose(&col_arr.view())?;
1125                for j in 0..approx.len().min(out1) {
1126                    lo[[i, j, k]] = approx[j];
1127                }
1128                for j in 0..detail.len().min(out1) {
1129                    hi[[i, j, k]] = detail[j];
1130                }
1131            }
1132        }
1133        Ok((lo, hi))
1134    }
1135
1136    // Apply 1-D DWT along axis 2 of a 3-D array, returning the two output halves.
1137    fn apply_dwt_axis2_3d(
1138        &self,
1139        arr: &Array3<f64>,
1140        dwt1d: &DWT,
1141        size0: usize,
1142        size1: usize,
1143        size2: usize,
1144        out2: usize,
1145    ) -> Result<(Array3<f64>, Array3<f64>)> {
1146        let mut lo = Array3::<f64>::zeros((size0, size1, out2));
1147        let mut hi = Array3::<f64>::zeros((size0, size1, out2));
1148
1149        for i in 0..size0 {
1150            for j in 0..size1 {
1151                let row: Vec<f64> = (0..size2).map(|k| arr[[i, j, k]]).collect();
1152                let row_arr = Array1::from(row);
1153                let (approx, detail) = dwt1d.decompose(&row_arr.view())?;
1154                for k in 0..approx.len().min(out2) {
1155                    lo[[i, j, k]] = approx[k];
1156                }
1157                for k in 0..detail.len().min(out2) {
1158                    hi[[i, j, k]] = detail[k];
1159                }
1160            }
1161        }
1162        Ok((lo, hi))
1163    }
1164}
1165
1166#[cfg(test)]
1167mod tests {
1168    use super::*;
1169    use approx::assert_abs_diff_eq;
1170
1171    #[test]
1172    fn test_dwt_haar() -> Result<()> {
1173        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1174        let dwt = DWT::new(WaveletType::Haar)?;
1175
1176        let (approx, detail) = dwt.decompose(&signal.view())?;
1177
1178        assert!(approx.len() > 0);
1179        assert!(detail.len() > 0);
1180        assert_eq!(approx.len(), detail.len());
1181
1182        Ok(())
1183    }
1184
1185    #[test]
1186    fn test_dwt_multilevel() -> Result<()> {
1187        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1188        let dwt = DWT::new(WaveletType::Haar)?.with_level(2);
1189
1190        let coeffs = dwt.wavedec(&signal.view())?;
1191
1192        assert_eq!(coeffs.len(), 3); // 2 levels + approximation
1193
1194        Ok(())
1195    }
1196
1197    #[test]
1198    fn test_dwt_reconstruction() -> Result<()> {
1199        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1200        let dwt = DWT::new(WaveletType::Haar)?;
1201
1202        let (approx, detail) = dwt.decompose(&signal.view())?;
1203        let reconstructed = dwt.reconstruct(&approx.view(), &detail.view())?;
1204
1205        // Check reconstruction is approximately correct (may have different length)
1206        assert!(reconstructed.len() >= signal.len() - 2);
1207
1208        Ok(())
1209    }
1210
1211    #[test]
1212    fn test_dwt2d() -> Result<()> {
1213        let image = Array2::from_shape_fn((8, 8), |(i, j)| (i + j) as f64);
1214        let dwt2d = DWT2D::new(WaveletType::Haar)?;
1215
1216        let coeffs = dwt2d.decompose2(&image.view())?;
1217
1218        assert!(coeffs.ll.len() > 0);
1219        assert!(coeffs.lh.len() > 0);
1220        assert!(coeffs.hl.len() > 0);
1221        assert!(coeffs.hh.len() > 0);
1222
1223        Ok(())
1224    }
1225
1226    #[test]
1227    fn test_wavelet_filters() -> Result<()> {
1228        let filters = WaveletFilters::from_wavelet(WaveletType::Haar)?;
1229
1230        assert_eq!(filters.dec_lo.len(), 2);
1231        assert_eq!(filters.dec_hi.len(), 2);
1232        assert_eq!(filters.rec_lo.len(), 2);
1233        assert_eq!(filters.rec_hi.len(), 2);
1234
1235        Ok(())
1236    }
1237
1238    /// Verify that `dec_lo` sums to sqrt(2) (standard orthonormal normalisation).
1239    fn check_filter_normalisation(filters: &WaveletFilters) {
1240        let sum: f64 = filters.dec_lo.iter().sum();
1241        let diff = (sum - 2.0_f64.sqrt()).abs();
1242        assert!(
1243            diff < 1e-6,
1244            "dec_lo sum {sum} is not sqrt(2); diff = {diff}"
1245        );
1246    }
1247
1248    #[test]
1249    fn test_daubechies_db1_is_haar() -> Result<()> {
1250        let haar = WaveletFilters::from_wavelet(WaveletType::Haar)?;
1251        let db1 = WaveletFilters::from_wavelet(WaveletType::Daubechies(1))?;
1252        assert_abs_diff_eq!(haar.dec_lo[0], db1.dec_lo[0], epsilon = 1e-10);
1253        assert_abs_diff_eq!(haar.dec_lo[1], db1.dec_lo[1], epsilon = 1e-10);
1254        Ok(())
1255    }
1256
1257    #[test]
1258    fn test_daubechies_db3_filters() -> Result<()> {
1259        let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(3))?;
1260        assert_eq!(f.dec_lo.len(), 6);
1261        check_filter_normalisation(&f);
1262        Ok(())
1263    }
1264
1265    #[test]
1266    fn test_daubechies_db5_filters() -> Result<()> {
1267        let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(5))?;
1268        assert_eq!(f.dec_lo.len(), 10);
1269        check_filter_normalisation(&f);
1270        Ok(())
1271    }
1272
1273    #[test]
1274    fn test_daubechies_db6_filters() -> Result<()> {
1275        let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(6))?;
1276        assert_eq!(f.dec_lo.len(), 12);
1277        check_filter_normalisation(&f);
1278        Ok(())
1279    }
1280
1281    #[test]
1282    fn test_daubechies_db7_filters() -> Result<()> {
1283        let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(7))?;
1284        assert_eq!(f.dec_lo.len(), 14);
1285        check_filter_normalisation(&f);
1286        Ok(())
1287    }
1288
1289    #[test]
1290    fn test_daubechies_db8_filters() -> Result<()> {
1291        let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(8))?;
1292        assert_eq!(f.dec_lo.len(), 16);
1293        check_filter_normalisation(&f);
1294        Ok(())
1295    }
1296
1297    #[test]
1298    fn test_daubechies_db10_filters() -> Result<()> {
1299        let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(10))?;
1300        assert_eq!(f.dec_lo.len(), 20);
1301        check_filter_normalisation(&f);
1302        Ok(())
1303    }
1304
1305    #[test]
1306    fn test_daubechies_unsupported_returns_error() {
1307        let result = WaveletFilters::from_wavelet(WaveletType::Daubechies(11));
1308        assert!(result.is_err(), "DB11 should return an error");
1309    }
1310
1311    #[test]
1312    fn test_coiflet1_filters() -> Result<()> {
1313        let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(1))?;
1314        assert_eq!(f.dec_lo.len(), 6);
1315        check_filter_normalisation(&f);
1316        Ok(())
1317    }
1318
1319    #[test]
1320    fn test_coiflet2_filters() -> Result<()> {
1321        let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(2))?;
1322        assert_eq!(f.dec_lo.len(), 12);
1323        check_filter_normalisation(&f);
1324        Ok(())
1325    }
1326
1327    #[test]
1328    fn test_coiflet3_filters() -> Result<()> {
1329        let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(3))?;
1330        assert_eq!(f.dec_lo.len(), 18);
1331        check_filter_normalisation(&f);
1332        Ok(())
1333    }
1334
1335    #[test]
1336    fn test_coiflet4_filters() -> Result<()> {
1337        let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(4))?;
1338        assert_eq!(f.dec_lo.len(), 24);
1339        check_filter_normalisation(&f);
1340        Ok(())
1341    }
1342
1343    #[test]
1344    fn test_coiflet5_filters() -> Result<()> {
1345        let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(5))?;
1346        assert_eq!(f.dec_lo.len(), 30);
1347        check_filter_normalisation(&f);
1348        Ok(())
1349    }
1350
1351    #[test]
1352    fn test_coiflet_unsupported_returns_error() {
1353        let result = WaveletFilters::from_wavelet(WaveletType::Coiflet(6));
1354        assert!(result.is_err(), "Coif6 should return an error");
1355    }
1356
1357    #[test]
1358    fn test_dwt_db3_roundtrip() -> Result<()> {
1359        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
1360        let dwt = DWT::new(WaveletType::Daubechies(3))?;
1361
1362        let (approx, detail) = dwt.decompose(&signal.view())?;
1363        let reconstructed = dwt.reconstruct(&approx.view(), &detail.view())?;
1364
1365        // Reconstruction should have at least as many samples as the original
1366        assert!(reconstructed.len() >= signal.len() - 2);
1367        Ok(())
1368    }
1369
1370    #[test]
1371    fn test_dwt_coif2_roundtrip() -> Result<()> {
1372        let signal = Array1::from_vec(vec![
1373            1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1374        ]);
1375        let dwt = DWT::new(WaveletType::Coiflet(2))?;
1376
1377        let (approx, detail) = dwt.decompose(&signal.view())?;
1378        let reconstructed = dwt.reconstruct(&approx.view(), &detail.view())?;
1379
1380        assert!(reconstructed.len() >= signal.len() - 4);
1381        Ok(())
1382    }
1383
1384    // -------------------------------------------------------------------------
1385    // Invariant helpers: unit energy (sum of squares = 1) and QMF orthogonality
1386    // -------------------------------------------------------------------------
1387
1388    /// Assert that sum of squares of dec_lo equals 1.0 (unit energy condition).
1389    fn check_unit_energy(filters: &WaveletFilters) {
1390        let energy: f64 = filters.dec_lo.iter().map(|x| x * x).sum();
1391        let diff = (energy - 1.0).abs();
1392        assert!(
1393            diff < 1e-10,
1394            "dec_lo unit-energy check failed: sum-of-squares = {energy}, diff from 1.0 = {diff}"
1395        );
1396    }
1397
1398    /// Assert that <dec_lo, dec_hi> = 0 (QMF orthogonality).
1399    fn check_qmf_orthogonality(filters: &WaveletFilters) {
1400        let inner: f64 = filters
1401            .dec_lo
1402            .iter()
1403            .zip(filters.dec_hi.iter())
1404            .map(|(a, b)| a * b)
1405            .sum();
1406        let diff = inner.abs();
1407        assert!(
1408            diff < 1e-10,
1409            "QMF orthogonality check failed: <dec_lo, dec_hi> = {inner}, abs = {diff}"
1410        );
1411    }
1412
1413    // -------------------------------------------------------------------------
1414    // Daubechies filter invariant tests (length, sum=√2, energy=1, QMF ortho)
1415    // -------------------------------------------------------------------------
1416
1417    #[test]
1418    fn test_db2_all_invariants() -> Result<()> {
1419        let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(2))?;
1420        assert_eq!(f.dec_lo.len(), 4, "db2 length must be 4");
1421        check_filter_normalisation(&f);
1422        check_unit_energy(&f);
1423        check_qmf_orthogonality(&f);
1424        Ok(())
1425    }
1426
1427    #[test]
1428    fn test_db4_all_invariants() -> Result<()> {
1429        let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(4))?;
1430        assert_eq!(f.dec_lo.len(), 8, "db4 length must be 8");
1431        check_filter_normalisation(&f);
1432        check_unit_energy(&f);
1433        check_qmf_orthogonality(&f);
1434        Ok(())
1435    }
1436
1437    #[test]
1438    fn test_db6_all_invariants() -> Result<()> {
1439        let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(6))?;
1440        assert_eq!(f.dec_lo.len(), 12, "db6 length must be 12");
1441        check_filter_normalisation(&f);
1442        check_unit_energy(&f);
1443        check_qmf_orthogonality(&f);
1444        Ok(())
1445    }
1446
1447    #[test]
1448    fn test_db8_all_invariants() -> Result<()> {
1449        let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(8))?;
1450        assert_eq!(f.dec_lo.len(), 16, "db8 length must be 16");
1451        check_filter_normalisation(&f);
1452        check_unit_energy(&f);
1453        check_qmf_orthogonality(&f);
1454        Ok(())
1455    }
1456
1457    #[test]
1458    fn test_db10_all_invariants() -> Result<()> {
1459        let f = WaveletFilters::from_wavelet(WaveletType::Daubechies(10))?;
1460        assert_eq!(f.dec_lo.len(), 20, "db10 length must be 20");
1461        check_filter_normalisation(&f);
1462        check_unit_energy(&f);
1463        check_qmf_orthogonality(&f);
1464        Ok(())
1465    }
1466
1467    // -------------------------------------------------------------------------
1468    // Coiflet filter invariant tests (length, sum=√2, energy=1, QMF ortho)
1469    // -------------------------------------------------------------------------
1470
1471    #[test]
1472    fn test_coif1_all_invariants() -> Result<()> {
1473        let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(1))?;
1474        assert_eq!(f.dec_lo.len(), 6, "coif1 length must be 6");
1475        check_filter_normalisation(&f);
1476        check_unit_energy(&f);
1477        check_qmf_orthogonality(&f);
1478        Ok(())
1479    }
1480
1481    #[test]
1482    fn test_coif2_all_invariants() -> Result<()> {
1483        let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(2))?;
1484        assert_eq!(f.dec_lo.len(), 12, "coif2 length must be 12");
1485        check_filter_normalisation(&f);
1486        check_unit_energy(&f);
1487        check_qmf_orthogonality(&f);
1488        Ok(())
1489    }
1490
1491    #[test]
1492    fn test_coif3_all_invariants() -> Result<()> {
1493        let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(3))?;
1494        assert_eq!(f.dec_lo.len(), 18, "coif3 length must be 18");
1495        check_filter_normalisation(&f);
1496        check_unit_energy(&f);
1497        check_qmf_orthogonality(&f);
1498        Ok(())
1499    }
1500
1501    #[test]
1502    fn test_coif4_all_invariants() -> Result<()> {
1503        let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(4))?;
1504        assert_eq!(f.dec_lo.len(), 24, "coif4 length must be 24");
1505        check_filter_normalisation(&f);
1506        check_unit_energy(&f);
1507        check_qmf_orthogonality(&f);
1508        Ok(())
1509    }
1510
1511    #[test]
1512    fn test_coif5_all_invariants() -> Result<()> {
1513        let f = WaveletFilters::from_wavelet(WaveletType::Coiflet(5))?;
1514        assert_eq!(f.dec_lo.len(), 30, "coif5 length must be 30");
1515        check_filter_normalisation(&f);
1516        check_unit_energy(&f);
1517        check_qmf_orthogonality(&f);
1518        Ok(())
1519    }
1520
1521    // -------------------------------------------------------------------------
1522    // 3D DWT tests
1523    // -------------------------------------------------------------------------
1524
1525    /// Decompose a constant volume with Haar: the LLL (approximation) subband
1526    /// should equal the input value scaled by (√2)^3 = 2√2, and all seven
1527    /// detail subbands should be near-zero.
1528    ///
1529    /// With Haar, each axis pass scales the low-pass output by 1/√2 * √2 = 1
1530    /// per coefficient and sums two values, so: a constant `c` signal of length N
1531    /// yields approx coefficients ≈ `c * √2` (one per pair of inputs).  Three
1532    /// axis passes → LLL ≈ `c * (√2)^3 = c * 2√2`.
1533    #[test]
1534    fn test_dwt3d_constant_volume_lll_scaling() -> Result<()> {
1535        let c = 3.0_f64;
1536        // 8×8×8 constant volume for clean Haar decomposition
1537        let volume = Array3::from_elem((8, 8, 8), c);
1538        let dwtn = DWTN::new(WaveletType::Haar);
1539        let coeffs = dwtn.decompose3(&volume)?;
1540
1541        // LLL subband must be non-empty and all values ≈ c * 2√2
1542        assert!(coeffs.lll.len() > 0, "LLL subband must not be empty");
1543        let expected_lll = c * 2.0_f64.sqrt().powi(3); // c * 2√2 ≈ 4.243 for c=3
1544        for val in coeffs.lll.iter() {
1545            assert_abs_diff_eq!(*val, expected_lll, epsilon = 1e-10);
1546        }
1547
1548        // All 7 detail subbands must be near-zero (constant signal has no detail)
1549        let detail_bands: [&Array3<f64>; 7] = [
1550            &coeffs.llh,
1551            &coeffs.lhl,
1552            &coeffs.lhh,
1553            &coeffs.hll,
1554            &coeffs.hlh,
1555            &coeffs.hhl,
1556            &coeffs.hhh,
1557        ];
1558        for band in detail_bands {
1559            for val in band.iter() {
1560                assert_abs_diff_eq!(*val, 0.0, epsilon = 1e-10);
1561            }
1562        }
1563        Ok(())
1564    }
1565
1566    /// Verify that `decompose3` produces subbands with the expected half-sizes.
1567    #[test]
1568    fn test_dwt3d_output_shape() -> Result<()> {
1569        let volume = Array3::from_shape_fn((8, 6, 4), |(i, j, k)| (i + j + k) as f64);
1570        let dwtn = DWTN::new(WaveletType::Haar);
1571        let coeffs = dwtn.decompose3(&volume)?;
1572
1573        // Each axis is halved (ceiling division): 8→4, 6→3, 4→2
1574        assert_eq!(coeffs.lll.dim(), (4, 3, 2));
1575        assert_eq!(coeffs.hhh.dim(), (4, 3, 2));
1576        Ok(())
1577    }
1578
1579    /// Decomposing a volume with all dimensions < 2 must return an error.
1580    #[test]
1581    fn test_dwt3d_rejects_too_small_volume() {
1582        let volume = Array3::from_elem((1, 8, 8), 1.0);
1583        let dwtn = DWTN::new(WaveletType::Haar);
1584        assert!(
1585            dwtn.decompose3(&volume).is_err(),
1586            "decompose3 must reject a volume with any dimension < 2"
1587        );
1588    }
1589}