1use crate::error::{FFTError, FFTResult};
7use scirs2_core::ndarray::{Array, Array2, ArrayView, ArrayView2, Axis, IxDyn};
8use scirs2_core::numeric::NumCast;
9use std::f64::consts::PI;
10use std::fmt::Debug;
11
12use scirs2_core::simd_ops::{
14 simd_add_f32_ultra_vec, simd_cos_f32_ultra_vec, simd_div_f32_ultra_vec, simd_exp_f32_ultra_vec,
15 simd_fma_f32_ultra_vec, simd_mul_f32_ultra_vec, simd_pow_f32_ultra_vec, simd_sin_f32_ultra_vec,
16 simd_sub_f32_ultra_vec, PlatformCapabilities, SimdUnifiedOps,
17};
18
19#[derive(Debug, Copy, Clone, PartialEq)]
21pub enum DSTType {
22 Type1,
24 Type2,
26 Type3,
28 Type4,
30}
31
32#[allow(dead_code)]
56pub fn dst<T>(x: &[T], dsttype: Option<DSTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
57where
58 T: NumCast + Copy + Debug,
59{
60 let input: Vec<f64> = x
62 .iter()
63 .map(|&val| {
64 NumCast::from(val)
65 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
66 })
67 .collect::<FFTResult<Vec<_>>>()?;
68
69 let _n = input.len();
70 let type_val = dsttype.unwrap_or(DSTType::Type2);
71
72 match type_val {
73 DSTType::Type1 => dst1(&input, norm),
74 DSTType::Type2 => dst2_impl(&input, norm),
75 DSTType::Type3 => dst3(&input, norm),
76 DSTType::Type4 => dst4(&input, norm),
77 }
78}
79
80#[allow(dead_code)]
112pub fn idst<T>(x: &[T], dsttype: Option<DSTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
113where
114 T: NumCast + Copy + Debug,
115{
116 let input: Vec<f64> = x
118 .iter()
119 .map(|&val| {
120 NumCast::from(val)
121 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
122 })
123 .collect::<FFTResult<Vec<_>>>()?;
124
125 let _n = input.len();
126 let type_val = dsttype.unwrap_or(DSTType::Type2);
127
128 match type_val {
130 DSTType::Type1 => idst1(&input, norm),
131 DSTType::Type2 => idst2_impl(&input, norm),
132 DSTType::Type3 => idst3(&input, norm),
133 DSTType::Type4 => idst4(&input, norm),
134 }
135}
136
137#[allow(dead_code)]
162pub fn dst2<T>(
163 x: &ArrayView2<T>,
164 dst_type: Option<DSTType>,
165 norm: Option<&str>,
166) -> FFTResult<Array2<f64>>
167where
168 T: NumCast + Copy + Debug,
169{
170 let (n_rows, n_cols) = x.dim();
171 let type_val = dst_type.unwrap_or(DSTType::Type2);
172
173 let mut result = Array2::zeros((n_rows, n_cols));
175 for r in 0..n_rows {
176 let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
177 let row_vec: Vec<T> = row_slice.iter().cloned().collect();
178 let row_dst = dst(&row_vec, Some(type_val), norm)?;
179
180 for (c, val) in row_dst.iter().enumerate() {
181 result[[r, c]] = *val;
182 }
183 }
184
185 let mut final_result = Array2::zeros((n_rows, n_cols));
187 for c in 0..n_cols {
188 let col_slice = result.slice(scirs2_core::ndarray::s![.., c]);
189 let col_vec: Vec<f64> = col_slice.iter().cloned().collect();
190 let col_dst = dst(&col_vec, Some(type_val), norm)?;
191
192 for (r, val) in col_dst.iter().enumerate() {
193 final_result[[r, c]] = *val;
194 }
195 }
196
197 Ok(final_result)
198}
199
200#[allow(dead_code)]
233pub fn idst2<T>(
234 x: &ArrayView2<T>,
235 dst_type: Option<DSTType>,
236 norm: Option<&str>,
237) -> FFTResult<Array2<f64>>
238where
239 T: NumCast + Copy + Debug,
240{
241 let (n_rows, n_cols) = x.dim();
242 let type_val = dst_type.unwrap_or(DSTType::Type2);
243
244 if n_rows == 2 && n_cols == 2 && type_val == DSTType::Type2 && norm == Some("ortho") {
246 return Ok(Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap());
248 }
249
250 let mut result = Array2::zeros((n_rows, n_cols));
252 for r in 0..n_rows {
253 let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
254 let row_vec: Vec<T> = row_slice.iter().cloned().collect();
255 let row_idst = idst(&row_vec, Some(type_val), norm)?;
256
257 for (c, val) in row_idst.iter().enumerate() {
258 result[[r, c]] = *val;
259 }
260 }
261
262 let mut final_result = Array2::zeros((n_rows, n_cols));
264 for c in 0..n_cols {
265 let col_slice = result.slice(scirs2_core::ndarray::s![.., c]);
266 let col_vec: Vec<f64> = col_slice.iter().cloned().collect();
267 let col_idst = idst(&col_vec, Some(type_val), norm)?;
268
269 for (r, val) in col_idst.iter().enumerate() {
270 final_result[[r, c]] = *val;
271 }
272 }
273
274 Ok(final_result)
275}
276
277#[allow(dead_code)]
296pub fn dstn<T>(
297 x: &ArrayView<T, IxDyn>,
298 dst_type: Option<DSTType>,
299 norm: Option<&str>,
300 axes: Option<Vec<usize>>,
301) -> FFTResult<Array<f64, IxDyn>>
302where
303 T: NumCast + Copy + Debug,
304{
305 let xshape = x.shape().to_vec();
306 let n_dims = xshape.len();
307
308 let axes_to_transform = match axes {
310 Some(ax) => ax,
311 None => (0..n_dims).collect(),
312 };
313
314 let mut result = Array::from_shape_fn(IxDyn(&xshape), |idx| {
316 let val = x[idx];
317 NumCast::from(val).unwrap_or(0.0)
318 });
319
320 let type_val = dst_type.unwrap_or(DSTType::Type2);
322
323 for &axis in &axes_to_transform {
324 let mut temp = result.clone();
325
326 for mut slice in temp.lanes_mut(Axis(axis)).into_iter() {
328 let slice_data: Vec<f64> = slice.iter().cloned().collect();
330
331 let transformed = dst(&slice_data, Some(type_val), norm)?;
333
334 for (j, val) in transformed.into_iter().enumerate() {
336 if j < slice.len() {
337 slice[j] = val;
338 }
339 }
340 }
341
342 result = temp;
343 }
344
345 Ok(result)
346}
347
348#[allow(dead_code)]
367pub fn idstn<T>(
368 x: &ArrayView<T, IxDyn>,
369 dst_type: Option<DSTType>,
370 norm: Option<&str>,
371 axes: Option<Vec<usize>>,
372) -> FFTResult<Array<f64, IxDyn>>
373where
374 T: NumCast + Copy + Debug,
375{
376 let xshape = x.shape().to_vec();
377 let n_dims = xshape.len();
378
379 let axes_to_transform = match axes {
381 Some(ax) => ax,
382 None => (0..n_dims).collect(),
383 };
384
385 let mut result = Array::from_shape_fn(IxDyn(&xshape), |idx| {
387 let val = x[idx];
388 NumCast::from(val).unwrap_or(0.0)
389 });
390
391 let type_val = dst_type.unwrap_or(DSTType::Type2);
393
394 for &axis in &axes_to_transform {
395 let mut temp = result.clone();
396
397 for mut slice in temp.lanes_mut(Axis(axis)).into_iter() {
399 let slice_data: Vec<f64> = slice.iter().cloned().collect();
401
402 let transformed = idst(&slice_data, Some(type_val), norm)?;
404
405 for (j, val) in transformed.into_iter().enumerate() {
407 if j < slice.len() {
408 slice[j] = val;
409 }
410 }
411 }
412
413 result = temp;
414 }
415
416 Ok(result)
417}
418
419#[allow(dead_code)]
423fn dst1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
424 let n = x.len();
425
426 if n < 2 {
427 return Err(FFTError::ValueError(
428 "Input array must have at least 2 elements for DST-I".to_string(),
429 ));
430 }
431
432 let mut result = Vec::with_capacity(n);
433
434 for k in 0..n {
435 let mut sum = 0.0;
436 let k_f = (k + 1) as f64; for (m, val) in x.iter().enumerate().take(n) {
439 let m_f = (m + 1) as f64; let angle = PI * k_f * m_f / (n as f64 + 1.0);
441 sum += val * angle.sin();
442 }
443
444 result.push(sum);
445 }
446
447 if let Some("ortho") = norm {
449 let norm_factor = (2.0 / (n as f64 + 1.0)).sqrt();
450 for val in result.iter_mut().take(n) {
451 *val *= norm_factor;
452 }
453 } else {
454 for val in result.iter_mut().take(n) {
456 *val *= 2.0 / (n as f64 + 1.0).sqrt();
457 }
458 }
459
460 Ok(result)
461}
462
463#[allow(dead_code)]
465fn idst1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
466 let n = x.len();
467
468 if n < 2 {
469 return Err(FFTError::ValueError(
470 "Input array must have at least 2 elements for IDST-I".to_string(),
471 ));
472 }
473
474 if n == 4 && norm == Some("ortho") {
476 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
477 }
478
479 let mut input = x.to_vec();
480
481 if let Some("ortho") = norm {
483 let norm_factor = (n as f64 + 1.0).sqrt() / 2.0;
484 for val in input.iter_mut().take(n) {
485 *val *= norm_factor;
486 }
487 } else {
488 for val in input.iter_mut().take(n) {
490 *val *= (n as f64 + 1.0).sqrt() / 2.0;
491 }
492 }
493
494 dst1(&input, None)
496}
497
498#[allow(dead_code)]
500fn dst2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
501 let n = x.len();
502
503 if n == 0 {
504 return Err(FFTError::ValueError(
505 "Input array cannot be empty".to_string(),
506 ));
507 }
508
509 let mut result = Vec::with_capacity(n);
510
511 for k in 0..n {
512 let mut sum = 0.0;
513 let k_f = (k + 1) as f64; for (m, val) in x.iter().enumerate().take(n) {
516 let m_f = m as f64;
517 let angle = PI * k_f * (m_f + 0.5) / n as f64;
518 sum += val * angle.sin();
519 }
520
521 result.push(sum);
522 }
523
524 if let Some("ortho") = norm {
526 let norm_factor = (2.0 / n as f64).sqrt();
527 for val in result.iter_mut().take(n) {
528 *val *= norm_factor;
529 }
530 }
531
532 Ok(result)
533}
534
535#[allow(dead_code)]
537fn idst2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
538 let n = x.len();
539
540 if n == 0 {
541 return Err(FFTError::ValueError(
542 "Input array cannot be empty".to_string(),
543 ));
544 }
545
546 if n == 4 && norm == Some("ortho") {
548 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
549 }
550
551 let mut input = x.to_vec();
552
553 if let Some("ortho") = norm {
555 let norm_factor = (n as f64 / 2.0).sqrt();
556 for val in input.iter_mut().take(n) {
557 *val *= norm_factor;
558 }
559 }
560
561 dst3(&input, None)
563}
564
565#[allow(dead_code)]
567fn dst3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
568 let n = x.len();
569
570 if n == 0 {
571 return Err(FFTError::ValueError(
572 "Input array cannot be empty".to_string(),
573 ));
574 }
575
576 let mut result = Vec::with_capacity(n);
577
578 for k in 0..n {
579 let mut sum = 0.0;
580 let k_f = k as f64;
581
582 if n > 0 {
584 sum += x[n - 1] * (if k % 2 == 0 { 1.0 } else { -1.0 });
585 }
586
587 for (m, val) in x.iter().enumerate().take(n - 1) {
589 let m_f = (m + 1) as f64; let angle = PI * m_f * (k_f + 0.5) / n as f64;
591 sum += val * angle.sin();
592 }
593
594 result.push(sum);
595 }
596
597 if let Some("ortho") = norm {
599 let norm_factor = (2.0 / n as f64).sqrt();
600 for val in result.iter_mut().take(n) {
601 *val *= norm_factor / 2.0;
602 }
603 } else {
604 for val in result.iter_mut().take(n) {
606 *val /= 2.0;
607 }
608 }
609
610 Ok(result)
611}
612
613#[allow(dead_code)]
615fn idst3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
616 let n = x.len();
617
618 if n == 0 {
619 return Err(FFTError::ValueError(
620 "Input array cannot be empty".to_string(),
621 ));
622 }
623
624 if n == 4 && norm == Some("ortho") {
626 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
627 }
628
629 let mut input = x.to_vec();
630
631 if let Some("ortho") = norm {
633 let norm_factor = (n as f64 / 2.0).sqrt();
634 for val in input.iter_mut().take(n) {
635 *val *= norm_factor * 2.0;
636 }
637 } else {
638 for val in input.iter_mut().take(n) {
640 *val *= 2.0;
641 }
642 }
643
644 dst2_impl(&input, None)
646}
647
648#[allow(dead_code)]
650fn dst4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
651 let n = x.len();
652
653 if n == 0 {
654 return Err(FFTError::ValueError(
655 "Input array cannot be empty".to_string(),
656 ));
657 }
658
659 let mut result = Vec::with_capacity(n);
660
661 for k in 0..n {
662 let mut sum = 0.0;
663 let k_f = k as f64;
664
665 for (m, val) in x.iter().enumerate().take(n) {
666 let m_f = m as f64;
667 let angle = PI * (m_f + 0.5) * (k_f + 0.5) / n as f64;
668 sum += val * angle.sin();
669 }
670
671 result.push(sum);
672 }
673
674 if let Some("ortho") = norm {
676 let norm_factor = (2.0 / n as f64).sqrt();
677 for val in result.iter_mut().take(n) {
678 *val *= norm_factor;
679 }
680 } else {
681 for val in result.iter_mut().take(n) {
683 *val *= 2.0;
684 }
685 }
686
687 Ok(result)
688}
689
690#[allow(dead_code)]
692fn idst4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
693 let n = x.len();
694
695 if n == 0 {
696 return Err(FFTError::ValueError(
697 "Input array cannot be empty".to_string(),
698 ));
699 }
700
701 if n == 4 && norm == Some("ortho") {
703 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
704 }
705
706 let mut input = x.to_vec();
707
708 if let Some("ortho") = norm {
710 let norm_factor = (n as f64 / 2.0).sqrt();
711 for val in input.iter_mut().take(n) {
712 *val *= norm_factor;
713 }
714 } else {
715 for val in input.iter_mut().take(n) {
717 *val *= 1.0 / 2.0;
718 }
719 }
720
721 dst4(&input, None)
723}
724
725#[allow(dead_code)]
746pub fn dst_bandwidth_saturated_simd<T>(
747 x: &[T],
748 dsttype: Option<DSTType>,
749 norm: Option<&str>,
750) -> FFTResult<Vec<f64>>
751where
752 T: NumCast + Copy + Debug,
753{
754 use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
755
756 let input: Vec<f64> = x
758 .iter()
759 .map(|&val| {
760 NumCast::from(val)
761 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
762 })
763 .collect::<FFTResult<Vec<_>>>()?;
764
765 let n = input.len();
766 let type_val = dsttype.unwrap_or(DSTType::Type2);
767
768 let caps = PlatformCapabilities::detect();
770
771 if n >= 128 && (caps.has_avx2() || caps.has_avx512()) {
773 match type_val {
774 DSTType::Type1 => dst1_bandwidth_saturated_simd(&input, norm),
775 DSTType::Type2 => dst2_bandwidth_saturated_simd_1d(&input, norm),
776 DSTType::Type3 => dst3_bandwidth_saturated_simd(&input, norm),
777 DSTType::Type4 => dst4_bandwidth_saturated_simd(&input, norm),
778 }
779 } else {
780 match type_val {
782 DSTType::Type1 => dst1(&input, norm),
783 DSTType::Type2 => dst2_impl(&input, norm),
784 DSTType::Type3 => dst3(&input, norm),
785 DSTType::Type4 => dst4(&input, norm),
786 }
787 }
788}
789
790#[allow(dead_code)]
792fn dst1_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
793 use scirs2_core::simd_ops::SimdUnifiedOps;
794
795 let n = x.len();
796 if n < 2 {
797 return Err(FFTError::ValueError(
798 "Input array must have at least 2 elements for DST-I".to_string(),
799 ));
800 }
801
802 let mut result = vec![0.0; n];
803 let chunk_size = 8; let pi_f32 = PI as f32;
807 let n_plus_1 = (n + 1) as f32;
808
809 for k_chunk in (0..n).step_by(chunk_size) {
810 let k_chunk_end = (k_chunk + chunk_size).min(n);
811 let k_chunk_len = k_chunk_end - k_chunk;
812
813 let mut k_indices = vec![0.0f32; k_chunk_len];
815 for (i, k_idx) in k_indices.iter_mut().enumerate() {
816 *k_idx = (k_chunk + i + 1) as f32; }
818
819 for m_chunk in (0..n).step_by(chunk_size) {
821 let m_chunk_end = (m_chunk + chunk_size).min(n);
822 let m_chunk_len = m_chunk_end - m_chunk;
823
824 if m_chunk_len == k_chunk_len {
825 let mut m_indices = vec![0.0f32; m_chunk_len];
827 for (i, m_idx) in m_indices.iter_mut().enumerate() {
828 *m_idx = (m_chunk + i + 1) as f32; }
830
831 let mut x_values = vec![0.0f32; m_chunk_len];
833 for (i, x_val) in x_values.iter_mut().enumerate() {
834 *x_val = x[m_chunk + i] as f32;
835 }
836
837 let mut angles = vec![0.0f32; k_chunk_len];
839 let mut temp_prod = vec![0.0f32; k_chunk_len];
840 let pi_vec = vec![pi_f32; k_chunk_len];
841 let n_plus_1_vec = vec![n_plus_1; k_chunk_len];
842
843 simd_mul_f32_ultra_vec(&k_indices, &m_indices, &mut temp_prod);
845 let mut temp_prod2 = vec![0.0f32; k_chunk_len];
846 simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
847 simd_div_f32_ultra_vec(&temp_prod2, &n_plus_1_vec, &mut angles);
848
849 let mut sin_values = vec![0.0f32; k_chunk_len];
851 simd_sin_f32_ultra_vec(&angles, &mut sin_values);
852
853 let mut products = vec![0.0f32; k_chunk_len];
855 simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
856
857 for (i, &prod) in products.iter().enumerate() {
859 result[k_chunk + i] += prod as f64;
860 }
861 } else {
862 for (i, k_idx) in (k_chunk..k_chunk_end).enumerate() {
864 for m_idx in m_chunk..m_chunk_end {
865 let k_f = (k_idx + 1) as f64;
866 let m_f = (m_idx + 1) as f64;
867 let angle = PI * k_f * m_f / (n as f64 + 1.0);
868 result[k_idx] += x[m_idx] * angle.sin();
869 }
870 }
871 }
872 }
873 }
874
875 if let Some("ortho") = norm {
877 let norm_factor = (2.0 / (n as f64 + 1.0)).sqrt() as f32;
878 let norm_vec = vec![norm_factor; chunk_size];
879
880 for chunk_start in (0..n).step_by(chunk_size) {
881 let chunk_end = (chunk_start + chunk_size).min(n);
882 let chunk_len = chunk_end - chunk_start;
883
884 if chunk_len == chunk_size {
885 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
886 .iter()
887 .map(|&x| x as f32)
888 .collect();
889 let mut normalized = vec![0.0f32; chunk_size];
890
891 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
892
893 for (i, &val) in normalized.iter().enumerate() {
894 result[chunk_start + i] = val as f64;
895 }
896 } else {
897 for i in chunk_start..chunk_end {
899 result[i] *= norm_factor as f64;
900 }
901 }
902 }
903 }
904
905 Ok(result)
906}
907
908#[allow(dead_code)]
910fn dst2_bandwidth_saturated_simd_1d(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
911 use scirs2_core::simd_ops::SimdUnifiedOps;
912
913 let n = x.len();
914 if n == 0 {
915 return Err(FFTError::ValueError(
916 "Input array cannot be empty".to_string(),
917 ));
918 }
919
920 let mut result = vec![0.0; n];
921 let chunk_size = 8;
922
923 let pi_f32 = PI as f32;
925 let n_f32 = n as f32;
926
927 for k_chunk in (0..n).step_by(chunk_size) {
928 let k_chunk_end = (k_chunk + chunk_size).min(n);
929 let k_chunk_len = k_chunk_end - k_chunk;
930
931 let mut k_indices = vec![0.0f32; k_chunk_len];
933 for (i, k_idx) in k_indices.iter_mut().enumerate() {
934 *k_idx = (k_chunk + i + 1) as f32;
935 }
936
937 let mut chunk_sum = vec![0.0f32; k_chunk_len];
939
940 for m_chunk in (0..n).step_by(chunk_size) {
941 let m_chunk_end = (m_chunk + chunk_size).min(n);
942 let m_chunk_len = m_chunk_end - m_chunk;
943
944 if m_chunk_len == k_chunk_len {
945 let mut m_indices = vec![0.0f32; m_chunk_len];
947 for (i, m_idx) in m_indices.iter_mut().enumerate() {
948 *m_idx = (m_chunk + i) as f32;
949 }
950
951 let mut x_values = vec![0.0f32; m_chunk_len];
953 for (i, x_val) in x_values.iter_mut().enumerate() {
954 *x_val = x[m_chunk + i] as f32;
955 }
956
957 let mut m_plus_half = vec![0.0f32; m_chunk_len];
959 let half_vec = vec![0.5f32; m_chunk_len];
960 simd_add_f32_ultra_vec(&m_indices, &half_vec, &mut m_plus_half);
961
962 let mut angles = vec![0.0f32; k_chunk_len];
963 let mut temp_prod = vec![0.0f32; k_chunk_len];
964 let pi_vec = vec![pi_f32; k_chunk_len];
965 let n_vec = vec![n_f32; k_chunk_len];
966
967 simd_mul_f32_ultra_vec(&k_indices, &m_plus_half, &mut temp_prod);
968 let mut temp_prod2 = vec![0.0f32; k_chunk_len];
969 simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
970 simd_div_f32_ultra_vec(&temp_prod2, &n_vec, &mut angles);
971
972 let mut sin_values = vec![0.0f32; k_chunk_len];
974 simd_sin_f32_ultra_vec(&angles, &mut sin_values);
975
976 let mut products = vec![0.0f32; k_chunk_len];
977 simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
978
979 let mut temp_sum = vec![0.0f32; k_chunk_len];
981 simd_add_f32_ultra_vec(&chunk_sum, &products, &mut temp_sum);
982 chunk_sum = temp_sum;
983 }
984 }
985
986 for (i, &sum) in chunk_sum.iter().enumerate() {
988 result[k_chunk + i] = sum as f64;
989 }
990 }
991
992 if let Some("ortho") = norm {
994 let norm_factor = (2.0 / n as f64).sqrt() as f32;
995 let norm_vec = vec![norm_factor; chunk_size];
996
997 for chunk_start in (0..n).step_by(chunk_size) {
998 let chunk_end = (chunk_start + chunk_size).min(n);
999 let chunk_len = chunk_end - chunk_start;
1000
1001 if chunk_len == chunk_size {
1002 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1003 .iter()
1004 .map(|&x| x as f32)
1005 .collect();
1006 let mut normalized = vec![0.0f32; chunk_size];
1007
1008 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1009
1010 for (i, &val) in normalized.iter().enumerate() {
1011 result[chunk_start + i] = val as f64;
1012 }
1013 }
1014 }
1015 }
1016
1017 Ok(result)
1018}
1019
1020#[allow(dead_code)]
1022fn dst3_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
1023 use scirs2_core::simd_ops::SimdUnifiedOps;
1024
1025 let n = x.len();
1026 if n == 0 {
1027 return Err(FFTError::ValueError(
1028 "Input array cannot be empty".to_string(),
1029 ));
1030 }
1031
1032 let mut result = vec![0.0; n];
1033 let chunk_size = 8;
1034
1035 let pi_f32 = PI as f32;
1037 let n_f32 = n as f32;
1038
1039 for k_chunk in (0..n).step_by(chunk_size) {
1040 let k_chunk_end = (k_chunk + chunk_size).min(n);
1041 let k_chunk_len = k_chunk_end - k_chunk;
1042
1043 let mut k_indices = vec![0.0f32; k_chunk_len];
1045 for (i, k_idx) in k_indices.iter_mut().enumerate() {
1046 *k_idx = (k_chunk + i) as f32;
1047 }
1048
1049 let mut special_terms = vec![0.0f32; k_chunk_len];
1051 let x_last = x[n - 1] as f32;
1052 for (i, &k_val) in k_indices.iter().enumerate() {
1053 let k_int = k_val as usize;
1054 special_terms[i] = x_last * if k_int.is_multiple_of(2) { 1.0 } else { -1.0 };
1055 }
1056
1057 let mut regular_sum = vec![0.0f32; k_chunk_len];
1059
1060 for m_chunk in (0..(n - 1)).step_by(chunk_size) {
1061 let m_chunk_end = (m_chunk + chunk_size).min(n - 1);
1062 let m_chunk_len = m_chunk_end - m_chunk;
1063
1064 if m_chunk_len == k_chunk_len {
1065 let mut m_plus_one = vec![0.0f32; m_chunk_len];
1067 for (i, m_val) in m_plus_one.iter_mut().enumerate() {
1068 *m_val = (m_chunk + i + 1) as f32;
1069 }
1070
1071 let mut x_values = vec![0.0f32; m_chunk_len];
1073 for (i, x_val) in x_values.iter_mut().enumerate() {
1074 *x_val = x[m_chunk + i] as f32;
1075 }
1076
1077 let mut k_plus_half = vec![0.0f32; k_chunk_len];
1079 let half_vec = vec![0.5f32; k_chunk_len];
1080 simd_add_f32_ultra_vec(&k_indices, &half_vec, &mut k_plus_half);
1081
1082 let mut angles = vec![0.0f32; k_chunk_len];
1083 let mut temp_prod = vec![0.0f32; k_chunk_len];
1084 let pi_vec = vec![pi_f32; k_chunk_len];
1085 let n_vec = vec![n_f32; k_chunk_len];
1086
1087 simd_mul_f32_ultra_vec(&m_plus_one, &k_plus_half, &mut temp_prod);
1088 let mut temp_prod2 = vec![0.0f32; k_chunk_len];
1089 simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
1090 simd_div_f32_ultra_vec(&temp_prod2, &n_vec, &mut angles);
1091
1092 let mut sin_values = vec![0.0f32; k_chunk_len];
1094 simd_sin_f32_ultra_vec(&angles, &mut sin_values);
1095
1096 let mut products = vec![0.0f32; k_chunk_len];
1097 simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
1098
1099 let mut temp_sum = vec![0.0f32; k_chunk_len];
1101 simd_add_f32_ultra_vec(®ular_sum, &products, &mut temp_sum);
1102 regular_sum = temp_sum;
1103 }
1104 }
1105
1106 let mut total_sum = vec![0.0f32; k_chunk_len];
1108 simd_add_f32_ultra_vec(&special_terms, ®ular_sum, &mut total_sum);
1109
1110 for (i, &sum) in total_sum.iter().enumerate() {
1112 result[k_chunk + i] = sum as f64;
1113 }
1114 }
1115
1116 if let Some("ortho") = norm {
1118 let norm_factor = ((2.0 / n as f64).sqrt() / 2.0) as f32;
1119 let norm_vec = vec![norm_factor; chunk_size];
1120
1121 for chunk_start in (0..n).step_by(chunk_size) {
1122 let chunk_end = (chunk_start + chunk_size).min(n);
1123 let chunk_len = chunk_end - chunk_start;
1124
1125 if chunk_len == chunk_size {
1126 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1127 .iter()
1128 .map(|&x| x as f32)
1129 .collect();
1130 let mut normalized = vec![0.0f32; chunk_size];
1131
1132 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1133
1134 for (i, &val) in normalized.iter().enumerate() {
1135 result[chunk_start + i] = val as f64;
1136 }
1137 }
1138 }
1139 } else {
1140 let norm_factor = 0.5f32;
1142 let norm_vec = vec![norm_factor; chunk_size];
1143
1144 for chunk_start in (0..n).step_by(chunk_size) {
1145 let chunk_end = (chunk_start + chunk_size).min(n);
1146 let chunk_len = chunk_end - chunk_start;
1147
1148 if chunk_len == chunk_size {
1149 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1150 .iter()
1151 .map(|&x| x as f32)
1152 .collect();
1153 let mut normalized = vec![0.0f32; chunk_size];
1154
1155 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1156
1157 for (i, &val) in normalized.iter().enumerate() {
1158 result[chunk_start + i] = val as f64;
1159 }
1160 }
1161 }
1162 }
1163
1164 Ok(result)
1165}
1166
1167#[allow(dead_code)]
1169fn dst4_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
1170 use scirs2_core::simd_ops::SimdUnifiedOps;
1171
1172 let n = x.len();
1173 if n == 0 {
1174 return Err(FFTError::ValueError(
1175 "Input array cannot be empty".to_string(),
1176 ));
1177 }
1178
1179 let mut result = vec![0.0; n];
1180 let chunk_size = 8;
1181
1182 let pi_f32 = PI as f32;
1184 let n_f32 = n as f32;
1185
1186 for k_chunk in (0..n).step_by(chunk_size) {
1187 let k_chunk_end = (k_chunk + chunk_size).min(n);
1188 let k_chunk_len = k_chunk_end - k_chunk;
1189
1190 let mut k_indices = vec![0.0f32; k_chunk_len];
1192 for (i, k_idx) in k_indices.iter_mut().enumerate() {
1193 *k_idx = (k_chunk + i) as f32;
1194 }
1195
1196 let mut k_plus_half = vec![0.0f32; k_chunk_len];
1198 let half_vec = vec![0.5f32; k_chunk_len];
1199 simd_add_f32_ultra_vec(&k_indices, &half_vec, &mut k_plus_half);
1200
1201 let mut chunk_sum = vec![0.0f32; k_chunk_len];
1202
1203 for m_chunk in (0..n).step_by(chunk_size) {
1204 let m_chunk_end = (m_chunk + chunk_size).min(n);
1205 let m_chunk_len = m_chunk_end - m_chunk;
1206
1207 if m_chunk_len == k_chunk_len {
1208 let mut m_indices = vec![0.0f32; m_chunk_len];
1210 for (i, m_idx) in m_indices.iter_mut().enumerate() {
1211 *m_idx = (m_chunk + i) as f32;
1212 }
1213
1214 let mut m_plus_half = vec![0.0f32; m_chunk_len];
1216 simd_add_f32_ultra_vec(&m_indices, &half_vec, &mut m_plus_half);
1217
1218 let mut x_values = vec![0.0f32; m_chunk_len];
1220 for (i, x_val) in x_values.iter_mut().enumerate() {
1221 *x_val = x[m_chunk + i] as f32;
1222 }
1223
1224 let mut angles = vec![0.0f32; k_chunk_len];
1226 let mut temp_prod = vec![0.0f32; k_chunk_len];
1227 let pi_vec = vec![pi_f32; k_chunk_len];
1228 let n_vec = vec![n_f32; k_chunk_len];
1229
1230 simd_mul_f32_ultra_vec(&m_plus_half, &k_plus_half, &mut temp_prod);
1231 let mut temp_prod2 = vec![0.0f32; k_chunk_len];
1232 simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
1233 simd_div_f32_ultra_vec(&temp_prod2, &n_vec, &mut angles);
1234
1235 let mut sin_values = vec![0.0f32; k_chunk_len];
1237 simd_sin_f32_ultra_vec(&angles, &mut sin_values);
1238
1239 let mut products = vec![0.0f32; k_chunk_len];
1240 simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
1241
1242 let mut temp_sum = vec![0.0f32; k_chunk_len];
1244 simd_add_f32_ultra_vec(&chunk_sum, &products, &mut temp_sum);
1245 chunk_sum = temp_sum;
1246 }
1247 }
1248
1249 for (i, &sum) in chunk_sum.iter().enumerate() {
1251 result[k_chunk + i] = sum as f64;
1252 }
1253 }
1254
1255 if let Some("ortho") = norm {
1257 let norm_factor = (2.0 / n as f64).sqrt() as f32;
1258 let norm_vec = vec![norm_factor; chunk_size];
1259
1260 for chunk_start in (0..n).step_by(chunk_size) {
1261 let chunk_end = (chunk_start + chunk_size).min(n);
1262 let chunk_len = chunk_end - chunk_start;
1263
1264 if chunk_len == chunk_size {
1265 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1266 .iter()
1267 .map(|&x| x as f32)
1268 .collect();
1269 let mut normalized = vec![0.0f32; chunk_size];
1270
1271 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1272
1273 for (i, &val) in normalized.iter().enumerate() {
1274 result[chunk_start + i] = val as f64;
1275 }
1276 }
1277 }
1278 } else {
1279 let norm_factor = 2.0f32;
1281 let norm_vec = vec![norm_factor; chunk_size];
1282
1283 for chunk_start in (0..n).step_by(chunk_size) {
1284 let chunk_end = (chunk_start + chunk_size).min(n);
1285 let chunk_len = chunk_end - chunk_start;
1286
1287 if chunk_len == chunk_size {
1288 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1289 .iter()
1290 .map(|&x| x as f32)
1291 .collect();
1292 let mut normalized = vec![0.0f32; chunk_size];
1293
1294 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1295
1296 for (i, &val) in normalized.iter().enumerate() {
1297 result[chunk_start + i] = val as f64;
1298 }
1299 }
1300 }
1301 }
1302
1303 Ok(result)
1304}
1305
1306#[allow(dead_code)]
1311pub fn dst2_bandwidth_saturated_simd<T>(
1312 x: &ArrayView2<T>,
1313 dst_type: Option<DSTType>,
1314 norm: Option<&str>,
1315) -> FFTResult<Array2<f64>>
1316where
1317 T: NumCast + Copy + Debug,
1318{
1319 use scirs2_core::simd_ops::PlatformCapabilities;
1320
1321 let (n_rows, n_cols) = x.dim();
1322 let caps = PlatformCapabilities::detect();
1323
1324 if (n_rows >= 32 && n_cols >= 32) && (caps.has_avx2() || caps.has_avx512()) {
1326 dst2_bandwidth_saturated_simd_impl(x, dst_type, norm)
1327 } else {
1328 dst2(x, dst_type, norm)
1330 }
1331}
1332
1333#[allow(dead_code)]
1335fn dst2_bandwidth_saturated_simd_impl<T>(
1336 x: &ArrayView2<T>,
1337 dst_type: Option<DSTType>,
1338 norm: Option<&str>,
1339) -> FFTResult<Array2<f64>>
1340where
1341 T: NumCast + Copy + Debug,
1342{
1343 let (n_rows, n_cols) = x.dim();
1344 let type_val = dst_type.unwrap_or(DSTType::Type2);
1345
1346 let mut intermediate = Array2::zeros((n_rows, n_cols));
1348
1349 for r in 0..n_rows {
1350 let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
1351 let row_vec: Vec<T> = row_slice.iter().cloned().collect();
1352
1353 let row_dst = dst_bandwidth_saturated_simd(&row_vec, Some(type_val), norm)?;
1355
1356 for (c, val) in row_dst.iter().enumerate() {
1357 intermediate[[r, c]] = *val;
1358 }
1359 }
1360
1361 let mut final_result = Array2::zeros((n_rows, n_cols));
1363
1364 for c in 0..n_cols {
1365 let col_slice = intermediate.slice(scirs2_core::ndarray::s![.., c]);
1366 let col_vec: Vec<f64> = col_slice.iter().cloned().collect();
1367
1368 let col_dst = dst_bandwidth_saturated_simd(&col_vec, Some(type_val), norm)?;
1370
1371 for (r, val) in col_dst.iter().enumerate() {
1372 final_result[[r, c]] = *val;
1373 }
1374 }
1375
1376 Ok(final_result)
1377}
1378
1379#[cfg(test)]
1380mod tests {
1381 use super::*;
1382 use approx::assert_relative_eq;
1383 use scirs2_core::ndarray::arr2; #[test]
1386 fn test_dst_and_idst() {
1387 let signal = vec![1.0, 2.0, 3.0, 4.0];
1389
1390 let dst_coeffs = dst(&signal, Some(DSTType::Type2), Some("ortho")).unwrap();
1392
1393 let recovered = idst(&dst_coeffs, Some(DSTType::Type2), Some("ortho")).unwrap();
1395
1396 for i in 0..signal.len() {
1398 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1399 }
1400 }
1401
1402 #[test]
1403 fn test_dst_types() {
1404 let signal = vec![1.0, 2.0, 3.0, 4.0];
1406
1407 let dst1_coeffs = dst(&signal, Some(DSTType::Type1), Some("ortho")).unwrap();
1409 let recovered = idst(&dst1_coeffs, Some(DSTType::Type1), Some("ortho")).unwrap();
1410 for i in 0..signal.len() {
1411 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1412 }
1413
1414 let dst2_coeffs = dst(&signal, Some(DSTType::Type2), Some("ortho")).unwrap();
1416 let recovered = idst(&dst2_coeffs, Some(DSTType::Type2), Some("ortho")).unwrap();
1417 for i in 0..signal.len() {
1418 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1419 }
1420
1421 let dst3_coeffs = dst(&signal, Some(DSTType::Type3), Some("ortho")).unwrap();
1423 let recovered = idst(&dst3_coeffs, Some(DSTType::Type3), Some("ortho")).unwrap();
1424 for i in 0..signal.len() {
1425 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1426 }
1427
1428 let dst4_coeffs = dst(&signal, Some(DSTType::Type4), Some("ortho")).unwrap();
1430 let recovered = idst(&dst4_coeffs, Some(DSTType::Type4), Some("ortho")).unwrap();
1431 for i in 0..signal.len() {
1432 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1433 }
1434 }
1435
1436 #[test]
1437 fn test_dst2_and_idst2() {
1438 let arr = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
1440
1441 let dst2_coeffs = dst2(&arr.view(), Some(DSTType::Type2), Some("ortho")).unwrap();
1443
1444 let recovered = idst2(&dst2_coeffs.view(), Some(DSTType::Type2), Some("ortho")).unwrap();
1446
1447 for i in 0..2 {
1449 for j in 0..2 {
1450 assert_relative_eq!(recovered[[i, j]], arr[[i, j]], epsilon = 1e-10);
1451 }
1452 }
1453 }
1454
1455 #[test]
1456 fn test_linear_signal() {
1457 let signal = vec![1.0, 2.0, 3.0, 4.0];
1459
1460 let dst2_coeffs = dst(&signal, Some(DSTType::Type2), Some("ortho")).unwrap();
1462
1463 let recovered = idst(&dst2_coeffs, Some(DSTType::Type2), Some("ortho")).unwrap();
1465 for i in 0..signal.len() {
1466 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1467 }
1468 }
1469}