1use crate::error::{Result, TransformError};
7use rayon::prelude::*;
8use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, Axis};
9
10#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum WaveletType {
13 Haar,
15 Daubechies(usize),
17 Symlet(usize),
19 Coiflet(usize),
21 Biorthogonal(usize, usize),
23}
24
25#[derive(Debug, Clone, Copy, PartialEq)]
27pub enum BoundaryMode {
28 Zero,
30 Constant,
32 Symmetric,
34 Periodic,
36 Reflect,
38}
39
40#[derive(Debug, Clone)]
42pub struct WaveletFilters {
43 pub dec_lo: Vec<f64>,
45 pub dec_hi: Vec<f64>,
47 pub rec_lo: Vec<f64>,
49 pub rec_hi: Vec<f64>,
51}
52
53impl WaveletFilters {
54 pub fn from_wavelet(wavelet: WaveletType) -> Result<Self> {
56 match wavelet {
57 WaveletType::Haar => Self::haar(),
58 WaveletType::Daubechies(n) => Self::daubechies(n),
59 WaveletType::Symlet(n) => Self::symlet(n),
60 WaveletType::Coiflet(n) => Self::coiflet(n),
61 WaveletType::Biorthogonal(p, q) => Self::biorthogonal(p, q),
62 }
63 }
64
65 fn haar() -> Result<Self> {
67 let norm = 1.0 / 2.0_f64.sqrt();
68 Ok(WaveletFilters {
69 dec_lo: vec![norm, norm],
70 dec_hi: vec![norm, -norm],
71 rec_lo: vec![norm, norm],
72 rec_hi: vec![-norm, norm],
73 })
74 }
75
76 fn daubechies(n: usize) -> Result<Self> {
78 match n {
79 2 => {
80 let sqrt3 = 3.0_f64.sqrt();
82 let denom = 4.0 * 2.0_f64.sqrt();
83 let dec_lo = vec![
84 (1.0 + sqrt3) / denom,
85 (3.0 + sqrt3) / denom,
86 (3.0 - sqrt3) / denom,
87 (1.0 - sqrt3) / denom,
88 ];
89 let mut dec_hi = Vec::with_capacity(dec_lo.len());
90 for (i, &val) in dec_lo.iter().enumerate().rev() {
91 dec_hi.push(if i % 2 == 0 { val } else { -val });
92 }
93
94 let mut rec_lo = dec_lo.clone();
95 rec_lo.reverse();
96 let mut rec_hi = dec_hi.clone();
97 rec_hi.reverse();
98
99 Ok(WaveletFilters {
100 dec_lo,
101 dec_hi,
102 rec_lo,
103 rec_hi,
104 })
105 }
106 4 => {
107 let dec_lo = vec![
109 -0.010597401784997,
110 0.032883011666983,
111 0.030841381835987,
112 -0.187034811718881,
113 -0.027983769416984,
114 0.630880767929590,
115 0.714846570552542,
116 0.230377813308855,
117 ];
118 let mut dec_hi = Vec::with_capacity(dec_lo.len());
119 for (i, &val) in dec_lo.iter().enumerate().rev() {
120 dec_hi.push(if i % 2 == 0 { val } else { -val });
121 }
122
123 let mut rec_lo = dec_lo.clone();
124 rec_lo.reverse();
125 let mut rec_hi = dec_hi.clone();
126 rec_hi.reverse();
127
128 Ok(WaveletFilters {
129 dec_lo,
130 dec_hi,
131 rec_lo,
132 rec_hi,
133 })
134 }
135 _ => Err(TransformError::InvalidInput(format!(
136 "Daubechies-{} not yet implemented",
137 n
138 ))),
139 }
140 }
141
142 fn symlet(n: usize) -> Result<Self> {
144 Self::daubechies(n)
146 }
147
148 fn coiflet(n: usize) -> Result<Self> {
150 match n {
151 1 => {
152 let sqrt2 = 2.0_f64.sqrt();
154 let dec_lo = vec![
155 -0.01565572813546454 / sqrt2,
156 -0.07268974908697540 / sqrt2,
157 0.38486484686420286 / sqrt2,
158 0.85257202021225542 / sqrt2,
159 0.33789766245780093 / sqrt2,
160 -0.07268974908697540 / sqrt2,
161 ];
162 let mut dec_hi = Vec::with_capacity(dec_lo.len());
163 for (i, &val) in dec_lo.iter().enumerate().rev() {
164 dec_hi.push(if i % 2 == 0 { val } else { -val });
165 }
166
167 let mut rec_lo = dec_lo.clone();
168 rec_lo.reverse();
169 let mut rec_hi = dec_hi.clone();
170 rec_hi.reverse();
171
172 Ok(WaveletFilters {
173 dec_lo,
174 dec_hi,
175 rec_lo,
176 rec_hi,
177 })
178 }
179 _ => Err(TransformError::InvalidInput(format!(
180 "Coiflet-{} not yet implemented",
181 n
182 ))),
183 }
184 }
185
186 fn biorthogonal(_p: usize, _q: usize) -> Result<Self> {
188 Self::haar()
190 }
191}
192
193#[derive(Debug, Clone)]
195pub struct DWT {
196 wavelet: WaveletType,
197 filters: WaveletFilters,
198 boundary: BoundaryMode,
199 level: Option<usize>,
200}
201
202impl DWT {
203 pub fn new(wavelet: WaveletType) -> Result<Self> {
205 let filters = WaveletFilters::from_wavelet(wavelet)?;
206 Ok(DWT {
207 wavelet,
208 filters,
209 boundary: BoundaryMode::Symmetric,
210 level: None,
211 })
212 }
213
214 pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
216 self.boundary = boundary;
217 self
218 }
219
220 pub fn with_level(mut self, level: usize) -> Self {
222 self.level = Some(level);
223 self
224 }
225
226 pub fn decompose(&self, signal: &ArrayView1<f64>) -> Result<(Array1<f64>, Array1<f64>)> {
228 let n = signal.len();
229 if n < 2 {
230 return Err(TransformError::InvalidInput(
231 "Signal too short for DWT".to_string(),
232 ));
233 }
234
235 let extended = self.extend_signal(signal)?;
237
238 let approx = self.convolve_downsample(&extended, &self.filters.dec_lo)?;
240 let detail = self.convolve_downsample(&extended, &self.filters.dec_hi)?;
241
242 Ok((approx, detail))
243 }
244
245 pub fn wavedec(&self, signal: &ArrayView1<f64>) -> Result<Vec<Array1<f64>>> {
247 let max_level = self.max_decomposition_level(signal.len());
248 let level = self.level.unwrap_or(max_level).min(max_level);
249
250 let mut coeffs = Vec::with_capacity(level + 1);
251 let mut current = signal.to_owned();
252
253 for _ in 0..level {
254 let (approx, detail) = self.decompose(¤t.view())?;
255 coeffs.push(detail);
256 current = approx;
257 }
258
259 coeffs.push(current);
261 coeffs.reverse();
262
263 Ok(coeffs)
264 }
265
266 pub fn reconstruct(
268 &self,
269 approx: &ArrayView1<f64>,
270 detail: &ArrayView1<f64>,
271 ) -> Result<Array1<f64>> {
272 let approx_up = self.upsample_convolve(approx, &self.filters.rec_lo)?;
274 let detail_up = self.upsample_convolve(detail, &self.filters.rec_hi)?;
275
276 let min_len = approx_up.len().min(detail_up.len());
278 let mut reconstructed = Array1::zeros(min_len);
279 for i in 0..min_len {
280 reconstructed[i] = approx_up[i] + detail_up[i];
281 }
282
283 Ok(reconstructed)
284 }
285
286 pub fn waverec(&self, coeffs: &[Array1<f64>]) -> Result<Array1<f64>> {
288 if coeffs.is_empty() {
289 return Err(TransformError::InvalidInput(
290 "No coefficients provided for reconstruction".to_string(),
291 ));
292 }
293
294 let mut current = coeffs[0].clone();
295
296 for detail in &coeffs[1..] {
297 current = self.reconstruct(¤t.view(), &detail.view())?;
298 }
299
300 Ok(current)
301 }
302
303 fn extend_signal(&self, signal: &ArrayView1<f64>) -> Result<Array1<f64>> {
306 let filter_len = self.filters.dec_lo.len();
307 let n = signal.len();
308 let pad_len = filter_len - 1;
309
310 let mut extended = Array1::zeros(n + 2 * pad_len);
311
312 match self.boundary {
313 BoundaryMode::Zero => {
314 for i in 0..n {
315 extended[i + pad_len] = signal[i];
316 }
317 }
318 BoundaryMode::Constant => {
319 let first = signal[0];
320 let last = signal[n - 1];
321 for i in 0..pad_len {
322 extended[i] = first;
323 extended[n + pad_len + i] = last;
324 }
325 for i in 0..n {
326 extended[i + pad_len] = signal[i];
327 }
328 }
329 BoundaryMode::Symmetric => {
330 for i in 0..pad_len {
331 extended[pad_len - 1 - i] = signal[i.min(n - 1)];
332 extended[n + pad_len + i] = signal[(n - 1 - i).max(0)];
333 }
334 for i in 0..n {
335 extended[i + pad_len] = signal[i];
336 }
337 }
338 BoundaryMode::Periodic => {
339 for i in 0..pad_len {
340 extended[i] = signal[(n - pad_len + i) % n];
341 extended[n + pad_len + i] = signal[i % n];
342 }
343 for i in 0..n {
344 extended[i + pad_len] = signal[i];
345 }
346 }
347 BoundaryMode::Reflect => {
348 for i in 0..pad_len {
349 let idx1 = if i < n { i } else { n - 1 };
350 let idx2 = if n > i + 1 { n - 1 - i } else { 0 };
351 extended[pad_len - 1 - i] = signal[idx1];
352 extended[n + pad_len + i] = signal[idx2];
353 }
354 for i in 0..n {
355 extended[i + pad_len] = signal[i];
356 }
357 }
358 }
359
360 Ok(extended)
361 }
362
363 fn convolve_downsample(&self, signal: &Array1<f64>, filter: &[f64]) -> Result<Array1<f64>> {
364 let n = signal.len();
365 let filter_len = filter.len();
366 let output_len = (n + 1) / 2;
367 let mut output = Array1::zeros(output_len);
368
369 for i in 0..output_len {
370 let pos = i * 2;
371 let mut sum = 0.0;
372
373 for (j, &coeff) in filter.iter().enumerate() {
374 let idx = pos + j;
375 if idx < n {
376 sum += signal[idx] * coeff;
377 }
378 }
379
380 output[i] = sum;
381 }
382
383 Ok(output)
384 }
385
386 fn upsample_convolve(&self, signal: &ArrayView1<f64>, filter: &[f64]) -> Result<Array1<f64>> {
387 let n = signal.len();
388 let filter_len = filter.len();
389 let output_len = n * 2;
390 let mut output = Array1::zeros(output_len);
391
392 let mut upsampled = Array1::zeros(output_len);
394 for i in 0..n {
395 upsampled[i * 2] = signal[i];
396 }
397
398 for i in 0..output_len {
400 let mut sum = 0.0;
401 for (j, &coeff) in filter.iter().enumerate() {
402 if i >= j && i - j < output_len {
403 sum += upsampled[i - j] * coeff;
404 }
405 }
406 output[i] = sum;
407 }
408
409 Ok(output)
410 }
411
412 fn max_decomposition_level(&self, signal_len: usize) -> usize {
413 let filter_len = self.filters.dec_lo.len();
414 let mut level: usize = 0;
415 let mut current_len = signal_len;
416
417 while current_len >= filter_len {
418 current_len = (current_len + 1) / 2;
419 level += 1;
420 }
421
422 level.saturating_sub(1)
423 }
424}
425
426#[derive(Debug, Clone)]
428pub struct DWT2D {
429 wavelet: WaveletType,
430 filters: WaveletFilters,
431 boundary: BoundaryMode,
432 level: Option<usize>,
433}
434
435impl DWT2D {
436 pub fn new(wavelet: WaveletType) -> Result<Self> {
438 let filters = WaveletFilters::from_wavelet(wavelet)?;
439 Ok(DWT2D {
440 wavelet,
441 filters,
442 boundary: BoundaryMode::Symmetric,
443 level: None,
444 })
445 }
446
447 pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
449 self.boundary = boundary;
450 self
451 }
452
453 pub fn with_level(mut self, level: usize) -> Self {
455 self.level = Some(level);
456 self
457 }
458
459 pub fn decompose2(&self, image: &ArrayView2<f64>) -> Result<Dwt2dCoeffs> {
461 let (rows, cols) = image.dim();
462 if rows < 2 || cols < 2 {
463 return Err(TransformError::InvalidInput(
464 "Image too small for 2D DWT".to_string(),
465 ));
466 }
467
468 let dwt1d = DWT {
469 wavelet: self.wavelet,
470 filters: self.filters.clone(),
471 boundary: self.boundary,
472 level: None,
473 };
474
475 let mut row_results_approx = Vec::with_capacity(rows);
477 let mut row_results_detail = Vec::with_capacity(rows);
478
479 for row_idx in 0..rows {
480 let row = image.row(row_idx);
481 let (approx, detail) = dwt1d.decompose(&row)?;
482 row_results_approx.push(approx);
483 row_results_detail.push(detail);
484 }
485
486 let approx_rows = row_results_approx[0].len();
487 let detail_rows = row_results_detail[0].len();
488
489 let mut approx_mat = Array2::zeros((rows, approx_rows));
491 let mut detail_mat = Array2::zeros((rows, detail_rows));
492
493 for (i, (app, det)) in row_results_approx
494 .iter()
495 .zip(row_results_detail.iter())
496 .enumerate()
497 {
498 for (j, &val) in app.iter().enumerate() {
499 approx_mat[[i, j]] = val;
500 }
501 for (j, &val) in det.iter().enumerate() {
502 detail_mat[[i, j]] = val;
503 }
504 }
505
506 let (ll, lh) = self.decompose_columns(&approx_mat.view(), &dwt1d)?;
508 let (hl, hh) = self.decompose_columns(&detail_mat.view(), &dwt1d)?;
509
510 Ok(Dwt2dCoeffs { ll, lh, hl, hh })
511 }
512
513 fn decompose_columns(
514 &self,
515 mat: &ArrayView2<f64>,
516 dwt1d: &DWT,
517 ) -> Result<(Array2<f64>, Array2<f64>)> {
518 let (rows, cols) = mat.dim();
519 let mut col_results_approx = Vec::with_capacity(cols);
520 let mut col_results_detail = Vec::with_capacity(cols);
521
522 for col_idx in 0..cols {
523 let col = mat.column(col_idx);
524 let (approx, detail) = dwt1d.decompose(&col)?;
525 col_results_approx.push(approx);
526 col_results_detail.push(detail);
527 }
528
529 let approx_cols = col_results_approx[0].len();
530 let detail_cols = col_results_detail[0].len();
531
532 let mut approx_result = Array2::zeros((approx_cols, cols));
533 let mut detail_result = Array2::zeros((detail_cols, cols));
534
535 for (j, (app, det)) in col_results_approx
536 .iter()
537 .zip(col_results_detail.iter())
538 .enumerate()
539 {
540 for (i, &val) in app.iter().enumerate() {
541 approx_result[[i, j]] = val;
542 }
543 for (i, &val) in det.iter().enumerate() {
544 detail_result[[i, j]] = val;
545 }
546 }
547
548 Ok((approx_result, detail_result))
549 }
550
551 pub fn wavedec2(&self, image: &ArrayView2<f64>) -> Result<Vec<Dwt2dCoeffs>> {
553 let max_level = self.max_decomposition_level_2d(image.dim());
554 let level = self.level.unwrap_or(max_level).min(max_level);
555
556 let mut coeffs = Vec::with_capacity(level);
557 let mut current = image.to_owned();
558
559 for _ in 0..level {
560 let dwt2d_coeffs = self.decompose2(¤t.view())?;
561 coeffs.push(dwt2d_coeffs.clone());
562 current = dwt2d_coeffs.ll;
563 }
564
565 Ok(coeffs)
566 }
567
568 fn max_decomposition_level_2d(&self, shape: (usize, usize)) -> usize {
569 let filter_len = self.filters.dec_lo.len();
570 let min_dim = shape.0.min(shape.1);
571
572 let mut level: usize = 0;
573 let mut current_dim = min_dim;
574
575 while current_dim >= filter_len {
576 current_dim = (current_dim + 1) / 2;
577 level += 1;
578 }
579
580 level.saturating_sub(1)
581 }
582}
583
584#[derive(Debug, Clone)]
586pub struct Dwt2dCoeffs {
587 pub ll: Array2<f64>,
589 pub lh: Array2<f64>,
591 pub hl: Array2<f64>,
593 pub hh: Array2<f64>,
595}
596
597#[derive(Debug, Clone)]
599pub struct DWTN {
600 wavelet: WaveletType,
601 boundary: BoundaryMode,
602 level: Option<usize>,
603}
604
605impl DWTN {
606 pub fn new(wavelet: WaveletType) -> Self {
608 DWTN {
609 wavelet,
610 boundary: BoundaryMode::Symmetric,
611 level: None,
612 }
613 }
614
615 pub fn with_boundary(mut self, boundary: BoundaryMode) -> Self {
617 self.boundary = boundary;
618 self
619 }
620
621 pub fn with_level(mut self, level: usize) -> Self {
623 self.level = Some(level);
624 self
625 }
626
627 pub fn decompose3(&self, _volume: &Array3<f64>) -> Result<Array3<f64>> {
629 Err(TransformError::NotImplemented(
630 "3D DWT not yet fully implemented".to_string(),
631 ))
632 }
633}
634
635#[cfg(test)]
636mod tests {
637 use super::*;
638 use approx::assert_abs_diff_eq;
639
640 #[test]
641 fn test_dwt_haar() -> Result<()> {
642 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
643 let dwt = DWT::new(WaveletType::Haar)?;
644
645 let (approx, detail) = dwt.decompose(&signal.view())?;
646
647 assert!(approx.len() > 0);
648 assert!(detail.len() > 0);
649 assert_eq!(approx.len(), detail.len());
650
651 Ok(())
652 }
653
654 #[test]
655 fn test_dwt_multilevel() -> Result<()> {
656 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
657 let dwt = DWT::new(WaveletType::Haar)?.with_level(2);
658
659 let coeffs = dwt.wavedec(&signal.view())?;
660
661 assert_eq!(coeffs.len(), 3); Ok(())
664 }
665
666 #[test]
667 fn test_dwt_reconstruction() -> Result<()> {
668 let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
669 let dwt = DWT::new(WaveletType::Haar)?;
670
671 let (approx, detail) = dwt.decompose(&signal.view())?;
672 let reconstructed = dwt.reconstruct(&approx.view(), &detail.view())?;
673
674 assert!(reconstructed.len() >= signal.len() - 2);
676
677 Ok(())
678 }
679
680 #[test]
681 fn test_dwt2d() -> Result<()> {
682 let image = Array2::from_shape_fn((8, 8), |(i, j)| (i + j) as f64);
683 let dwt2d = DWT2D::new(WaveletType::Haar)?;
684
685 let coeffs = dwt2d.decompose2(&image.view())?;
686
687 assert!(coeffs.ll.len() > 0);
688 assert!(coeffs.lh.len() > 0);
689 assert!(coeffs.hl.len() > 0);
690 assert!(coeffs.hh.len() > 0);
691
692 Ok(())
693 }
694
695 #[test]
696 fn test_wavelet_filters() -> Result<()> {
697 let filters = WaveletFilters::from_wavelet(WaveletType::Haar)?;
698
699 assert_eq!(filters.dec_lo.len(), 2);
700 assert_eq!(filters.dec_hi.len(), 2);
701 assert_eq!(filters.rec_lo.len(), 2);
702 assert_eq!(filters.rec_hi.len(), 2);
703
704 Ok(())
705 }
706}