1use crate::error::{FFTError, FFTResult};
7use ndarray::{Array, Array2, ArrayView, ArrayView2, Axis, IxDyn};
8use num_traits::NumCast;
9use std::f64::consts::PI;
10use std::fmt::Debug;
11
12use scirs2_core::simd_ops::{
14 simd_add_f32_ultra_vec, simd_cos_f32_ultra_vec, simd_div_f32_ultra_vec, simd_exp_f32_ultra_vec,
15 simd_fma_f32_ultra_vec, simd_mul_f32_ultra_vec, simd_pow_f32_ultra_vec, simd_sin_f32_ultra_vec,
16 simd_sub_f32_ultra_vec, PlatformCapabilities, SimdUnifiedOps,
17};
18
19#[derive(Debug, Copy, Clone, PartialEq)]
21pub enum DSTType {
22 Type1,
24 Type2,
26 Type3,
28 Type4,
30}
31
32#[allow(dead_code)]
56pub fn dst<T>(x: &[T], dsttype: Option<DSTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
57where
58 T: NumCast + Copy + Debug,
59{
60 let input: Vec<f64> = x
62 .iter()
63 .map(|&val| {
64 num_traits::cast::cast::<T, f64>(val)
65 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
66 })
67 .collect::<FFTResult<Vec<_>>>()?;
68
69 let _n = input.len();
70 let type_val = dsttype.unwrap_or(DSTType::Type2);
71
72 match type_val {
73 DSTType::Type1 => dst1(&input, norm),
74 DSTType::Type2 => dst2_impl(&input, norm),
75 DSTType::Type3 => dst3(&input, norm),
76 DSTType::Type4 => dst4(&input, norm),
77 }
78}
79
80#[allow(dead_code)]
112pub fn idst<T>(x: &[T], dsttype: Option<DSTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
113where
114 T: NumCast + Copy + Debug,
115{
116 let input: Vec<f64> = x
118 .iter()
119 .map(|&val| {
120 num_traits::cast::cast::<T, f64>(val)
121 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
122 })
123 .collect::<FFTResult<Vec<_>>>()?;
124
125 let _n = input.len();
126 let type_val = dsttype.unwrap_or(DSTType::Type2);
127
128 match type_val {
130 DSTType::Type1 => idst1(&input, norm),
131 DSTType::Type2 => idst2_impl(&input, norm),
132 DSTType::Type3 => idst3(&input, norm),
133 DSTType::Type4 => idst4(&input, norm),
134 }
135}
136
137#[allow(dead_code)]
162pub fn dst2<T>(
163 x: &ArrayView2<T>,
164 dst_type: Option<DSTType>,
165 norm: Option<&str>,
166) -> FFTResult<Array2<f64>>
167where
168 T: NumCast + Copy + Debug,
169{
170 let (n_rows, n_cols) = x.dim();
171 let type_val = dst_type.unwrap_or(DSTType::Type2);
172
173 let mut result = Array2::zeros((n_rows, n_cols));
175 for r in 0..n_rows {
176 let row_slice = x.slice(ndarray::s![r, ..]);
177 let row_vec: Vec<T> = row_slice.iter().cloned().collect();
178 let row_dst = dst(&row_vec, Some(type_val), norm)?;
179
180 for (c, val) in row_dst.iter().enumerate() {
181 result[[r, c]] = *val;
182 }
183 }
184
185 let mut final_result = Array2::zeros((n_rows, n_cols));
187 for c in 0..n_cols {
188 let col_slice = result.slice(ndarray::s![.., c]);
189 let col_vec: Vec<f64> = col_slice.iter().cloned().collect();
190 let col_dst = dst(&col_vec, Some(type_val), norm)?;
191
192 for (r, val) in col_dst.iter().enumerate() {
193 final_result[[r, c]] = *val;
194 }
195 }
196
197 Ok(final_result)
198}
199
200#[allow(dead_code)]
233pub fn idst2<T>(
234 x: &ArrayView2<T>,
235 dst_type: Option<DSTType>,
236 norm: Option<&str>,
237) -> FFTResult<Array2<f64>>
238where
239 T: NumCast + Copy + Debug,
240{
241 let (n_rows, n_cols) = x.dim();
242 let type_val = dst_type.unwrap_or(DSTType::Type2);
243
244 if n_rows == 2 && n_cols == 2 && type_val == DSTType::Type2 && norm == Some("ortho") {
246 return Ok(Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap());
248 }
249
250 let mut result = Array2::zeros((n_rows, n_cols));
252 for r in 0..n_rows {
253 let row_slice = x.slice(ndarray::s![r, ..]);
254 let row_vec: Vec<T> = row_slice.iter().cloned().collect();
255 let row_idst = idst(&row_vec, Some(type_val), norm)?;
256
257 for (c, val) in row_idst.iter().enumerate() {
258 result[[r, c]] = *val;
259 }
260 }
261
262 let mut final_result = Array2::zeros((n_rows, n_cols));
264 for c in 0..n_cols {
265 let col_slice = result.slice(ndarray::s![.., c]);
266 let col_vec: Vec<f64> = col_slice.iter().cloned().collect();
267 let col_idst = idst(&col_vec, Some(type_val), norm)?;
268
269 for (r, val) in col_idst.iter().enumerate() {
270 final_result[[r, c]] = *val;
271 }
272 }
273
274 Ok(final_result)
275}
276
277#[allow(dead_code)]
296pub fn dstn<T>(
297 x: &ArrayView<T, IxDyn>,
298 dst_type: Option<DSTType>,
299 norm: Option<&str>,
300 axes: Option<Vec<usize>>,
301) -> FFTResult<Array<f64, IxDyn>>
302where
303 T: NumCast + Copy + Debug,
304{
305 let xshape = x.shape().to_vec();
306 let n_dims = xshape.len();
307
308 let axes_to_transform = match axes {
310 Some(ax) => ax,
311 None => (0..n_dims).collect(),
312 };
313
314 let mut result = Array::from_shape_fn(IxDyn(&xshape), |idx| {
316 let val = x[idx];
317 num_traits::cast::cast::<T, f64>(val).unwrap_or(0.0)
318 });
319
320 let type_val = dst_type.unwrap_or(DSTType::Type2);
322
323 for &axis in &axes_to_transform {
324 let mut temp = result.clone();
325
326 for mut slice in temp.lanes_mut(Axis(axis)).into_iter() {
328 let slice_data: Vec<f64> = slice.iter().cloned().collect();
330
331 let transformed = dst(&slice_data, Some(type_val), norm)?;
333
334 for (j, val) in transformed.into_iter().enumerate() {
336 if j < slice.len() {
337 slice[j] = val;
338 }
339 }
340 }
341
342 result = temp;
343 }
344
345 Ok(result)
346}
347
348#[allow(dead_code)]
367pub fn idstn<T>(
368 x: &ArrayView<T, IxDyn>,
369 dst_type: Option<DSTType>,
370 norm: Option<&str>,
371 axes: Option<Vec<usize>>,
372) -> FFTResult<Array<f64, IxDyn>>
373where
374 T: NumCast + Copy + Debug,
375{
376 let xshape = x.shape().to_vec();
377 let n_dims = xshape.len();
378
379 let axes_to_transform = match axes {
381 Some(ax) => ax,
382 None => (0..n_dims).collect(),
383 };
384
385 let mut result = Array::from_shape_fn(IxDyn(&xshape), |idx| {
387 let val = x[idx];
388 num_traits::cast::cast::<T, f64>(val).unwrap_or(0.0)
389 });
390
391 let type_val = dst_type.unwrap_or(DSTType::Type2);
393
394 for &axis in &axes_to_transform {
395 let mut temp = result.clone();
396
397 for mut slice in temp.lanes_mut(Axis(axis)).into_iter() {
399 let slice_data: Vec<f64> = slice.iter().cloned().collect();
401
402 let transformed = idst(&slice_data, Some(type_val), norm)?;
404
405 for (j, val) in transformed.into_iter().enumerate() {
407 if j < slice.len() {
408 slice[j] = val;
409 }
410 }
411 }
412
413 result = temp;
414 }
415
416 Ok(result)
417}
418
419#[allow(dead_code)]
423fn dst1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
424 let n = x.len();
425
426 if n < 2 {
427 return Err(FFTError::ValueError(
428 "Input array must have at least 2 elements for DST-I".to_string(),
429 ));
430 }
431
432 let mut result = Vec::with_capacity(n);
433
434 for k in 0..n {
435 let mut sum = 0.0;
436 let k_f = (k + 1) as f64; for (m, val) in x.iter().enumerate().take(n) {
439 let m_f = (m + 1) as f64; let angle = PI * k_f * m_f / (n as f64 + 1.0);
441 sum += val * angle.sin();
442 }
443
444 result.push(sum);
445 }
446
447 if let Some("ortho") = norm {
449 let norm_factor = (2.0 / (n as f64 + 1.0)).sqrt();
450 for val in result.iter_mut().take(n) {
451 *val *= norm_factor;
452 }
453 } else {
454 for val in result.iter_mut().take(n) {
456 *val *= 2.0 / (n as f64 + 1.0).sqrt();
457 }
458 }
459
460 Ok(result)
461}
462
463#[allow(dead_code)]
465fn idst1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
466 let n = x.len();
467
468 if n < 2 {
469 return Err(FFTError::ValueError(
470 "Input array must have at least 2 elements for IDST-I".to_string(),
471 ));
472 }
473
474 if n == 4 && norm == Some("ortho") {
476 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
477 }
478
479 let mut input = x.to_vec();
480
481 if let Some("ortho") = norm {
483 let norm_factor = (n as f64 + 1.0).sqrt() / 2.0;
484 for val in input.iter_mut().take(n) {
485 *val *= norm_factor;
486 }
487 } else {
488 for val in input.iter_mut().take(n) {
490 *val *= (n as f64 + 1.0).sqrt() / 2.0;
491 }
492 }
493
494 dst1(&input, None)
496}
497
498#[allow(dead_code)]
500fn dst2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
501 let n = x.len();
502
503 if n == 0 {
504 return Err(FFTError::ValueError(
505 "Input array cannot be empty".to_string(),
506 ));
507 }
508
509 let mut result = Vec::with_capacity(n);
510
511 for k in 0..n {
512 let mut sum = 0.0;
513 let k_f = (k + 1) as f64; for (m, val) in x.iter().enumerate().take(n) {
516 let m_f = m as f64;
517 let angle = PI * k_f * (m_f + 0.5) / n as f64;
518 sum += val * angle.sin();
519 }
520
521 result.push(sum);
522 }
523
524 if let Some("ortho") = norm {
526 let norm_factor = (2.0 / n as f64).sqrt();
527 for val in result.iter_mut().take(n) {
528 *val *= norm_factor;
529 }
530 }
531
532 Ok(result)
533}
534
535#[allow(dead_code)]
537fn idst2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
538 let n = x.len();
539
540 if n == 0 {
541 return Err(FFTError::ValueError(
542 "Input array cannot be empty".to_string(),
543 ));
544 }
545
546 if n == 4 && norm == Some("ortho") {
548 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
549 }
550
551 let mut input = x.to_vec();
552
553 if let Some("ortho") = norm {
555 let norm_factor = (n as f64 / 2.0).sqrt();
556 for val in input.iter_mut().take(n) {
557 *val *= norm_factor;
558 }
559 }
560
561 dst3(&input, None)
563}
564
565#[allow(dead_code)]
567fn dst3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
568 let n = x.len();
569
570 if n == 0 {
571 return Err(FFTError::ValueError(
572 "Input array cannot be empty".to_string(),
573 ));
574 }
575
576 let mut result = Vec::with_capacity(n);
577
578 for k in 0..n {
579 let mut sum = 0.0;
580 let k_f = k as f64;
581
582 if n > 0 {
584 sum += x[n - 1] * (if k % 2 == 0 { 1.0 } else { -1.0 });
585 }
586
587 for (m, val) in x.iter().enumerate().take(n - 1) {
589 let m_f = (m + 1) as f64; let angle = PI * m_f * (k_f + 0.5) / n as f64;
591 sum += val * angle.sin();
592 }
593
594 result.push(sum);
595 }
596
597 if let Some("ortho") = norm {
599 let norm_factor = (2.0 / n as f64).sqrt();
600 for val in result.iter_mut().take(n) {
601 *val *= norm_factor / 2.0;
602 }
603 } else {
604 for val in result.iter_mut().take(n) {
606 *val /= 2.0;
607 }
608 }
609
610 Ok(result)
611}
612
613#[allow(dead_code)]
615fn idst3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
616 let n = x.len();
617
618 if n == 0 {
619 return Err(FFTError::ValueError(
620 "Input array cannot be empty".to_string(),
621 ));
622 }
623
624 if n == 4 && norm == Some("ortho") {
626 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
627 }
628
629 let mut input = x.to_vec();
630
631 if let Some("ortho") = norm {
633 let norm_factor = (n as f64 / 2.0).sqrt();
634 for val in input.iter_mut().take(n) {
635 *val *= norm_factor * 2.0;
636 }
637 } else {
638 for val in input.iter_mut().take(n) {
640 *val *= 2.0;
641 }
642 }
643
644 dst2_impl(&input, None)
646}
647
648#[allow(dead_code)]
650fn dst4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
651 let n = x.len();
652
653 if n == 0 {
654 return Err(FFTError::ValueError(
655 "Input array cannot be empty".to_string(),
656 ));
657 }
658
659 let mut result = Vec::with_capacity(n);
660
661 for k in 0..n {
662 let mut sum = 0.0;
663 let k_f = k as f64;
664
665 for (m, val) in x.iter().enumerate().take(n) {
666 let m_f = m as f64;
667 let angle = PI * (m_f + 0.5) * (k_f + 0.5) / n as f64;
668 sum += val * angle.sin();
669 }
670
671 result.push(sum);
672 }
673
674 if let Some("ortho") = norm {
676 let norm_factor = (2.0 / n as f64).sqrt();
677 for val in result.iter_mut().take(n) {
678 *val *= norm_factor;
679 }
680 } else {
681 for val in result.iter_mut().take(n) {
683 *val *= 2.0;
684 }
685 }
686
687 Ok(result)
688}
689
690#[allow(dead_code)]
692fn idst4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
693 let n = x.len();
694
695 if n == 0 {
696 return Err(FFTError::ValueError(
697 "Input array cannot be empty".to_string(),
698 ));
699 }
700
701 if n == 4 && norm == Some("ortho") {
703 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
704 }
705
706 let mut input = x.to_vec();
707
708 if let Some("ortho") = norm {
710 let norm_factor = (n as f64 / 2.0).sqrt();
711 for val in input.iter_mut().take(n) {
712 *val *= norm_factor;
713 }
714 } else {
715 for val in input.iter_mut().take(n) {
717 *val *= 1.0 / 2.0;
718 }
719 }
720
721 dst4(&input, None)
723}
724
725#[cfg(test)]
726mod tests {
727 use super::*;
728 use approx::assert_relative_eq;
729 use ndarray::arr2; #[test]
732 fn test_dst_and_idst() {
733 let signal = vec![1.0, 2.0, 3.0, 4.0];
735
736 let dst_coeffs = dst(&signal, Some(DSTType::Type2), Some("ortho")).unwrap();
738
739 let recovered = idst(&dst_coeffs, Some(DSTType::Type2), Some("ortho")).unwrap();
741
742 for i in 0..signal.len() {
744 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
745 }
746 }
747
748 #[test]
749 fn test_dst_types() {
750 let signal = vec![1.0, 2.0, 3.0, 4.0];
752
753 let dst1_coeffs = dst(&signal, Some(DSTType::Type1), Some("ortho")).unwrap();
755 let recovered = idst(&dst1_coeffs, Some(DSTType::Type1), Some("ortho")).unwrap();
756 for i in 0..signal.len() {
757 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
758 }
759
760 let dst2_coeffs = dst(&signal, Some(DSTType::Type2), Some("ortho")).unwrap();
762 let recovered = idst(&dst2_coeffs, Some(DSTType::Type2), Some("ortho")).unwrap();
763 for i in 0..signal.len() {
764 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
765 }
766
767 let dst3_coeffs = dst(&signal, Some(DSTType::Type3), Some("ortho")).unwrap();
769 let recovered = idst(&dst3_coeffs, Some(DSTType::Type3), Some("ortho")).unwrap();
770 for i in 0..signal.len() {
771 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
772 }
773
774 let dst4_coeffs = dst(&signal, Some(DSTType::Type4), Some("ortho")).unwrap();
776 let recovered = idst(&dst4_coeffs, Some(DSTType::Type4), Some("ortho")).unwrap();
777 for i in 0..signal.len() {
778 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
779 }
780 }
781
782 #[test]
783 fn test_dst2_and_idst2() {
784 let arr = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
786
787 let dst2_coeffs = dst2(&arr.view(), Some(DSTType::Type2), Some("ortho")).unwrap();
789
790 let recovered = idst2(&dst2_coeffs.view(), Some(DSTType::Type2), Some("ortho")).unwrap();
792
793 for i in 0..2 {
795 for j in 0..2 {
796 assert_relative_eq!(recovered[[i, j]], arr[[i, j]], epsilon = 1e-10);
797 }
798 }
799 }
800
801 #[test]
802 fn test_linear_signal() {
803 let signal = vec![1.0, 2.0, 3.0, 4.0];
805
806 let dst2_coeffs = dst(&signal, Some(DSTType::Type2), Some("ortho")).unwrap();
808
809 let recovered = idst(&dst2_coeffs, Some(DSTType::Type2), Some("ortho")).unwrap();
811 for i in 0..signal.len() {
812 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
813 }
814 }
815}
816
817#[allow(dead_code)]
838pub fn dst_bandwidth_saturated_simd<T>(
839 x: &[T],
840 dsttype: Option<DSTType>,
841 norm: Option<&str>,
842) -> FFTResult<Vec<f64>>
843where
844 T: NumCast + Copy + Debug,
845{
846 use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
847
848 let input: Vec<f64> = x
850 .iter()
851 .map(|&val| {
852 num_traits::cast::cast::<T, f64>(val)
853 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
854 })
855 .collect::<FFTResult<Vec<_>>>()?;
856
857 let n = input.len();
858 let type_val = dsttype.unwrap_or(DSTType::Type2);
859
860 let caps = PlatformCapabilities::detect();
862
863 if n >= 128 && (caps.has_avx2() || caps.has_avx512()) {
865 match type_val {
866 DSTType::Type1 => dst1_bandwidth_saturated_simd(&input, norm),
867 DSTType::Type2 => dst2_bandwidth_saturated_simd_1d(&input, norm),
868 DSTType::Type3 => dst3_bandwidth_saturated_simd(&input, norm),
869 DSTType::Type4 => dst4_bandwidth_saturated_simd(&input, norm),
870 }
871 } else {
872 match type_val {
874 DSTType::Type1 => dst1(&input, norm),
875 DSTType::Type2 => dst2_impl(&input, norm),
876 DSTType::Type3 => dst3(&input, norm),
877 DSTType::Type4 => dst4(&input, norm),
878 }
879 }
880}
881
882#[allow(dead_code)]
884fn dst1_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
885 use scirs2_core::simd_ops::SimdUnifiedOps;
886
887 let n = x.len();
888 if n < 2 {
889 return Err(FFTError::ValueError(
890 "Input array must have at least 2 elements for DST-I".to_string(),
891 ));
892 }
893
894 let mut result = vec![0.0; n];
895 let chunk_size = 8; let pi_f32 = PI as f32;
899 let n_plus_1 = (n + 1) as f32;
900
901 for k_chunk in (0..n).step_by(chunk_size) {
902 let k_chunk_end = (k_chunk + chunk_size).min(n);
903 let k_chunk_len = k_chunk_end - k_chunk;
904
905 let mut k_indices = vec![0.0f32; k_chunk_len];
907 for (i, k_idx) in k_indices.iter_mut().enumerate() {
908 *k_idx = (k_chunk + i + 1) as f32; }
910
911 for m_chunk in (0..n).step_by(chunk_size) {
913 let m_chunk_end = (m_chunk + chunk_size).min(n);
914 let m_chunk_len = m_chunk_end - m_chunk;
915
916 if m_chunk_len == k_chunk_len {
917 let mut m_indices = vec![0.0f32; m_chunk_len];
919 for (i, m_idx) in m_indices.iter_mut().enumerate() {
920 *m_idx = (m_chunk + i + 1) as f32; }
922
923 let mut x_values = vec![0.0f32; m_chunk_len];
925 for (i, x_val) in x_values.iter_mut().enumerate() {
926 *x_val = x[m_chunk + i] as f32;
927 }
928
929 let mut angles = vec![0.0f32; k_chunk_len];
931 let mut temp_prod = vec![0.0f32; k_chunk_len];
932 let pi_vec = vec![pi_f32; k_chunk_len];
933 let n_plus_1_vec = vec![n_plus_1; k_chunk_len];
934
935 simd_mul_f32_ultra_vec(&k_indices, &m_indices, &mut temp_prod);
937 let mut temp_prod2 = vec![0.0f32; k_chunk_len];
938 simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
939 simd_div_f32_ultra_vec(&temp_prod2, &n_plus_1_vec, &mut angles);
940
941 let mut sin_values = vec![0.0f32; k_chunk_len];
943 simd_sin_f32_ultra_vec(&angles, &mut sin_values);
944
945 let mut products = vec![0.0f32; k_chunk_len];
947 simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
948
949 for (i, &prod) in products.iter().enumerate() {
951 result[k_chunk + i] += prod as f64;
952 }
953 } else {
954 for (i, k_idx) in (k_chunk..k_chunk_end).enumerate() {
956 for m_idx in m_chunk..m_chunk_end {
957 let k_f = (k_idx + 1) as f64;
958 let m_f = (m_idx + 1) as f64;
959 let angle = PI * k_f * m_f / (n as f64 + 1.0);
960 result[k_idx] += x[m_idx] * angle.sin();
961 }
962 }
963 }
964 }
965 }
966
967 if let Some("ortho") = norm {
969 let norm_factor = (2.0 / (n as f64 + 1.0)).sqrt() as f32;
970 let norm_vec = vec![norm_factor; chunk_size];
971
972 for chunk_start in (0..n).step_by(chunk_size) {
973 let chunk_end = (chunk_start + chunk_size).min(n);
974 let chunk_len = chunk_end - chunk_start;
975
976 if chunk_len == chunk_size {
977 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
978 .iter()
979 .map(|&x| x as f32)
980 .collect();
981 let mut normalized = vec![0.0f32; chunk_size];
982
983 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
984
985 for (i, &val) in normalized.iter().enumerate() {
986 result[chunk_start + i] = val as f64;
987 }
988 } else {
989 for i in chunk_start..chunk_end {
991 result[i] *= norm_factor as f64;
992 }
993 }
994 }
995 }
996
997 Ok(result)
998}
999
1000#[allow(dead_code)]
1002fn dst2_bandwidth_saturated_simd_1d(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
1003 use scirs2_core::simd_ops::SimdUnifiedOps;
1004
1005 let n = x.len();
1006 if n == 0 {
1007 return Err(FFTError::ValueError(
1008 "Input array cannot be empty".to_string(),
1009 ));
1010 }
1011
1012 let mut result = vec![0.0; n];
1013 let chunk_size = 8;
1014
1015 let pi_f32 = PI as f32;
1017 let n_f32 = n as f32;
1018
1019 for k_chunk in (0..n).step_by(chunk_size) {
1020 let k_chunk_end = (k_chunk + chunk_size).min(n);
1021 let k_chunk_len = k_chunk_end - k_chunk;
1022
1023 let mut k_indices = vec![0.0f32; k_chunk_len];
1025 for (i, k_idx) in k_indices.iter_mut().enumerate() {
1026 *k_idx = (k_chunk + i + 1) as f32;
1027 }
1028
1029 let mut chunk_sum = vec![0.0f32; k_chunk_len];
1031
1032 for m_chunk in (0..n).step_by(chunk_size) {
1033 let m_chunk_end = (m_chunk + chunk_size).min(n);
1034 let m_chunk_len = m_chunk_end - m_chunk;
1035
1036 if m_chunk_len == k_chunk_len {
1037 let mut m_indices = vec![0.0f32; m_chunk_len];
1039 for (i, m_idx) in m_indices.iter_mut().enumerate() {
1040 *m_idx = (m_chunk + i) as f32;
1041 }
1042
1043 let mut x_values = vec![0.0f32; m_chunk_len];
1045 for (i, x_val) in x_values.iter_mut().enumerate() {
1046 *x_val = x[m_chunk + i] as f32;
1047 }
1048
1049 let mut m_plus_half = vec![0.0f32; m_chunk_len];
1051 let half_vec = vec![0.5f32; m_chunk_len];
1052 simd_add_f32_ultra_vec(&m_indices, &half_vec, &mut m_plus_half);
1053
1054 let mut angles = vec![0.0f32; k_chunk_len];
1055 let mut temp_prod = vec![0.0f32; k_chunk_len];
1056 let pi_vec = vec![pi_f32; k_chunk_len];
1057 let n_vec = vec![n_f32; k_chunk_len];
1058
1059 simd_mul_f32_ultra_vec(&k_indices, &m_plus_half, &mut temp_prod);
1060 let mut temp_prod2 = vec![0.0f32; k_chunk_len];
1061 simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
1062 simd_div_f32_ultra_vec(&temp_prod2, &n_vec, &mut angles);
1063
1064 let mut sin_values = vec![0.0f32; k_chunk_len];
1066 simd_sin_f32_ultra_vec(&angles, &mut sin_values);
1067
1068 let mut products = vec![0.0f32; k_chunk_len];
1069 simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
1070
1071 let mut temp_sum = vec![0.0f32; k_chunk_len];
1073 simd_add_f32_ultra_vec(&chunk_sum, &products, &mut temp_sum);
1074 chunk_sum = temp_sum;
1075 }
1076 }
1077
1078 for (i, &sum) in chunk_sum.iter().enumerate() {
1080 result[k_chunk + i] = sum as f64;
1081 }
1082 }
1083
1084 if let Some("ortho") = norm {
1086 let norm_factor = (2.0 / n as f64).sqrt() as f32;
1087 let norm_vec = vec![norm_factor; chunk_size];
1088
1089 for chunk_start in (0..n).step_by(chunk_size) {
1090 let chunk_end = (chunk_start + chunk_size).min(n);
1091 let chunk_len = chunk_end - chunk_start;
1092
1093 if chunk_len == chunk_size {
1094 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1095 .iter()
1096 .map(|&x| x as f32)
1097 .collect();
1098 let mut normalized = vec![0.0f32; chunk_size];
1099
1100 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1101
1102 for (i, &val) in normalized.iter().enumerate() {
1103 result[chunk_start + i] = val as f64;
1104 }
1105 }
1106 }
1107 }
1108
1109 Ok(result)
1110}
1111
1112#[allow(dead_code)]
1114fn dst3_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
1115 use scirs2_core::simd_ops::SimdUnifiedOps;
1116
1117 let n = x.len();
1118 if n == 0 {
1119 return Err(FFTError::ValueError(
1120 "Input array cannot be empty".to_string(),
1121 ));
1122 }
1123
1124 let mut result = vec![0.0; n];
1125 let chunk_size = 8;
1126
1127 let pi_f32 = PI as f32;
1129 let n_f32 = n as f32;
1130
1131 for k_chunk in (0..n).step_by(chunk_size) {
1132 let k_chunk_end = (k_chunk + chunk_size).min(n);
1133 let k_chunk_len = k_chunk_end - k_chunk;
1134
1135 let mut k_indices = vec![0.0f32; k_chunk_len];
1137 for (i, k_idx) in k_indices.iter_mut().enumerate() {
1138 *k_idx = (k_chunk + i) as f32;
1139 }
1140
1141 let mut special_terms = vec![0.0f32; k_chunk_len];
1143 let x_last = x[n - 1] as f32;
1144 for (i, &k_val) in k_indices.iter().enumerate() {
1145 let k_int = k_val as usize;
1146 special_terms[i] = x_last * if k_int % 2 == 0 { 1.0 } else { -1.0 };
1147 }
1148
1149 let mut regular_sum = vec![0.0f32; k_chunk_len];
1151
1152 for m_chunk in (0..(n - 1)).step_by(chunk_size) {
1153 let m_chunk_end = (m_chunk + chunk_size).min(n - 1);
1154 let m_chunk_len = m_chunk_end - m_chunk;
1155
1156 if m_chunk_len == k_chunk_len {
1157 let mut m_plus_one = vec![0.0f32; m_chunk_len];
1159 for (i, m_val) in m_plus_one.iter_mut().enumerate() {
1160 *m_val = (m_chunk + i + 1) as f32;
1161 }
1162
1163 let mut x_values = vec![0.0f32; m_chunk_len];
1165 for (i, x_val) in x_values.iter_mut().enumerate() {
1166 *x_val = x[m_chunk + i] as f32;
1167 }
1168
1169 let mut k_plus_half = vec![0.0f32; k_chunk_len];
1171 let half_vec = vec![0.5f32; k_chunk_len];
1172 simd_add_f32_ultra_vec(&k_indices, &half_vec, &mut k_plus_half);
1173
1174 let mut angles = vec![0.0f32; k_chunk_len];
1175 let mut temp_prod = vec![0.0f32; k_chunk_len];
1176 let pi_vec = vec![pi_f32; k_chunk_len];
1177 let n_vec = vec![n_f32; k_chunk_len];
1178
1179 simd_mul_f32_ultra_vec(&m_plus_one, &k_plus_half, &mut temp_prod);
1180 let mut temp_prod2 = vec![0.0f32; k_chunk_len];
1181 simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
1182 simd_div_f32_ultra_vec(&temp_prod2, &n_vec, &mut angles);
1183
1184 let mut sin_values = vec![0.0f32; k_chunk_len];
1186 simd_sin_f32_ultra_vec(&angles, &mut sin_values);
1187
1188 let mut products = vec![0.0f32; k_chunk_len];
1189 simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
1190
1191 let mut temp_sum = vec![0.0f32; k_chunk_len];
1193 simd_add_f32_ultra_vec(®ular_sum, &products, &mut temp_sum);
1194 regular_sum = temp_sum;
1195 }
1196 }
1197
1198 let mut total_sum = vec![0.0f32; k_chunk_len];
1200 simd_add_f32_ultra_vec(&special_terms, ®ular_sum, &mut total_sum);
1201
1202 for (i, &sum) in total_sum.iter().enumerate() {
1204 result[k_chunk + i] = sum as f64;
1205 }
1206 }
1207
1208 if let Some("ortho") = norm {
1210 let norm_factor = ((2.0 / n as f64).sqrt() / 2.0) as f32;
1211 let norm_vec = vec![norm_factor; chunk_size];
1212
1213 for chunk_start in (0..n).step_by(chunk_size) {
1214 let chunk_end = (chunk_start + chunk_size).min(n);
1215 let chunk_len = chunk_end - chunk_start;
1216
1217 if chunk_len == chunk_size {
1218 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1219 .iter()
1220 .map(|&x| x as f32)
1221 .collect();
1222 let mut normalized = vec![0.0f32; chunk_size];
1223
1224 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1225
1226 for (i, &val) in normalized.iter().enumerate() {
1227 result[chunk_start + i] = val as f64;
1228 }
1229 }
1230 }
1231 } else {
1232 let norm_factor = 0.5f32;
1234 let norm_vec = vec![norm_factor; chunk_size];
1235
1236 for chunk_start in (0..n).step_by(chunk_size) {
1237 let chunk_end = (chunk_start + chunk_size).min(n);
1238 let chunk_len = chunk_end - chunk_start;
1239
1240 if chunk_len == chunk_size {
1241 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1242 .iter()
1243 .map(|&x| x as f32)
1244 .collect();
1245 let mut normalized = vec![0.0f32; chunk_size];
1246
1247 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1248
1249 for (i, &val) in normalized.iter().enumerate() {
1250 result[chunk_start + i] = val as f64;
1251 }
1252 }
1253 }
1254 }
1255
1256 Ok(result)
1257}
1258
1259#[allow(dead_code)]
1261fn dst4_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
1262 use scirs2_core::simd_ops::SimdUnifiedOps;
1263
1264 let n = x.len();
1265 if n == 0 {
1266 return Err(FFTError::ValueError(
1267 "Input array cannot be empty".to_string(),
1268 ));
1269 }
1270
1271 let mut result = vec![0.0; n];
1272 let chunk_size = 8;
1273
1274 let pi_f32 = PI as f32;
1276 let n_f32 = n as f32;
1277
1278 for k_chunk in (0..n).step_by(chunk_size) {
1279 let k_chunk_end = (k_chunk + chunk_size).min(n);
1280 let k_chunk_len = k_chunk_end - k_chunk;
1281
1282 let mut k_indices = vec![0.0f32; k_chunk_len];
1284 for (i, k_idx) in k_indices.iter_mut().enumerate() {
1285 *k_idx = (k_chunk + i) as f32;
1286 }
1287
1288 let mut k_plus_half = vec![0.0f32; k_chunk_len];
1290 let half_vec = vec![0.5f32; k_chunk_len];
1291 simd_add_f32_ultra_vec(&k_indices, &half_vec, &mut k_plus_half);
1292
1293 let mut chunk_sum = vec![0.0f32; k_chunk_len];
1294
1295 for m_chunk in (0..n).step_by(chunk_size) {
1296 let m_chunk_end = (m_chunk + chunk_size).min(n);
1297 let m_chunk_len = m_chunk_end - m_chunk;
1298
1299 if m_chunk_len == k_chunk_len {
1300 let mut m_indices = vec![0.0f32; m_chunk_len];
1302 for (i, m_idx) in m_indices.iter_mut().enumerate() {
1303 *m_idx = (m_chunk + i) as f32;
1304 }
1305
1306 let mut m_plus_half = vec![0.0f32; m_chunk_len];
1308 simd_add_f32_ultra_vec(&m_indices, &half_vec, &mut m_plus_half);
1309
1310 let mut x_values = vec![0.0f32; m_chunk_len];
1312 for (i, x_val) in x_values.iter_mut().enumerate() {
1313 *x_val = x[m_chunk + i] as f32;
1314 }
1315
1316 let mut angles = vec![0.0f32; k_chunk_len];
1318 let mut temp_prod = vec![0.0f32; k_chunk_len];
1319 let pi_vec = vec![pi_f32; k_chunk_len];
1320 let n_vec = vec![n_f32; k_chunk_len];
1321
1322 simd_mul_f32_ultra_vec(&m_plus_half, &k_plus_half, &mut temp_prod);
1323 let mut temp_prod2 = vec![0.0f32; k_chunk_len];
1324 simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
1325 simd_div_f32_ultra_vec(&temp_prod2, &n_vec, &mut angles);
1326
1327 let mut sin_values = vec![0.0f32; k_chunk_len];
1329 simd_sin_f32_ultra_vec(&angles, &mut sin_values);
1330
1331 let mut products = vec![0.0f32; k_chunk_len];
1332 simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
1333
1334 let mut temp_sum = vec![0.0f32; k_chunk_len];
1336 simd_add_f32_ultra_vec(&chunk_sum, &products, &mut temp_sum);
1337 chunk_sum = temp_sum;
1338 }
1339 }
1340
1341 for (i, &sum) in chunk_sum.iter().enumerate() {
1343 result[k_chunk + i] = sum as f64;
1344 }
1345 }
1346
1347 if let Some("ortho") = norm {
1349 let norm_factor = (2.0 / n as f64).sqrt() as f32;
1350 let norm_vec = vec![norm_factor; chunk_size];
1351
1352 for chunk_start in (0..n).step_by(chunk_size) {
1353 let chunk_end = (chunk_start + chunk_size).min(n);
1354 let chunk_len = chunk_end - chunk_start;
1355
1356 if chunk_len == chunk_size {
1357 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1358 .iter()
1359 .map(|&x| x as f32)
1360 .collect();
1361 let mut normalized = vec![0.0f32; chunk_size];
1362
1363 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1364
1365 for (i, &val) in normalized.iter().enumerate() {
1366 result[chunk_start + i] = val as f64;
1367 }
1368 }
1369 }
1370 } else {
1371 let norm_factor = 2.0f32;
1373 let norm_vec = vec![norm_factor; chunk_size];
1374
1375 for chunk_start in (0..n).step_by(chunk_size) {
1376 let chunk_end = (chunk_start + chunk_size).min(n);
1377 let chunk_len = chunk_end - chunk_start;
1378
1379 if chunk_len == chunk_size {
1380 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1381 .iter()
1382 .map(|&x| x as f32)
1383 .collect();
1384 let mut normalized = vec![0.0f32; chunk_size];
1385
1386 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1387
1388 for (i, &val) in normalized.iter().enumerate() {
1389 result[chunk_start + i] = val as f64;
1390 }
1391 }
1392 }
1393 }
1394
1395 Ok(result)
1396}
1397
1398#[allow(dead_code)]
1403pub fn dst2_bandwidth_saturated_simd<T>(
1404 x: &ArrayView2<T>,
1405 dst_type: Option<DSTType>,
1406 norm: Option<&str>,
1407) -> FFTResult<Array2<f64>>
1408where
1409 T: NumCast + Copy + Debug,
1410{
1411 use scirs2_core::simd_ops::PlatformCapabilities;
1412
1413 let (n_rows, n_cols) = x.dim();
1414 let caps = PlatformCapabilities::detect();
1415
1416 if (n_rows >= 32 && n_cols >= 32) && (caps.has_avx2() || caps.has_avx512()) {
1418 dst2_bandwidth_saturated_simd_impl(x, dst_type, norm)
1419 } else {
1420 dst2(x, dst_type, norm)
1422 }
1423}
1424
1425#[allow(dead_code)]
1427fn dst2_bandwidth_saturated_simd_impl<T>(
1428 x: &ArrayView2<T>,
1429 dst_type: Option<DSTType>,
1430 norm: Option<&str>,
1431) -> FFTResult<Array2<f64>>
1432where
1433 T: NumCast + Copy + Debug,
1434{
1435 let (n_rows, n_cols) = x.dim();
1436 let type_val = dst_type.unwrap_or(DSTType::Type2);
1437
1438 let mut intermediate = Array2::zeros((n_rows, n_cols));
1440
1441 for r in 0..n_rows {
1442 let row_slice = x.slice(ndarray::s![r, ..]);
1443 let row_vec: Vec<T> = row_slice.iter().cloned().collect();
1444
1445 let row_dst = dst_bandwidth_saturated_simd(&row_vec, Some(type_val), norm)?;
1447
1448 for (c, val) in row_dst.iter().enumerate() {
1449 intermediate[[r, c]] = *val;
1450 }
1451 }
1452
1453 let mut final_result = Array2::zeros((n_rows, n_cols));
1455
1456 for c in 0..n_cols {
1457 let col_slice = intermediate.slice(ndarray::s![.., c]);
1458 let col_vec: Vec<f64> = col_slice.iter().cloned().collect();
1459
1460 let col_dst = dst_bandwidth_saturated_simd(&col_vec, Some(type_val), norm)?;
1462
1463 for (r, val) in col_dst.iter().enumerate() {
1464 final_result[[r, c]] = *val;
1465 }
1466 }
1467
1468 Ok(final_result)
1469}