1use torsh_core::{Result as TorshResult, TorshError};
8use torsh_tensor::Tensor;
9#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum WaveletType {
14 Daubechies(usize), Biorthogonal(usize, usize), Coiflets(usize), Haar,
22 MexicanHat,
24 Morlet,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq)]
30pub enum WaveletMode {
31 Zero,
33 Constant,
35 Symmetric,
37 Periodic,
39 Reflect,
41}
42
43pub fn dwt_1d(
47 input: &Tensor,
48 wavelet: WaveletType,
49 mode: WaveletMode,
50) -> TorshResult<(Tensor, Tensor)> {
51 let shape = input.shape();
52 if shape.ndim() != 1 {
53 return Err(TorshError::InvalidArgument(format!(
54 "Expected 1D input tensor, got {}D",
55 shape.ndim()
56 )));
57 }
58
59 let length = shape.dims()[0];
60 if length < 2 {
61 return Err(TorshError::InvalidArgument(
62 "Input length must be at least 2".to_string(),
63 ));
64 }
65
66 let (low_pass, high_pass) = get_wavelet_coefficients(wavelet)?;
68
69 let approx = convolve_downsample(input, &low_pass, mode)?;
71 let detail = convolve_downsample(input, &high_pass, mode)?;
72
73 Ok((approx, detail))
74}
75
76pub fn idwt_1d(
80 approx: &Tensor,
81 detail: &Tensor,
82 wavelet: WaveletType,
83 mode: WaveletMode,
84) -> TorshResult<Tensor> {
85 if approx.shape() != detail.shape() {
86 return Err(TorshError::ShapeMismatch {
87 expected: approx.shape().dims().to_vec(),
88 got: detail.shape().dims().to_vec(),
89 });
90 }
91
92 let (rec_low, rec_high) = get_reconstruction_coefficients(wavelet)?;
94
95 let upsampled_approx = upsample_convolve(approx, &rec_low, mode)?;
97 let upsampled_detail = upsample_convolve(detail, &rec_high, mode)?;
98
99 upsampled_approx.add_op(&upsampled_detail)
101}
102
103pub fn dwt_2d(
107 input: &Tensor,
108 wavelet: WaveletType,
109 mode: WaveletMode,
110) -> TorshResult<(Tensor, Tensor, Tensor, Tensor)> {
111 let shape = input.shape();
112 if shape.ndim() != 2 {
113 return Err(TorshError::InvalidArgument(format!(
114 "Expected 2D input tensor, got {}D",
115 shape.ndim()
116 )));
117 }
118
119 let (height, width) = (shape.dims()[0], shape.dims()[1]);
120 if height < 2 || width < 2 {
121 return Err(TorshError::InvalidArgument(
122 "Input dimensions must be at least 2x2".to_string(),
123 ));
124 }
125
126 let (_low_pass, _high_pass) = get_wavelet_coefficients(wavelet)?;
128
129 let mut row_approx = Vec::new();
131 let mut row_detail = Vec::new();
132
133 for h in 0..height {
134 let row = input.narrow(0, h as i64, 1)?.squeeze(0)?; let (a, d) = dwt_1d(&row, wavelet, mode)?;
136 row_approx.push(a);
137 row_detail.push(d);
138 }
139
140 let approx_rows = stack_tensors(&row_approx, 0)?;
142 let detail_rows = stack_tensors(&row_detail, 0)?;
143
144 let new_width = approx_rows.shape().dims()[1];
146 let mut ll_cols = Vec::new();
147 let mut lh_cols = Vec::new();
148 let mut hl_cols = Vec::new();
149 let mut hh_cols = Vec::new();
150
151 for w in 0..new_width {
152 let approx_col = approx_rows.narrow(1, w as i64, 1)?.squeeze(1)?;
153 let detail_col = detail_rows.narrow(1, w as i64, 1)?.squeeze(1)?;
154
155 let (ll, lh) = dwt_1d(&approx_col, wavelet, mode)?;
156 let (hl, hh) = dwt_1d(&detail_col, wavelet, mode)?;
157
158 ll_cols.push(ll);
159 lh_cols.push(lh);
160 hl_cols.push(hl);
161 hh_cols.push(hh);
162 }
163
164 let ll = stack_tensors(&ll_cols, 1)?;
165 let lh = stack_tensors(&lh_cols, 1)?;
166 let hl = stack_tensors(&hl_cols, 1)?;
167 let hh = stack_tensors(&hh_cols, 1)?;
168
169 Ok((ll, lh, hl, hh))
170}
171
172pub fn idwt_2d(
176 ll: &Tensor,
177 lh: &Tensor,
178 hl: &Tensor,
179 hh: &Tensor,
180 wavelet: WaveletType,
181 mode: WaveletMode,
182) -> TorshResult<Tensor> {
183 let ll_shape = ll.shape();
185 if lh.shape() != ll_shape || hl.shape() != ll_shape || hh.shape() != ll_shape {
186 return Err(TorshError::InvalidArgument(
187 "All wavelet subbands must have the same shape".to_string(),
188 ));
189 }
190
191 let (_sub_height, sub_width) = (ll_shape.dims()[0], ll_shape.dims()[1]);
192
193 let mut approx_cols = Vec::new();
195 let mut detail_cols = Vec::new();
196
197 for w in 0..sub_width {
198 let ll_col = ll.narrow(1, w as i64, 1)?.squeeze(1)?;
199 let lh_col = lh.narrow(1, w as i64, 1)?.squeeze(1)?;
200 let hl_col = hl.narrow(1, w as i64, 1)?.squeeze(1)?;
201 let hh_col = hh.narrow(1, w as i64, 1)?.squeeze(1)?;
202
203 let approx_reconstructed = idwt_1d(&ll_col, &lh_col, wavelet, mode)?;
204 let detail_reconstructed = idwt_1d(&hl_col, &hh_col, wavelet, mode)?;
205
206 approx_cols.push(approx_reconstructed);
207 detail_cols.push(detail_reconstructed);
208 }
209
210 let approx_rows = stack_tensors(&approx_cols, 1)?;
211 let detail_rows = stack_tensors(&detail_cols, 1)?;
212
213 let height = approx_rows.shape().dims()[0];
215 let mut final_rows = Vec::new();
216
217 for h in 0..height {
218 let approx_row = approx_rows.narrow(0, h as i64, 1)?.squeeze(0)?;
219 let detail_row = detail_rows.narrow(0, h as i64, 1)?.squeeze(0)?;
220
221 let reconstructed_row = idwt_1d(&approx_row, &detail_row, wavelet, mode)?;
222 final_rows.push(reconstructed_row);
223 }
224
225 stack_tensors(&final_rows, 0)
226}
227
228pub fn cwt(input: &Tensor, scales: &[f32], wavelet: WaveletType) -> TorshResult<Tensor> {
232 let input_length = input.shape().dims()[0];
233 let num_scales = scales.len();
234
235 let mut cwt_coeffs = Vec::with_capacity(num_scales * input_length);
237
238 for &scale in scales {
239 let wavelet_kernel = generate_wavelet_kernel(wavelet, scale, input_length)?;
240 let convolved = convolve_same(input, &wavelet_kernel)?;
241
242 let convolved_data = convolved.data()?;
243 cwt_coeffs.extend_from_slice(&convolved_data);
244 }
245
246 Tensor::from_data(cwt_coeffs, vec![num_scales, input_length], input.device())
247}
248
249pub fn wavedec(
253 input: &Tensor,
254 wavelet: WaveletType,
255 levels: usize,
256 mode: WaveletMode,
257) -> TorshResult<Vec<Tensor>> {
258 if levels == 0 {
259 return Err(TorshError::InvalidArgument(
260 "Number of levels must be greater than 0".to_string(),
261 ));
262 }
263
264 let mut coeffs = Vec::with_capacity(levels + 1);
265 let mut current = input.clone();
266
267 for _ in 0..levels {
268 let (approx, detail) = dwt_1d(¤t, wavelet, mode)?;
269 coeffs.push(detail);
270 current = approx;
271 }
272
273 coeffs.push(current);
275 coeffs.reverse(); Ok(coeffs)
278}
279
280pub fn waverec(coeffs: &[Tensor], wavelet: WaveletType, mode: WaveletMode) -> TorshResult<Tensor> {
284 if coeffs.is_empty() {
285 return Err(TorshError::InvalidArgument(
286 "Coefficient list cannot be empty".to_string(),
287 ));
288 }
289
290 let mut current = coeffs[0].clone(); for i in 1..coeffs.len() {
293 current = idwt_1d(¤t, &coeffs[i], wavelet, mode)?;
294 }
295
296 Ok(current)
297}
298
299fn get_wavelet_coefficients(wavelet: WaveletType) -> TorshResult<(Vec<f32>, Vec<f32>)> {
302 match wavelet {
303 WaveletType::Haar => {
304 let low_pass = vec![
305 std::f32::consts::FRAC_1_SQRT_2,
306 std::f32::consts::FRAC_1_SQRT_2,
307 ]; let high_pass = vec![
309 std::f32::consts::FRAC_1_SQRT_2,
310 -std::f32::consts::FRAC_1_SQRT_2,
311 ]; Ok((low_pass, high_pass))
313 }
314 WaveletType::Daubechies(n) => {
315 match n {
316 2 => {
317 let low_pass = vec![
319 std::f32::consts::FRAC_1_SQRT_2,
320 std::f32::consts::FRAC_1_SQRT_2,
321 ];
322 let high_pass = vec![
323 std::f32::consts::FRAC_1_SQRT_2,
324 -std::f32::consts::FRAC_1_SQRT_2,
325 ];
326 Ok((low_pass, high_pass))
327 }
328 4 => {
329 let low_pass = vec![
331 0.48296291314469025,
332 0.8365163037378079,
333 0.22414386804185735,
334 -0.12940952255092145,
335 ];
336 let high_pass = vec![
337 -0.12940952255092145,
338 -0.22414386804185735,
339 0.8365163037378079,
340 -0.48296291314469025,
341 ];
342 Ok((low_pass, high_pass))
343 }
344 _ => Err(TorshError::UnsupportedOperation {
345 op: format!("Daubechies-{}", n),
346 dtype: "wavelet".to_string(),
347 }),
348 }
349 }
350 _ => Err(TorshError::UnsupportedOperation {
351 op: format!("{:?}", wavelet),
352 dtype: "wavelet".to_string(),
353 }),
354 }
355}
356
357fn get_reconstruction_coefficients(wavelet: WaveletType) -> TorshResult<(Vec<f32>, Vec<f32>)> {
358 let (mut low_pass, mut high_pass) = get_wavelet_coefficients(wavelet)?;
359
360 low_pass.reverse();
362 high_pass.reverse();
363
364 for (i, val) in high_pass.iter_mut().enumerate() {
366 if i % 2 == 1 {
367 *val = -*val;
368 }
369 }
370
371 Ok((low_pass, high_pass))
372}
373
374fn convolve_downsample(input: &Tensor, kernel: &[f32], _mode: WaveletMode) -> TorshResult<Tensor> {
375 let input_data = input.data()?;
376 let input_len = input_data.len();
377 let _kernel_len = kernel.len();
378
379 let output_len = (input_len + 1) / 2;
381 let mut output = Vec::with_capacity(output_len);
382
383 for i in (0..input_len).step_by(2) {
384 let mut sum = 0.0;
385
386 for (k, &coeff) in kernel.iter().enumerate() {
387 let idx = i as i32 - k as i32;
388 if idx >= 0 && (idx as usize) < input_len {
389 sum += input_data[idx as usize] * coeff;
390 }
391 }
392
393 output.push(sum);
394 }
395
396 Tensor::from_data(output, vec![output_len], input.device())
397}
398
399fn upsample_convolve(input: &Tensor, kernel: &[f32], _mode: WaveletMode) -> TorshResult<Tensor> {
400 let input_data = input.data()?;
401 let input_len = input_data.len();
402 let kernel_len = kernel.len();
403
404 let upsampled_len = input_len * 2;
406 let mut upsampled = vec![0.0; upsampled_len];
407
408 for (i, &val) in input_data.iter().enumerate() {
409 upsampled[i * 2] = val;
410 }
411
412 let output_len = upsampled_len + kernel_len - 1;
414 let mut output = vec![0.0; output_len];
415
416 for i in 0..upsampled_len {
417 for (k, &coeff) in kernel.iter().enumerate() {
418 output[i + k] += upsampled[i] * coeff;
419 }
420 }
421
422 let trimmed_len = (output_len).min(upsampled_len);
424 output.truncate(trimmed_len);
425
426 Tensor::from_data(output, vec![trimmed_len], input.device())
427}
428
429fn stack_tensors(tensors: &[Tensor], dim: usize) -> TorshResult<Tensor> {
430 if tensors.is_empty() {
431 return Err(TorshError::InvalidArgument(
432 "Cannot stack empty tensor list".to_string(),
433 ));
434 }
435
436 let first_shape = tensors[0].shape();
437 let mut stacked_shape = first_shape.dims().to_vec();
438 stacked_shape.insert(dim, tensors.len());
439
440 let element_count = first_shape.numel();
441 let mut stacked_data = Vec::with_capacity(tensors.len() * element_count);
442
443 for tensor in tensors {
444 let data = tensor.data()?;
445 stacked_data.extend_from_slice(&data);
446 }
447
448 Tensor::from_data(stacked_data, stacked_shape, tensors[0].device())
449}
450
451fn generate_wavelet_kernel(wavelet: WaveletType, scale: f32, length: usize) -> TorshResult<Tensor> {
452 match wavelet {
453 WaveletType::MexicanHat => {
454 let mut kernel = Vec::with_capacity(length);
455 let center = length as f32 / 2.0;
456
457 for i in 0..length {
458 let t = (i as f32 - center) / scale;
459 let t2 = t * t;
460 let val = (2.0 / (3.0 * scale).sqrt() * std::f32::consts::PI.powf(0.25))
461 * (1.0 - t2)
462 * (-t2 / 2.0).exp();
463 kernel.push(val);
464 }
465
466 Tensor::from_data(kernel, vec![length], torsh_core::device::DeviceType::Cpu)
467 }
468 WaveletType::Morlet => {
469 let mut kernel = Vec::with_capacity(length);
470 let center = length as f32 / 2.0;
471 let omega0 = 6.0; for i in 0..length {
474 let t = (i as f32 - center) / scale;
475 let val = (1.0 / (scale * std::f32::consts::PI.sqrt()))
476 * (omega0 * t).cos()
477 * (-(t * t) / 2.0).exp();
478 kernel.push(val);
479 }
480
481 Tensor::from_data(kernel, vec![length], torsh_core::device::DeviceType::Cpu)
482 }
483 _ => Err(TorshError::UnsupportedOperation {
484 op: format!("CWT with {:?}", wavelet),
485 dtype: "wavelet".to_string(),
486 }),
487 }
488}
489
490fn convolve_same(input: &Tensor, kernel: &Tensor) -> TorshResult<Tensor> {
491 let input_data = input.data()?;
492 let kernel_data = kernel.data()?;
493 let input_len = input_data.len();
494 let kernel_len = kernel_data.len();
495
496 let mut output = vec![0.0; input_len];
497 let half_kernel = kernel_len / 2;
498
499 for i in 0..input_len {
500 for j in 0..kernel_len {
501 let input_idx = i as i32 + j as i32 - half_kernel as i32;
502 if input_idx >= 0 && (input_idx as usize) < input_len {
503 output[i] += input_data[input_idx as usize] * kernel_data[j];
504 }
505 }
506 }
507
508 Tensor::from_data(output, vec![input_len], input.device())
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514 use torsh_tensor::creation::tensor_1d;
515
516 #[test]
517 fn test_haar_dwt_1d() {
518 let input = tensor_1d(&[1.0, 2.0, 3.0, 4.0]).unwrap();
519 let (approx, detail) = dwt_1d(&input, WaveletType::Haar, WaveletMode::Zero).unwrap();
520
521 assert_eq!(approx.shape().dims(), &[2]);
523 assert_eq!(detail.shape().dims(), &[2]);
524
525 let reconstructed =
527 idwt_1d(&approx, &detail, WaveletType::Haar, WaveletMode::Zero).unwrap();
528 assert_eq!(reconstructed.shape().dims(), &[4]);
529 }
530
531 #[test]
532 fn test_daubechies4_coefficients() {
533 let (low_pass, high_pass) = get_wavelet_coefficients(WaveletType::Daubechies(4)).unwrap();
534
535 assert_eq!(low_pass.len(), 4);
537 assert_eq!(high_pass.len(), 4);
538
539 let low_energy: f32 = low_pass.iter().map(|x| x * x).sum();
541 let high_energy: f32 = high_pass.iter().map(|x| x * x).sum();
542
543 assert!((low_energy - 1.0).abs() < 1e-6);
544 assert!((high_energy - 1.0).abs() < 1e-6);
545 }
546
547 #[test]
548 fn test_multilevel_decomposition() {
549 let input = tensor_1d(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]).unwrap();
550 let coeffs = wavedec(&input, WaveletType::Haar, 2, WaveletMode::Zero).unwrap();
551
552 assert_eq!(coeffs.len(), 3);
554
555 let reconstructed = waverec(&coeffs, WaveletType::Haar, WaveletMode::Zero).unwrap();
557
558 assert!(reconstructed.shape().dims()[0] >= 4);
560 }
561
562 #[test]
563 fn test_cwt_mexican_hat() {
564 let input = tensor_1d(&[0.0, 0.0, 1.0, 0.0, 0.0, 0.0]).unwrap();
565 let scales = vec![1.0, 2.0, 3.0];
566 let result = cwt(&input, &scales, WaveletType::MexicanHat).unwrap();
567
568 assert_eq!(result.shape().dims(), &[3, 6]);
570 }
571}