Skip to main content

scirs2_transform/signal_transforms/
dwt.rs

1//! Discrete Wavelet Transform (DWT) Implementation
2//!
3//! Provides 1D, 2D, and N-D discrete wavelet transforms with multiple wavelet families.
4//! Implements efficient decomposition and reconstruction with proper boundary handling.
5
6use crate::error::{Result, TransformError};
7use rayon::prelude::*;
8use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
9
10/// Wavelet types supported by the DWT implementation
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum WaveletType {
13    /// Haar wavelet (Daubechies-1)
14    Haar,
15    /// Daubechies wavelets (N = 2, 4, 6, 8, 10, 12, 14, 16, 18, 20)
16    Daubechies(usize),
17    /// Symlet wavelets
18    Symlet(usize),
19    /// Coiflet wavelets
20    Coiflet(usize),
21    /// Biorthogonal wavelets
22    Biorthogonal(usize, usize),
23}
24
25/// Boundary extension modes for DWT
26#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum BoundaryMode {
28    /// Zero padding
29    Zero,
30    /// Constant padding (edge values)
31    Constant,
32    /// Symmetric padding
33    Symmetric,
34    /// Periodic padding
35    Periodic,
36    /// Reflect padding
37    Reflect,
38}
39
40/// Wavelet filter coefficients
41#[derive(Debug, Clone)]
42pub struct WaveletFilters {
43    /// Low-pass decomposition filter
44    pub dec_lo: Vec<f64>,
45    /// High-pass decomposition filter
46    pub dec_hi: Vec<f64>,
47    /// Low-pass reconstruction filter
48    pub rec_lo: Vec<f64>,
49    /// High-pass reconstruction filter
50    pub rec_hi: Vec<f64>,
51}
52
53impl WaveletFilters {
54    /// Get filter coefficients for a specific wavelet type
55    pub fn from_wavelet(wavelet: WaveletType) -> Result<Self> {
56        match wavelet {
57            WaveletType::Haar => Self::haar(),
58            WaveletType::Daubechies(n) => Self::daubechies(n),
59            WaveletType::Symlet(n) => Self::symlet(n),
60            WaveletType::Coiflet(n) => Self::coiflet(n),
61            WaveletType::Biorthogonal(p, q) => Self::biorthogonal(p, q),
62        }
63    }
64
65    /// Haar wavelet filters
66    fn haar() -> Result<Self> {
67        let norm = 1.0 / 2.0_f64.sqrt();
68        Ok(WaveletFilters {
69            dec_lo: vec![norm, norm],
70            dec_hi: vec![norm, -norm],
71            rec_lo: vec![norm, norm],
72            rec_hi: vec![-norm, norm],
73        })
74    }
75
76    /// Daubechies wavelet filters
77    fn daubechies(n: usize) -> Result<Self> {
78        match n {
79            2 => {
80                // DB2 (Daubechies-4 coefficients)
81                let sqrt3 = 3.0_f64.sqrt();
82                let denom = 4.0 * 2.0_f64.sqrt();
83                let dec_lo = vec![
84                    (1.0 + sqrt3) / denom,
85                    (3.0 + sqrt3) / denom,
86                    (3.0 - sqrt3) / denom,
87                    (1.0 - sqrt3) / denom,
88                ];
89                let mut dec_hi = Vec::with_capacity(dec_lo.len());
90                for (i, &val) in dec_lo.iter().enumerate().rev() {
91                    dec_hi.push(if i % 2 == 0 { val } else { -val });
92                }
93
94                let mut rec_lo = dec_lo.clone();
95                rec_lo.reverse();
96                let mut rec_hi = dec_hi.clone();
97                rec_hi.reverse();
98
99                Ok(WaveletFilters {
100                    dec_lo,
101                    dec_hi,
102                    rec_lo,
103                    rec_hi,
104                })
105            }
106            4 => {
107                // DB4 (Daubechies-8 coefficients)
108                let dec_lo = vec![
109                    -0.010597401784997,
110                    0.032883011666983,
111                    0.030841381835987,
112                    -0.187034811718881,
113                    -0.027983769416984,
114                    0.630880767929590,
115                    0.714846570552542,
116                    0.230377813308855,
117                ];
118                let mut dec_hi = Vec::with_capacity(dec_lo.len());
119                for (i, &val) in dec_lo.iter().enumerate().rev() {
120                    dec_hi.push(if i % 2 == 0 { val } else { -val });
121                }
122
123                let mut rec_lo = dec_lo.clone();
124                rec_lo.reverse();
125                let mut rec_hi = dec_hi.clone();
126                rec_hi.reverse();
127
128                Ok(WaveletFilters {
129                    dec_lo,
130                    dec_hi,
131                    rec_lo,
132                    rec_hi,
133                })
134            }
135            _ => Err(TransformError::InvalidInput(format!(
136                "Daubechies-{} not yet implemented",
137                n
138            ))),
139        }
140    }
141
142    /// Symlet wavelet filters (simplified - use Daubechies for now)
143    fn symlet(n: usize) -> Result<Self> {
144        // Symlets are nearly symmetric versions of Daubechies wavelets
145        Self::daubechies(n)
146    }
147
148    /// Coiflet wavelet filters
149    fn coiflet(n: usize) -> Result<Self> {
150        match n {
151            1 => {
152                // Coif1 coefficients
153                let sqrt2 = 2.0_f64.sqrt();
154                let dec_lo = vec![
155                    -0.01565572813546454 / sqrt2,
156                    -0.07268974908697540 / sqrt2,
157                    0.38486484686420286 / sqrt2,
158                    0.85257202021225542 / sqrt2,
159                    0.33789766245780093 / sqrt2,
160                    -0.07268974908697540 / sqrt2,
161                ];
162                let mut dec_hi = Vec::with_capacity(dec_lo.len());
163                for (i, &val) in dec_lo.iter().enumerate().rev() {
164                    dec_hi.push(if i % 2 == 0 { val } else { -val });
165                }
166
167                let mut rec_lo = dec_lo.clone();
168                rec_lo.reverse();
169                let mut rec_hi = dec_hi.clone();
170                rec_hi.reverse();
171
172                Ok(WaveletFilters {
173                    dec_lo,
174                    dec_hi,
175                    rec_lo,
176                    rec_hi,
177                })
178            }
179            _ => Err(TransformError::InvalidInput(format!(
180                "Coiflet-{} not yet implemented",
181                n
182            ))),
183        }
184    }
185
186    /// Biorthogonal wavelet filters
187    fn biorthogonal(_p: usize, _q: usize) -> Result<Self> {
188        // For now, return Haar as placeholder
189        Self::haar()
190    }
191}
192
193/// 1D Discrete Wavelet Transform
194#[derive(Debug, Clone)]
195pub struct DWT {
196    wavelet: WaveletType,
197    filters: WaveletFilters,
198    boundary: BoundaryMode,
199    level: Option<usize>,
200}
201
202impl DWT {
203    /// Create a new DWT instance
204    pub fn new(wavelet: WaveletType) -> Result<Self> {
205        let filters = WaveletFilters::from_wavelet(wavelet)?;
206        Ok(DWT {
207            wavelet,
208            filters,
209            boundary: BoundaryMode::Symmetric,
210            level: None,
211        })
212    }
213
214    /// Set the boundary mode
215    pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
216        self.boundary = boundary;
217        self
218    }
219
220    /// Set the decomposition level
221    pub fn with_level(mut self, level: usize) -> Self {
222        self.level = Some(level);
223        self
224    }
225
226    /// Perform single-level decomposition
227    pub fn decompose(&self, signal: &ArrayView1<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
228        let n = signal.len();
229        if n < 2 {
230            return Err(TransformError::InvalidInput(
231                "Signal too short for DWT".to_string(),
232            ));
233        }
234
235        // Extend signal according to boundary mode
236        let extended = self.extend_signal(signal)?;
237
238        // Convolve with filters and downsample
239        let approx = self.convolve_downsample(&extended, &self.filters.dec_lo)?;
240        let detail = self.convolve_downsample(&extended, &self.filters.dec_hi)?;
241
242        Ok((approx, detail))
243    }
244
245    /// Perform multi-level decomposition
246    pub fn wavedec(&self, signal: &ArrayView1<f64>) -> Result<Vec<Array1<f64>>> {
247        let max_level = self.max_decomposition_level(signal.len());
248        let level = self.level.unwrap_or(max_level).min(max_level);
249
250        let mut coeffs = Vec::with_capacity(level + 1);
251        let mut current = signal.to_owned();
252
253        for _ in 0..level {
254            let (approx, detail) = self.decompose(&current.view())?;
255            coeffs.push(detail);
256            current = approx;
257        }
258
259        // Add final approximation coefficients
260        coeffs.push(current);
261        coeffs.reverse();
262
263        Ok(coeffs)
264    }
265
266    /// Perform single-level reconstruction
267    pub fn reconstruct(
268        &self,
269        approx: &ArrayView1<f64>,
270        detail: &ArrayView1<f64>,
271    ) -> Result<Array1<f64>> {
272        // Upsample and convolve with reconstruction filters
273        let approx_up = self.upsample_convolve(approx, &self.filters.rec_lo)?;
274        let detail_up = self.upsample_convolve(detail, &self.filters.rec_hi)?;
275
276        // Add the two components
277        let min_len = approx_up.len().min(detail_up.len());
278        let mut reconstructed = Array1::zeros(min_len);
279        for i in 0..min_len {
280            reconstructed[i] = approx_up[i] + detail_up[i];
281        }
282
283        Ok(reconstructed)
284    }
285
286    /// Perform multi-level reconstruction
287    pub fn waverec(&self, coeffs: &[Array1<f64>]) -> Result<Array1<f64>> {
288        if coeffs.is_empty() {
289            return Err(TransformError::InvalidInput(
290                "No coefficients provided for reconstruction".to_string(),
291            ));
292        }
293
294        let mut current = coeffs[0].clone();
295
296        for detail in &coeffs[1..] {
297            current = self.reconstruct(&current.view(), &detail.view())?;
298        }
299
300        Ok(current)
301    }
302
303    // Helper methods
304
305    fn extend_signal(&self, signal: &ArrayView1<f64>) -> Result<Array1<f64>> {
306        let filter_len = self.filters.dec_lo.len();
307        let n = signal.len();
308        let pad_len = filter_len - 1;
309
310        let mut extended = Array1::zeros(n + 2 * pad_len);
311
312        match self.boundary {
313            BoundaryMode::Zero => {
314                for i in 0..n {
315                    extended[i + pad_len] = signal[i];
316                }
317            }
318            BoundaryMode::Constant => {
319                let first = signal[0];
320                let last = signal[n - 1];
321                for i in 0..pad_len {
322                    extended[i] = first;
323                    extended[n + pad_len + i] = last;
324                }
325                for i in 0..n {
326                    extended[i + pad_len] = signal[i];
327                }
328            }
329            BoundaryMode::Symmetric => {
330                for i in 0..pad_len {
331                    extended[pad_len - 1 - i] = signal[i.min(n - 1)];
332                    extended[n + pad_len + i] = signal[(n - 1 - i).max(0)];
333                }
334                for i in 0..n {
335                    extended[i + pad_len] = signal[i];
336                }
337            }
338            BoundaryMode::Periodic => {
339                for i in 0..pad_len {
340                    extended[i] = signal[(n - pad_len + i) % n];
341                    extended[n + pad_len + i] = signal[i % n];
342                }
343                for i in 0..n {
344                    extended[i + pad_len] = signal[i];
345                }
346            }
347            BoundaryMode::Reflect => {
348                for i in 0..pad_len {
349                    let idx1 = if i < n { i } else { n - 1 };
350                    let idx2 = if n > i + 1 { n - 1 - i } else { 0 };
351                    extended[pad_len - 1 - i] = signal[idx1];
352                    extended[n + pad_len + i] = signal[idx2];
353                }
354                for i in 0..n {
355                    extended[i + pad_len] = signal[i];
356                }
357            }
358        }
359
360        Ok(extended)
361    }
362
363    fn convolve_downsample(&self, signal: &Array1<f64>, filter: &[f64]) -> Result<Array1<f64>> {
364        let n = signal.len();
365        let filter_len = filter.len();
366        let output_len = (n + 1) / 2;
367        let mut output = Array1::zeros(output_len);
368
369        for i in 0..output_len {
370            let pos = i * 2;
371            let mut sum = 0.0;
372
373            for (j, &coeff) in filter.iter().enumerate() {
374                let idx = pos + j;
375                if idx < n {
376                    sum += signal[idx] * coeff;
377                }
378            }
379
380            output[i] = sum;
381        }
382
383        Ok(output)
384    }
385
386    fn upsample_convolve(&self, signal: &ArrayView1<f64>, filter: &[f64]) -> Result<Array1<f64>> {
387        let n = signal.len();
388        let filter_len = filter.len();
389        let output_len = n * 2;
390        let mut output = Array1::zeros(output_len);
391
392        // Upsample by inserting zeros
393        let mut upsampled = Array1::zeros(output_len);
394        for i in 0..n {
395            upsampled[i * 2] = signal[i];
396        }
397
398        // Convolve with reconstruction filter
399        for i in 0..output_len {
400            let mut sum = 0.0;
401            for (j, &coeff) in filter.iter().enumerate() {
402                if i >= j && i - j < output_len {
403                    sum += upsampled[i - j] * coeff;
404                }
405            }
406            output[i] = sum;
407        }
408
409        Ok(output)
410    }
411
412    fn max_decomposition_level(&self, signal_len: usize) -> usize {
413        let filter_len = self.filters.dec_lo.len();
414        let mut level: usize = 0;
415        let mut current_len = signal_len;
416
417        while current_len >= filter_len {
418            current_len = (current_len + 1) / 2;
419            level += 1;
420        }
421
422        level.saturating_sub(1)
423    }
424}
425
426/// 2D Discrete Wavelet Transform
427#[derive(Debug, Clone)]
428pub struct DWT2D {
429    wavelet: WaveletType,
430    filters: WaveletFilters,
431    boundary: BoundaryMode,
432    level: Option<usize>,
433}
434
435impl DWT2D {
436    /// Create a new DWT2D instance
437    pub fn new(wavelet: WaveletType) -> Result<Self> {
438        let filters = WaveletFilters::from_wavelet(wavelet)?;
439        Ok(DWT2D {
440            wavelet,
441            filters,
442            boundary: BoundaryMode::Symmetric,
443            level: None,
444        })
445    }
446
447    /// Set the boundary mode
448    pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
449        self.boundary = boundary;
450        self
451    }
452
453    /// Set the decomposition level
454    pub fn with_level(mut self, level: usize) -> Self {
455        self.level = Some(level);
456        self
457    }
458
459    /// Perform single-level 2D decomposition
460    pub fn decompose2(&self, image: &ArrayView2<f64>) -> Result<Dwt2dCoeffs> {
461        let (rows, cols) = image.dim();
462        if rows < 2 || cols < 2 {
463            return Err(TransformError::InvalidInput(
464                "Image too small for 2D DWT".to_string(),
465            ));
466        }
467
468        let dwt1d = DWT {
469            wavelet: self.wavelet,
470            filters: self.filters.clone(),
471            boundary: self.boundary,
472            level: None,
473        };
474
475        // Apply DWT along rows
476        let mut row_results_approx = Vec::with_capacity(rows);
477        let mut row_results_detail = Vec::with_capacity(rows);
478
479        for row_idx in 0..rows {
480            let row = image.row(row_idx);
481            let (approx, detail) = dwt1d.decompose(&row)?;
482            row_results_approx.push(approx);
483            row_results_detail.push(detail);
484        }
485
486        let approx_rows = row_results_approx[0].len();
487        let detail_rows = row_results_detail[0].len();
488
489        // Convert to 2D arrays
490        let mut approx_mat = Array2::zeros((rows, approx_rows));
491        let mut detail_mat = Array2::zeros((rows, detail_rows));
492
493        for (i, (app, det)) in row_results_approx
494            .iter()
495            .zip(row_results_detail.iter())
496            .enumerate()
497        {
498            for (j, &val) in app.iter().enumerate() {
499                approx_mat[[i, j]] = val;
500            }
501            for (j, &val) in det.iter().enumerate() {
502                detail_mat[[i, j]] = val;
503            }
504        }
505
506        // Apply DWT along columns
507        let (ll, lh) = self.decompose_columns(&approx_mat.view(), &dwt1d)?;
508        let (hl, hh) = self.decompose_columns(&detail_mat.view(), &dwt1d)?;
509
510        Ok(Dwt2dCoeffs { ll, lh, hl, hh })
511    }
512
513    fn decompose_columns(
514        &self,
515        mat: &ArrayView2<f64>,
516        dwt1d: &DWT,
517    ) -> Result<(Array2<f64>, Array2<f64>)> {
518        let (rows, cols) = mat.dim();
519        let mut col_results_approx = Vec::with_capacity(cols);
520        let mut col_results_detail = Vec::with_capacity(cols);
521
522        for col_idx in 0..cols {
523            let col = mat.column(col_idx);
524            let (approx, detail) = dwt1d.decompose(&col)?;
525            col_results_approx.push(approx);
526            col_results_detail.push(detail);
527        }
528
529        let approx_cols = col_results_approx[0].len();
530        let detail_cols = col_results_detail[0].len();
531
532        let mut approx_result = Array2::zeros((approx_cols, cols));
533        let mut detail_result = Array2::zeros((detail_cols, cols));
534
535        for (j, (app, det)) in col_results_approx
536            .iter()
537            .zip(col_results_detail.iter())
538            .enumerate()
539        {
540            for (i, &val) in app.iter().enumerate() {
541                approx_result[[i, j]] = val;
542            }
543            for (i, &val) in det.iter().enumerate() {
544                detail_result[[i, j]] = val;
545            }
546        }
547
548        Ok((approx_result, detail_result))
549    }
550
551    /// Perform multi-level 2D decomposition
552    pub fn wavedec2(&self, image: &ArrayView2<f64>) -> Result<Vec<Dwt2dCoeffs>> {
553        let max_level = self.max_decomposition_level_2d(image.dim());
554        let level = self.level.unwrap_or(max_level).min(max_level);
555
556        let mut coeffs = Vec::with_capacity(level);
557        let mut current = image.to_owned();
558
559        for _ in 0..level {
560            let dwt2d_coeffs = self.decompose2(&current.view())?;
561            coeffs.push(dwt2d_coeffs.clone());
562            current = dwt2d_coeffs.ll;
563        }
564
565        Ok(coeffs)
566    }
567
568    fn max_decomposition_level_2d(&self, shape: (usize, usize)) -> usize {
569        let filter_len = self.filters.dec_lo.len();
570        let min_dim = shape.0.min(shape.1);
571
572        let mut level: usize = 0;
573        let mut current_dim = min_dim;
574
575        while current_dim >= filter_len {
576            current_dim = (current_dim + 1) / 2;
577            level += 1;
578        }
579
580        level.saturating_sub(1)
581    }
582}
583
584/// 2D DWT coefficients (LL, LH, HL, HH)
585#[derive(Debug, Clone)]
586pub struct Dwt2dCoeffs {
587    /// Approximation coefficients (low-low)
588    pub ll: Array2<f64>,
589    /// Horizontal detail coefficients (low-high)
590    pub lh: Array2<f64>,
591    /// Vertical detail coefficients (high-low)
592    pub hl: Array2<f64>,
593    /// Diagonal detail coefficients (high-high)
594    pub hh: Array2<f64>,
595}
596
597/// N-D Discrete Wavelet Transform (placeholder for 3D and higher)
598#[derive(Debug, Clone)]
599pub struct DWTN {
600    wavelet: WaveletType,
601    boundary: BoundaryMode,
602    level: Option<usize>,
603}
604
605impl DWTN {
606    /// Create a new DWTN instance
607    pub fn new(wavelet: WaveletType) -> Self {
608        DWTN {
609            wavelet,
610            boundary: BoundaryMode::Symmetric,
611            level: None,
612        }
613    }
614
615    /// Set the boundary mode
616    pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
617        self.boundary = boundary;
618        self
619    }
620
621    /// Set the decomposition level
622    pub fn with_level(mut self, level: usize) -> Self {
623        self.level = Some(level);
624        self
625    }
626
627    /// Perform 3D decomposition (simplified placeholder)
628    pub fn decompose3(&self, _volume: &Array3<f64>) -> Result<Array3<f64>> {
629        Err(TransformError::NotImplemented(
630            "3D DWT not yet fully implemented".to_string(),
631        ))
632    }
633}
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638    use approx::assert_abs_diff_eq;
639
640    #[test]
641    fn test_dwt_haar() -> Result<()> {
642        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
643        let dwt = DWT::new(WaveletType::Haar)?;
644
645        let (approx, detail) = dwt.decompose(&signal.view())?;
646
647        assert!(approx.len() > 0);
648        assert!(detail.len() > 0);
649        assert_eq!(approx.len(), detail.len());
650
651        Ok(())
652    }
653
654    #[test]
655    fn test_dwt_multilevel() -> Result<()> {
656        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
657        let dwt = DWT::new(WaveletType::Haar)?.with_level(2);
658
659        let coeffs = dwt.wavedec(&signal.view())?;
660
661        assert_eq!(coeffs.len(), 3); // 2 levels + approximation
662
663        Ok(())
664    }
665
666    #[test]
667    fn test_dwt_reconstruction() -> Result<()> {
668        let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
669        let dwt = DWT::new(WaveletType::Haar)?;
670
671        let (approx, detail) = dwt.decompose(&signal.view())?;
672        let reconstructed = dwt.reconstruct(&approx.view(), &detail.view())?;
673
674        // Check reconstruction is approximately correct (may have different length)
675        assert!(reconstructed.len() >= signal.len() - 2);
676
677        Ok(())
678    }
679
680    #[test]
681    fn test_dwt2d() -> Result<()> {
682        let image = Array2::from_shape_fn((8, 8), |(i, j)| (i + j) as f64);
683        let dwt2d = DWT2D::new(WaveletType::Haar)?;
684
685        let coeffs = dwt2d.decompose2(&image.view())?;
686
687        assert!(coeffs.ll.len() > 0);
688        assert!(coeffs.lh.len() > 0);
689        assert!(coeffs.hl.len() > 0);
690        assert!(coeffs.hh.len() > 0);
691
692        Ok(())
693    }
694
695    #[test]
696    fn test_wavelet_filters() -> Result<()> {
697        let filters = WaveletFilters::from_wavelet(WaveletType::Haar)?;
698
699        assert_eq!(filters.dec_lo.len(), 2);
700        assert_eq!(filters.dec_hi.len(), 2);
701        assert_eq!(filters.rec_lo.len(), 2);
702        assert_eq!(filters.rec_hi.len(), 2);
703
704        Ok(())
705    }
706}