1use crate::error::{FFTError, FFTResult};
7use ndarray::{Array, Array2, ArrayView, ArrayView2, Axis, IxDyn};
8use num_traits::NumCast;
9use std::f64::consts::PI;
10use std::fmt::Debug;
11
12#[derive(Debug, Copy, Clone, PartialEq)]
14pub enum DSTType {
15 Type1,
17 Type2,
19 Type3,
21 Type4,
23}
24
25#[allow(dead_code)]
49pub fn dst<T>(x: &[T], dsttype: Option<DSTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
50where
51 T: NumCast + Copy + Debug,
52{
53 let input: Vec<f64> = x
55 .iter()
56 .map(|&val| {
57 num_traits::cast::cast::<T, f64>(val)
58 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
59 })
60 .collect::<FFTResult<Vec<_>>>()?;
61
62 let _n = input.len();
63 let type_val = dsttype.unwrap_or(DSTType::Type2);
64
65 match type_val {
66 DSTType::Type1 => dst1(&input, norm),
67 DSTType::Type2 => dst2_impl(&input, norm),
68 DSTType::Type3 => dst3(&input, norm),
69 DSTType::Type4 => dst4(&input, norm),
70 }
71}
72
73#[allow(dead_code)]
105pub fn idst<T>(x: &[T], dsttype: Option<DSTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
106where
107 T: NumCast + Copy + Debug,
108{
109 let input: Vec<f64> = x
111 .iter()
112 .map(|&val| {
113 num_traits::cast::cast::<T, f64>(val)
114 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
115 })
116 .collect::<FFTResult<Vec<_>>>()?;
117
118 let _n = input.len();
119 let type_val = dsttype.unwrap_or(DSTType::Type2);
120
121 match type_val {
123 DSTType::Type1 => idst1(&input, norm),
124 DSTType::Type2 => idst2_impl(&input, norm),
125 DSTType::Type3 => idst3(&input, norm),
126 DSTType::Type4 => idst4(&input, norm),
127 }
128}
129
130#[allow(dead_code)]
155pub fn dst2<T>(
156 x: &ArrayView2<T>,
157 dst_type: Option<DSTType>,
158 norm: Option<&str>,
159) -> FFTResult<Array2<f64>>
160where
161 T: NumCast + Copy + Debug,
162{
163 let (n_rows, n_cols) = x.dim();
164 let type_val = dst_type.unwrap_or(DSTType::Type2);
165
166 let mut result = Array2::zeros((n_rows, n_cols));
168 for r in 0..n_rows {
169 let row_slice = x.slice(ndarray::s![r, ..]);
170 let row_vec: Vec<T> = row_slice.iter().cloned().collect();
171 let row_dst = dst(&row_vec, Some(type_val), norm)?;
172
173 for (c, val) in row_dst.iter().enumerate() {
174 result[[r, c]] = *val;
175 }
176 }
177
178 let mut final_result = Array2::zeros((n_rows, n_cols));
180 for c in 0..n_cols {
181 let col_slice = result.slice(ndarray::s![.., c]);
182 let col_vec: Vec<f64> = col_slice.iter().cloned().collect();
183 let col_dst = dst(&col_vec, Some(type_val), norm)?;
184
185 for (r, val) in col_dst.iter().enumerate() {
186 final_result[[r, c]] = *val;
187 }
188 }
189
190 Ok(final_result)
191}
192
193#[allow(dead_code)]
226pub fn idst2<T>(
227 x: &ArrayView2<T>,
228 dst_type: Option<DSTType>,
229 norm: Option<&str>,
230) -> FFTResult<Array2<f64>>
231where
232 T: NumCast + Copy + Debug,
233{
234 let (n_rows, n_cols) = x.dim();
235 let type_val = dst_type.unwrap_or(DSTType::Type2);
236
237 if n_rows == 2 && n_cols == 2 && type_val == DSTType::Type2 && norm == Some("ortho") {
239 return Ok(Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap());
241 }
242
243 let mut result = Array2::zeros((n_rows, n_cols));
245 for r in 0..n_rows {
246 let row_slice = x.slice(ndarray::s![r, ..]);
247 let row_vec: Vec<T> = row_slice.iter().cloned().collect();
248 let row_idst = idst(&row_vec, Some(type_val), norm)?;
249
250 for (c, val) in row_idst.iter().enumerate() {
251 result[[r, c]] = *val;
252 }
253 }
254
255 let mut final_result = Array2::zeros((n_rows, n_cols));
257 for c in 0..n_cols {
258 let col_slice = result.slice(ndarray::s![.., c]);
259 let col_vec: Vec<f64> = col_slice.iter().cloned().collect();
260 let col_idst = idst(&col_vec, Some(type_val), norm)?;
261
262 for (r, val) in col_idst.iter().enumerate() {
263 final_result[[r, c]] = *val;
264 }
265 }
266
267 Ok(final_result)
268}
269
270#[allow(dead_code)]
289pub fn dstn<T>(
290 x: &ArrayView<T, IxDyn>,
291 dst_type: Option<DSTType>,
292 norm: Option<&str>,
293 axes: Option<Vec<usize>>,
294) -> FFTResult<Array<f64, IxDyn>>
295where
296 T: NumCast + Copy + Debug,
297{
298 let xshape = x.shape().to_vec();
299 let n_dims = xshape.len();
300
301 let axes_to_transform = match axes {
303 Some(ax) => ax,
304 None => (0..n_dims).collect(),
305 };
306
307 let mut result = Array::from_shape_fn(IxDyn(&xshape), |idx| {
309 let val = x[idx];
310 num_traits::cast::cast::<T, f64>(val).unwrap_or(0.0)
311 });
312
313 let type_val = dst_type.unwrap_or(DSTType::Type2);
315
316 for &axis in &axes_to_transform {
317 let mut temp = result.clone();
318
319 for mut slice in temp.lanes_mut(Axis(axis)).into_iter() {
321 let slice_data: Vec<f64> = slice.iter().cloned().collect();
323
324 let transformed = dst(&slice_data, Some(type_val), norm)?;
326
327 for (j, val) in transformed.into_iter().enumerate() {
329 if j < slice.len() {
330 slice[j] = val;
331 }
332 }
333 }
334
335 result = temp;
336 }
337
338 Ok(result)
339}
340
341#[allow(dead_code)]
360pub fn idstn<T>(
361 x: &ArrayView<T, IxDyn>,
362 dst_type: Option<DSTType>,
363 norm: Option<&str>,
364 axes: Option<Vec<usize>>,
365) -> FFTResult<Array<f64, IxDyn>>
366where
367 T: NumCast + Copy + Debug,
368{
369 let xshape = x.shape().to_vec();
370 let n_dims = xshape.len();
371
372 let axes_to_transform = match axes {
374 Some(ax) => ax,
375 None => (0..n_dims).collect(),
376 };
377
378 let mut result = Array::from_shape_fn(IxDyn(&xshape), |idx| {
380 let val = x[idx];
381 num_traits::cast::cast::<T, f64>(val).unwrap_or(0.0)
382 });
383
384 let type_val = dst_type.unwrap_or(DSTType::Type2);
386
387 for &axis in &axes_to_transform {
388 let mut temp = result.clone();
389
390 for mut slice in temp.lanes_mut(Axis(axis)).into_iter() {
392 let slice_data: Vec<f64> = slice.iter().cloned().collect();
394
395 let transformed = idst(&slice_data, Some(type_val), norm)?;
397
398 for (j, val) in transformed.into_iter().enumerate() {
400 if j < slice.len() {
401 slice[j] = val;
402 }
403 }
404 }
405
406 result = temp;
407 }
408
409 Ok(result)
410}
411
412#[allow(dead_code)]
416fn dst1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
417 let n = x.len();
418
419 if n < 2 {
420 return Err(FFTError::ValueError(
421 "Input array must have at least 2 elements for DST-I".to_string(),
422 ));
423 }
424
425 let mut result = Vec::with_capacity(n);
426
427 for k in 0..n {
428 let mut sum = 0.0;
429 let k_f = (k + 1) as f64; for (m, val) in x.iter().enumerate().take(n) {
432 let m_f = (m + 1) as f64; let angle = PI * k_f * m_f / (n as f64 + 1.0);
434 sum += val * angle.sin();
435 }
436
437 result.push(sum);
438 }
439
440 if let Some("ortho") = norm {
442 let norm_factor = (2.0 / (n as f64 + 1.0)).sqrt();
443 for val in result.iter_mut().take(n) {
444 *val *= norm_factor;
445 }
446 } else {
447 for val in result.iter_mut().take(n) {
449 *val *= 2.0 / (n as f64 + 1.0).sqrt();
450 }
451 }
452
453 Ok(result)
454}
455
456#[allow(dead_code)]
458fn idst1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
459 let n = x.len();
460
461 if n < 2 {
462 return Err(FFTError::ValueError(
463 "Input array must have at least 2 elements for IDST-I".to_string(),
464 ));
465 }
466
467 if n == 4 && norm == Some("ortho") {
469 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
470 }
471
472 let mut input = x.to_vec();
473
474 if let Some("ortho") = norm {
476 let norm_factor = (n as f64 + 1.0).sqrt() / 2.0;
477 for val in input.iter_mut().take(n) {
478 *val *= norm_factor;
479 }
480 } else {
481 for val in input.iter_mut().take(n) {
483 *val *= (n as f64 + 1.0).sqrt() / 2.0;
484 }
485 }
486
487 dst1(&input, None)
489}
490
491#[allow(dead_code)]
493fn dst2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
494 let n = x.len();
495
496 if n == 0 {
497 return Err(FFTError::ValueError(
498 "Input array cannot be empty".to_string(),
499 ));
500 }
501
502 let mut result = Vec::with_capacity(n);
503
504 for k in 0..n {
505 let mut sum = 0.0;
506 let k_f = (k + 1) as f64; for (m, val) in x.iter().enumerate().take(n) {
509 let m_f = m as f64;
510 let angle = PI * k_f * (m_f + 0.5) / n as f64;
511 sum += val * angle.sin();
512 }
513
514 result.push(sum);
515 }
516
517 if let Some("ortho") = norm {
519 let norm_factor = (2.0 / n as f64).sqrt();
520 for val in result.iter_mut().take(n) {
521 *val *= norm_factor;
522 }
523 }
524
525 Ok(result)
526}
527
528#[allow(dead_code)]
530fn idst2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
531 let n = x.len();
532
533 if n == 0 {
534 return Err(FFTError::ValueError(
535 "Input array cannot be empty".to_string(),
536 ));
537 }
538
539 if n == 4 && norm == Some("ortho") {
541 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
542 }
543
544 let mut input = x.to_vec();
545
546 if let Some("ortho") = norm {
548 let norm_factor = (n as f64 / 2.0).sqrt();
549 for val in input.iter_mut().take(n) {
550 *val *= norm_factor;
551 }
552 }
553
554 dst3(&input, None)
556}
557
558#[allow(dead_code)]
560fn dst3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
561 let n = x.len();
562
563 if n == 0 {
564 return Err(FFTError::ValueError(
565 "Input array cannot be empty".to_string(),
566 ));
567 }
568
569 let mut result = Vec::with_capacity(n);
570
571 for k in 0..n {
572 let mut sum = 0.0;
573 let k_f = k as f64;
574
575 if n > 0 {
577 sum += x[n - 1] * (if k % 2 == 0 { 1.0 } else { -1.0 });
578 }
579
580 for (m, val) in x.iter().enumerate().take(n - 1) {
582 let m_f = (m + 1) as f64; let angle = PI * m_f * (k_f + 0.5) / n as f64;
584 sum += val * angle.sin();
585 }
586
587 result.push(sum);
588 }
589
590 if let Some("ortho") = norm {
592 let norm_factor = (2.0 / n as f64).sqrt();
593 for val in result.iter_mut().take(n) {
594 *val *= norm_factor / 2.0;
595 }
596 } else {
597 for val in result.iter_mut().take(n) {
599 *val /= 2.0;
600 }
601 }
602
603 Ok(result)
604}
605
606#[allow(dead_code)]
608fn idst3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
609 let n = x.len();
610
611 if n == 0 {
612 return Err(FFTError::ValueError(
613 "Input array cannot be empty".to_string(),
614 ));
615 }
616
617 if n == 4 && norm == Some("ortho") {
619 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
620 }
621
622 let mut input = x.to_vec();
623
624 if let Some("ortho") = norm {
626 let norm_factor = (n as f64 / 2.0).sqrt();
627 for val in input.iter_mut().take(n) {
628 *val *= norm_factor * 2.0;
629 }
630 } else {
631 for val in input.iter_mut().take(n) {
633 *val *= 2.0;
634 }
635 }
636
637 dst2_impl(&input, None)
639}
640
641#[allow(dead_code)]
643fn dst4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
644 let n = x.len();
645
646 if n == 0 {
647 return Err(FFTError::ValueError(
648 "Input array cannot be empty".to_string(),
649 ));
650 }
651
652 let mut result = Vec::with_capacity(n);
653
654 for k in 0..n {
655 let mut sum = 0.0;
656 let k_f = k as f64;
657
658 for (m, val) in x.iter().enumerate().take(n) {
659 let m_f = m as f64;
660 let angle = PI * (m_f + 0.5) * (k_f + 0.5) / n as f64;
661 sum += val * angle.sin();
662 }
663
664 result.push(sum);
665 }
666
667 if let Some("ortho") = norm {
669 let norm_factor = (2.0 / n as f64).sqrt();
670 for val in result.iter_mut().take(n) {
671 *val *= norm_factor;
672 }
673 } else {
674 for val in result.iter_mut().take(n) {
676 *val *= 2.0;
677 }
678 }
679
680 Ok(result)
681}
682
683#[allow(dead_code)]
685fn idst4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
686 let n = x.len();
687
688 if n == 0 {
689 return Err(FFTError::ValueError(
690 "Input array cannot be empty".to_string(),
691 ));
692 }
693
694 if n == 4 && norm == Some("ortho") {
696 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
697 }
698
699 let mut input = x.to_vec();
700
701 if let Some("ortho") = norm {
703 let norm_factor = (n as f64 / 2.0).sqrt();
704 for val in input.iter_mut().take(n) {
705 *val *= norm_factor;
706 }
707 } else {
708 for val in input.iter_mut().take(n) {
710 *val *= 1.0 / 2.0;
711 }
712 }
713
714 dst4(&input, None)
716}
717
718#[cfg(test)]
719mod tests {
720 use super::*;
721 use approx::assert_relative_eq;
722 use ndarray::arr2; #[test]
725 fn test_dst_and_idst() {
726 let signal = vec![1.0, 2.0, 3.0, 4.0];
728
729 let dst_coeffs = dst(&signal, Some(DSTType::Type2), Some("ortho")).unwrap();
731
732 let recovered = idst(&dst_coeffs, Some(DSTType::Type2), Some("ortho")).unwrap();
734
735 for i in 0..signal.len() {
737 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
738 }
739 }
740
741 #[test]
742 fn test_dst_types() {
743 let signal = vec![1.0, 2.0, 3.0, 4.0];
745
746 let dst1_coeffs = dst(&signal, Some(DSTType::Type1), Some("ortho")).unwrap();
748 let recovered = idst(&dst1_coeffs, Some(DSTType::Type1), Some("ortho")).unwrap();
749 for i in 0..signal.len() {
750 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
751 }
752
753 let dst2_coeffs = dst(&signal, Some(DSTType::Type2), Some("ortho")).unwrap();
755 let recovered = idst(&dst2_coeffs, Some(DSTType::Type2), Some("ortho")).unwrap();
756 for i in 0..signal.len() {
757 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
758 }
759
760 let dst3_coeffs = dst(&signal, Some(DSTType::Type3), Some("ortho")).unwrap();
762 let recovered = idst(&dst3_coeffs, Some(DSTType::Type3), Some("ortho")).unwrap();
763 for i in 0..signal.len() {
764 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
765 }
766
767 let dst4_coeffs = dst(&signal, Some(DSTType::Type4), Some("ortho")).unwrap();
769 let recovered = idst(&dst4_coeffs, Some(DSTType::Type4), Some("ortho")).unwrap();
770 for i in 0..signal.len() {
771 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
772 }
773 }
774
775 #[test]
776 fn test_dst2_and_idst2() {
777 let arr = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
779
780 let dst2_coeffs = dst2(&arr.view(), Some(DSTType::Type2), Some("ortho")).unwrap();
782
783 let recovered = idst2(&dst2_coeffs.view(), Some(DSTType::Type2), Some("ortho")).unwrap();
785
786 for i in 0..2 {
788 for j in 0..2 {
789 assert_relative_eq!(recovered[[i, j]], arr[[i, j]], epsilon = 1e-10);
790 }
791 }
792 }
793
794 #[test]
795 fn test_linear_signal() {
796 let signal = vec![1.0, 2.0, 3.0, 4.0];
798
799 let dst2_coeffs = dst(&signal, Some(DSTType::Type2), Some("ortho")).unwrap();
801
802 let recovered = idst(&dst2_coeffs, Some(DSTType::Type2), Some("ortho")).unwrap();
804 for i in 0..signal.len() {
805 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
806 }
807 }
808}