Skip to main content

torsh_functional/
wavelet.rs

1//! Wavelet transform operations
2//!
3//! This module provides various wavelet transform functions commonly used in signal processing
4//! and image analysis. Wavelets are particularly useful for multi-resolution analysis and
5//! feature extraction in deep learning applications.
6
7use torsh_core::{Result as TorshResult, TorshError};
8use torsh_tensor::Tensor;
9// use std::ops::{Div, Mul}; // Currently unused
10
11/// Wavelet basis types supported
12#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum WaveletType {
14    /// Daubechies wavelets
15    Daubechies(usize), // Number of vanishing moments
16    /// Biorthogonal wavelets
17    Biorthogonal(usize, usize), // (Nr, Nd) - number of vanishing moments for reconstruction and decomposition
18    /// Coiflets wavelets
19    Coiflets(usize), // Number of vanishing moments
20    /// Haar wavelet (special case of Daubechies-1)
21    Haar,
22    /// Mexican Hat (Ricker) wavelet
23    MexicanHat,
24    /// Morlet wavelet
25    Morlet,
26}
27
28/// Mode for handling boundary conditions
29#[derive(Debug, Clone, Copy, PartialEq)]
30pub enum WaveletMode {
31    /// Zero padding
32    Zero,
33    /// Constant (edge) padding
34    Constant,
35    /// Symmetric boundary conditions
36    Symmetric,
37    /// Periodic boundary conditions
38    Periodic,
39    /// Reflection boundary conditions
40    Reflect,
41}
42
43/// 1D Discrete Wavelet Transform (DWT)
44///
45/// Decomposes a 1D signal into approximation and detail coefficients
46pub fn dwt_1d(
47    input: &Tensor,
48    wavelet: WaveletType,
49    mode: WaveletMode,
50) -> TorshResult<(Tensor, Tensor)> {
51    let shape = input.shape();
52    if shape.ndim() != 1 {
53        return Err(TorshError::InvalidArgument(format!(
54            "Expected 1D input tensor, got {}D",
55            shape.ndim()
56        )));
57    }
58
59    let length = shape.dims()[0];
60    if length < 2 {
61        return Err(TorshError::InvalidArgument(
62            "Input length must be at least 2".to_string(),
63        ));
64    }
65
66    // Get wavelet coefficients
67    let (low_pass, high_pass) = get_wavelet_coefficients(wavelet)?;
68
69    // Apply convolution with downsampling
70    let approx = convolve_downsample(input, &low_pass, mode)?;
71    let detail = convolve_downsample(input, &high_pass, mode)?;
72
73    Ok((approx, detail))
74}
75
76/// 1D Inverse Discrete Wavelet Transform (IDWT)
77///
78/// Reconstructs a signal from approximation and detail coefficients
79pub fn idwt_1d(
80    approx: &Tensor,
81    detail: &Tensor,
82    wavelet: WaveletType,
83    mode: WaveletMode,
84) -> TorshResult<Tensor> {
85    if approx.shape() != detail.shape() {
86        return Err(TorshError::ShapeMismatch {
87            expected: approx.shape().dims().to_vec(),
88            got: detail.shape().dims().to_vec(),
89        });
90    }
91
92    // Get reconstruction filters
93    let (rec_low, rec_high) = get_reconstruction_coefficients(wavelet)?;
94
95    // Upsample and convolve
96    let upsampled_approx = upsample_convolve(approx, &rec_low, mode)?;
97    let upsampled_detail = upsample_convolve(detail, &rec_high, mode)?;
98
99    // Add reconstructed components
100    upsampled_approx.add_op(&upsampled_detail)
101}
102
103/// 2D Discrete Wavelet Transform
104///
105/// Applies separable 2D DWT to an image, producing 4 subbands: LL, LH, HL, HH
106pub fn dwt_2d(
107    input: &Tensor,
108    wavelet: WaveletType,
109    mode: WaveletMode,
110) -> TorshResult<(Tensor, Tensor, Tensor, Tensor)> {
111    let shape = input.shape();
112    if shape.ndim() != 2 {
113        return Err(TorshError::InvalidArgument(format!(
114            "Expected 2D input tensor, got {}D",
115            shape.ndim()
116        )));
117    }
118
119    let (height, width) = (shape.dims()[0], shape.dims()[1]);
120    if height < 2 || width < 2 {
121        return Err(TorshError::InvalidArgument(
122            "Input dimensions must be at least 2x2".to_string(),
123        ));
124    }
125
126    // Get wavelet coefficients
127    let (_low_pass, _high_pass) = get_wavelet_coefficients(wavelet)?;
128
129    // First apply DWT along rows (width dimension)
130    let mut row_approx = Vec::new();
131    let mut row_detail = Vec::new();
132
133    for h in 0..height {
134        let row = input.narrow(0, h as i64, 1)?.squeeze(0)?; // Get single row
135        let (a, d) = dwt_1d(&row, wavelet, mode)?;
136        row_approx.push(a);
137        row_detail.push(d);
138    }
139
140    // Stack row results
141    let approx_rows = stack_tensors(&row_approx, 0)?;
142    let detail_rows = stack_tensors(&row_detail, 0)?;
143
144    // Then apply DWT along columns (height dimension)
145    let new_width = approx_rows.shape().dims()[1];
146    let mut ll_cols = Vec::new();
147    let mut lh_cols = Vec::new();
148    let mut hl_cols = Vec::new();
149    let mut hh_cols = Vec::new();
150
151    for w in 0..new_width {
152        let approx_col = approx_rows.narrow(1, w as i64, 1)?.squeeze(1)?;
153        let detail_col = detail_rows.narrow(1, w as i64, 1)?.squeeze(1)?;
154
155        let (ll, lh) = dwt_1d(&approx_col, wavelet, mode)?;
156        let (hl, hh) = dwt_1d(&detail_col, wavelet, mode)?;
157
158        ll_cols.push(ll);
159        lh_cols.push(lh);
160        hl_cols.push(hl);
161        hh_cols.push(hh);
162    }
163
164    let ll = stack_tensors(&ll_cols, 1)?;
165    let lh = stack_tensors(&lh_cols, 1)?;
166    let hl = stack_tensors(&hl_cols, 1)?;
167    let hh = stack_tensors(&hh_cols, 1)?;
168
169    Ok((ll, lh, hl, hh))
170}
171
172/// 2D Inverse Discrete Wavelet Transform
173///
174/// Reconstructs a 2D signal from its 4 wavelet subbands
175pub fn idwt_2d(
176    ll: &Tensor,
177    lh: &Tensor,
178    hl: &Tensor,
179    hh: &Tensor,
180    wavelet: WaveletType,
181    mode: WaveletMode,
182) -> TorshResult<Tensor> {
183    // Verify all subbands have the same shape
184    let ll_shape = ll.shape();
185    if lh.shape() != ll_shape || hl.shape() != ll_shape || hh.shape() != ll_shape {
186        return Err(TorshError::InvalidArgument(
187            "All wavelet subbands must have the same shape".to_string(),
188        ));
189    }
190
191    let (_sub_height, sub_width) = (ll_shape.dims()[0], ll_shape.dims()[1]);
192
193    // First, reconstruct along columns (height dimension)
194    let mut approx_cols = Vec::new();
195    let mut detail_cols = Vec::new();
196
197    for w in 0..sub_width {
198        let ll_col = ll.narrow(1, w as i64, 1)?.squeeze(1)?;
199        let lh_col = lh.narrow(1, w as i64, 1)?.squeeze(1)?;
200        let hl_col = hl.narrow(1, w as i64, 1)?.squeeze(1)?;
201        let hh_col = hh.narrow(1, w as i64, 1)?.squeeze(1)?;
202
203        let approx_reconstructed = idwt_1d(&ll_col, &lh_col, wavelet, mode)?;
204        let detail_reconstructed = idwt_1d(&hl_col, &hh_col, wavelet, mode)?;
205
206        approx_cols.push(approx_reconstructed);
207        detail_cols.push(detail_reconstructed);
208    }
209
210    let approx_rows = stack_tensors(&approx_cols, 1)?;
211    let detail_rows = stack_tensors(&detail_cols, 1)?;
212
213    // Then reconstruct along rows (width dimension)
214    let height = approx_rows.shape().dims()[0];
215    let mut final_rows = Vec::new();
216
217    for h in 0..height {
218        let approx_row = approx_rows.narrow(0, h as i64, 1)?.squeeze(0)?;
219        let detail_row = detail_rows.narrow(0, h as i64, 1)?.squeeze(0)?;
220
221        let reconstructed_row = idwt_1d(&approx_row, &detail_row, wavelet, mode)?;
222        final_rows.push(reconstructed_row);
223    }
224
225    stack_tensors(&final_rows, 0)
226}
227
228/// Continuous Wavelet Transform (CWT)
229///
230/// Computes the continuous wavelet transform using the specified mother wavelet
231pub fn cwt(input: &Tensor, scales: &[f32], wavelet: WaveletType) -> TorshResult<Tensor> {
232    let input_length = input.shape().dims()[0];
233    let num_scales = scales.len();
234
235    // Initialize output tensor [scales, time]
236    let mut cwt_coeffs = Vec::with_capacity(num_scales * input_length);
237
238    for &scale in scales {
239        let wavelet_kernel = generate_wavelet_kernel(wavelet, scale, input_length)?;
240        let convolved = convolve_same(input, &wavelet_kernel)?;
241
242        let convolved_data = convolved.data()?;
243        cwt_coeffs.extend_from_slice(&convolved_data);
244    }
245
246    Tensor::from_data(cwt_coeffs, vec![num_scales, input_length], input.device())
247}
248
249/// Multi-level wavelet decomposition
250///
251/// Performs multiple levels of DWT decomposition
252pub fn wavedec(
253    input: &Tensor,
254    wavelet: WaveletType,
255    levels: usize,
256    mode: WaveletMode,
257) -> TorshResult<Vec<Tensor>> {
258    if levels == 0 {
259        return Err(TorshError::InvalidArgument(
260            "Number of levels must be greater than 0".to_string(),
261        ));
262    }
263
264    let mut coeffs = Vec::with_capacity(levels + 1);
265    let mut current = input.clone();
266
267    for _ in 0..levels {
268        let (approx, detail) = dwt_1d(&current, wavelet, mode)?;
269        coeffs.push(detail);
270        current = approx;
271    }
272
273    // Add final approximation
274    coeffs.push(current);
275    coeffs.reverse(); // Convention: [approximation, detail_n, detail_n-1, ..., detail_1]
276
277    Ok(coeffs)
278}
279
280/// Multi-level wavelet reconstruction
281///
282/// Reconstructs signal from multi-level wavelet coefficients
283pub fn waverec(coeffs: &[Tensor], wavelet: WaveletType, mode: WaveletMode) -> TorshResult<Tensor> {
284    if coeffs.is_empty() {
285        return Err(TorshError::InvalidArgument(
286            "Coefficient list cannot be empty".to_string(),
287        ));
288    }
289
290    let mut current = coeffs[0].clone(); // Start with approximation
291
292    for i in 1..coeffs.len() {
293        current = idwt_1d(&current, &coeffs[i], wavelet, mode)?;
294    }
295
296    Ok(current)
297}
298
299// Helper functions
300
301fn get_wavelet_coefficients(wavelet: WaveletType) -> TorshResult<(Vec<f32>, Vec<f32>)> {
302    match wavelet {
303        WaveletType::Haar => {
304            let low_pass = vec![
305                std::f32::consts::FRAC_1_SQRT_2,
306                std::f32::consts::FRAC_1_SQRT_2,
307            ]; // [1/√2, 1/√2]
308            let high_pass = vec![
309                std::f32::consts::FRAC_1_SQRT_2,
310                -std::f32::consts::FRAC_1_SQRT_2,
311            ]; // [1/√2, -1/√2]
312            Ok((low_pass, high_pass))
313        }
314        WaveletType::Daubechies(n) => {
315            match n {
316                2 => {
317                    // Daubechies-2 (same as Haar)
318                    let low_pass = vec![
319                        std::f32::consts::FRAC_1_SQRT_2,
320                        std::f32::consts::FRAC_1_SQRT_2,
321                    ];
322                    let high_pass = vec![
323                        std::f32::consts::FRAC_1_SQRT_2,
324                        -std::f32::consts::FRAC_1_SQRT_2,
325                    ];
326                    Ok((low_pass, high_pass))
327                }
328                4 => {
329                    // Daubechies-4
330                    let low_pass = vec![
331                        0.48296291314469025,
332                        0.8365163037378079,
333                        0.22414386804185735,
334                        -0.12940952255092145,
335                    ];
336                    let high_pass = vec![
337                        -0.12940952255092145,
338                        -0.22414386804185735,
339                        0.8365163037378079,
340                        -0.48296291314469025,
341                    ];
342                    Ok((low_pass, high_pass))
343                }
344                _ => Err(TorshError::UnsupportedOperation {
345                    op: format!("Daubechies-{}", n),
346                    dtype: "wavelet".to_string(),
347                }),
348            }
349        }
350        _ => Err(TorshError::UnsupportedOperation {
351            op: format!("{:?}", wavelet),
352            dtype: "wavelet".to_string(),
353        }),
354    }
355}
356
357fn get_reconstruction_coefficients(wavelet: WaveletType) -> TorshResult<(Vec<f32>, Vec<f32>)> {
358    let (mut low_pass, mut high_pass) = get_wavelet_coefficients(wavelet)?;
359
360    // For orthogonal wavelets, reconstruction filters are time-reversed decomposition filters
361    low_pass.reverse();
362    high_pass.reverse();
363
364    // High-pass reconstruction filter has alternating signs
365    for (i, val) in high_pass.iter_mut().enumerate() {
366        if i % 2 == 1 {
367            *val = -*val;
368        }
369    }
370
371    Ok((low_pass, high_pass))
372}
373
374fn convolve_downsample(input: &Tensor, kernel: &[f32], _mode: WaveletMode) -> TorshResult<Tensor> {
375    let input_data = input.data()?;
376    let input_len = input_data.len();
377    let _kernel_len = kernel.len();
378
379    // For downsampling by 2
380    let output_len = (input_len + 1) / 2;
381    let mut output = Vec::with_capacity(output_len);
382
383    for i in (0..input_len).step_by(2) {
384        let mut sum = 0.0;
385
386        for (k, &coeff) in kernel.iter().enumerate() {
387            let idx = i as i32 - k as i32;
388            if idx >= 0 && (idx as usize) < input_len {
389                sum += input_data[idx as usize] * coeff;
390            }
391        }
392
393        output.push(sum);
394    }
395
396    Tensor::from_data(output, vec![output_len], input.device())
397}
398
399fn upsample_convolve(input: &Tensor, kernel: &[f32], _mode: WaveletMode) -> TorshResult<Tensor> {
400    let input_data = input.data()?;
401    let input_len = input_data.len();
402    let kernel_len = kernel.len();
403
404    // Upsample by inserting zeros
405    let upsampled_len = input_len * 2;
406    let mut upsampled = vec![0.0; upsampled_len];
407
408    for (i, &val) in input_data.iter().enumerate() {
409        upsampled[i * 2] = val;
410    }
411
412    // Convolve with reconstruction filter
413    let output_len = upsampled_len + kernel_len - 1;
414    let mut output = vec![0.0; output_len];
415
416    for i in 0..upsampled_len {
417        for (k, &coeff) in kernel.iter().enumerate() {
418            output[i + k] += upsampled[i] * coeff;
419        }
420    }
421
422    // Trim to original size
423    let trimmed_len = (output_len).min(upsampled_len);
424    output.truncate(trimmed_len);
425
426    Tensor::from_data(output, vec![trimmed_len], input.device())
427}
428
429fn stack_tensors(tensors: &[Tensor], dim: usize) -> TorshResult<Tensor> {
430    if tensors.is_empty() {
431        return Err(TorshError::InvalidArgument(
432            "Cannot stack empty tensor list".to_string(),
433        ));
434    }
435
436    let first_shape = tensors[0].shape();
437    let mut stacked_shape = first_shape.dims().to_vec();
438    stacked_shape.insert(dim, tensors.len());
439
440    let element_count = first_shape.numel();
441    let mut stacked_data = Vec::with_capacity(tensors.len() * element_count);
442
443    for tensor in tensors {
444        let data = tensor.data()?;
445        stacked_data.extend_from_slice(&data);
446    }
447
448    Tensor::from_data(stacked_data, stacked_shape, tensors[0].device())
449}
450
451fn generate_wavelet_kernel(wavelet: WaveletType, scale: f32, length: usize) -> TorshResult<Tensor> {
452    match wavelet {
453        WaveletType::MexicanHat => {
454            let mut kernel = Vec::with_capacity(length);
455            let center = length as f32 / 2.0;
456
457            for i in 0..length {
458                let t = (i as f32 - center) / scale;
459                let t2 = t * t;
460                let val = (2.0 / (3.0 * scale).sqrt() * std::f32::consts::PI.powf(0.25))
461                    * (1.0 - t2)
462                    * (-t2 / 2.0).exp();
463                kernel.push(val);
464            }
465
466            Tensor::from_data(kernel, vec![length], torsh_core::device::DeviceType::Cpu)
467        }
468        WaveletType::Morlet => {
469            let mut kernel = Vec::with_capacity(length);
470            let center = length as f32 / 2.0;
471            let omega0 = 6.0; // Central frequency
472
473            for i in 0..length {
474                let t = (i as f32 - center) / scale;
475                let val = (1.0 / (scale * std::f32::consts::PI.sqrt()))
476                    * (omega0 * t).cos()
477                    * (-(t * t) / 2.0).exp();
478                kernel.push(val);
479            }
480
481            Tensor::from_data(kernel, vec![length], torsh_core::device::DeviceType::Cpu)
482        }
483        _ => Err(TorshError::UnsupportedOperation {
484            op: format!("CWT with {:?}", wavelet),
485            dtype: "wavelet".to_string(),
486        }),
487    }
488}
489
490fn convolve_same(input: &Tensor, kernel: &Tensor) -> TorshResult<Tensor> {
491    let input_data = input.data()?;
492    let kernel_data = kernel.data()?;
493    let input_len = input_data.len();
494    let kernel_len = kernel_data.len();
495
496    let mut output = vec![0.0; input_len];
497    let half_kernel = kernel_len / 2;
498
499    for i in 0..input_len {
500        for j in 0..kernel_len {
501            let input_idx = i as i32 + j as i32 - half_kernel as i32;
502            if input_idx >= 0 && (input_idx as usize) < input_len {
503                output[i] += input_data[input_idx as usize] * kernel_data[j];
504            }
505        }
506    }
507
508    Tensor::from_data(output, vec![input_len], input.device())
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514    use torsh_tensor::creation::tensor_1d;
515
516    #[test]
517    fn test_haar_dwt_1d() {
518        let input = tensor_1d(&[1.0, 2.0, 3.0, 4.0]).unwrap();
519        let (approx, detail) = dwt_1d(&input, WaveletType::Haar, WaveletMode::Zero).unwrap();
520
521        // Check that we get expected shapes
522        assert_eq!(approx.shape().dims(), &[2]);
523        assert_eq!(detail.shape().dims(), &[2]);
524
525        // Check perfect reconstruction
526        let reconstructed =
527            idwt_1d(&approx, &detail, WaveletType::Haar, WaveletMode::Zero).unwrap();
528        assert_eq!(reconstructed.shape().dims(), &[4]);
529    }
530
531    #[test]
532    fn test_daubechies4_coefficients() {
533        let (low_pass, high_pass) = get_wavelet_coefficients(WaveletType::Daubechies(4)).unwrap();
534
535        // Check that we have the right number of coefficients
536        assert_eq!(low_pass.len(), 4);
537        assert_eq!(high_pass.len(), 4);
538
539        // Check energy conservation (sum of squares should be 2 for normalized wavelets)
540        let low_energy: f32 = low_pass.iter().map(|x| x * x).sum();
541        let high_energy: f32 = high_pass.iter().map(|x| x * x).sum();
542
543        assert!((low_energy - 1.0).abs() < 1e-6);
544        assert!((high_energy - 1.0).abs() < 1e-6);
545    }
546
547    #[test]
548    fn test_multilevel_decomposition() {
549        let input = tensor_1d(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
550        let coeffs = wavedec(&input, WaveletType::Haar, 2, WaveletMode::Zero).unwrap();
551
552        // Should have 3 coefficient arrays: [approximation, detail_2, detail_1]
553        assert_eq!(coeffs.len(), 3);
554
555        // Reconstruct and verify
556        let reconstructed = waverec(&coeffs, WaveletType::Haar, WaveletMode::Zero).unwrap();
557
558        // Check that reconstruction has reasonable length
559        assert!(reconstructed.shape().dims()[0] >= 4);
560    }
561
562    #[test]
563    fn test_cwt_mexican_hat() {
564        let input = tensor_1d(&[0.0, 0.0, 1.0, 0.0, 0.0, 0.0]).unwrap();
565        let scales = vec![1.0, 2.0, 3.0];
566        let result = cwt(&input, &scales, WaveletType::MexicanHat).unwrap();
567
568        // Should have shape [num_scales, input_length]
569        assert_eq!(result.shape().dims(), &[3, 6]);
570    }
571}