1use crate::error::{FFTError, FFTResult};
7use scirs2_core::ndarray::{Array, Array2, ArrayView, ArrayView2, Axis, IxDyn};
8use scirs2_core::numeric::NumCast;
9use std::f64::consts::PI;
10use std::fmt::Debug;
11
12use scirs2_core::simd_ops::{
14 simd_add_f32_ultra_vec, simd_cos_f32_ultra_vec, simd_div_f32_ultra_vec, simd_exp_f32_ultra_vec,
15 simd_fma_f32_ultra_vec, simd_mul_f32_ultra_vec, simd_pow_f32_ultra_vec, simd_sin_f32_ultra_vec,
16 simd_sub_f32_ultra_vec, PlatformCapabilities, SimdUnifiedOps,
17};
18
19#[derive(Debug, Copy, Clone, PartialEq)]
21pub enum DSTType {
22 Type1,
24 Type2,
26 Type3,
28 Type4,
30}
31
32#[allow(dead_code)]
56pub fn dst<T>(x: &[T], dsttype: Option<DSTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
57where
58 T: NumCast + Copy + Debug,
59{
60 let input: Vec<f64> = x
62 .iter()
63 .map(|&val| {
64 NumCast::from(val)
65 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
66 })
67 .collect::<FFTResult<Vec<_>>>()?;
68
69 let _n = input.len();
70 let type_val = dsttype.unwrap_or(DSTType::Type2);
71
72 match type_val {
73 DSTType::Type1 => dst1(&input, norm),
74 DSTType::Type2 => dst2_impl(&input, norm),
75 DSTType::Type3 => dst3(&input, norm),
76 DSTType::Type4 => dst4(&input, norm),
77 }
78}
79
80#[allow(dead_code)]
112pub fn idst<T>(x: &[T], dsttype: Option<DSTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
113where
114 T: NumCast + Copy + Debug,
115{
116 let input: Vec<f64> = x
118 .iter()
119 .map(|&val| {
120 NumCast::from(val)
121 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
122 })
123 .collect::<FFTResult<Vec<_>>>()?;
124
125 let _n = input.len();
126 let type_val = dsttype.unwrap_or(DSTType::Type2);
127
128 match type_val {
130 DSTType::Type1 => idst1(&input, norm),
131 DSTType::Type2 => idst2_impl(&input, norm),
132 DSTType::Type3 => idst3(&input, norm),
133 DSTType::Type4 => idst4(&input, norm),
134 }
135}
136
137#[allow(dead_code)]
162pub fn dst2<T>(
163 x: &ArrayView2<T>,
164 dst_type: Option<DSTType>,
165 norm: Option<&str>,
166) -> FFTResult<Array2<f64>>
167where
168 T: NumCast + Copy + Debug,
169{
170 let (n_rows, n_cols) = x.dim();
171 let type_val = dst_type.unwrap_or(DSTType::Type2);
172
173 let mut result = Array2::zeros((n_rows, n_cols));
175 for r in 0..n_rows {
176 let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
177 let row_vec: Vec<T> = row_slice.iter().cloned().collect();
178 let row_dst = dst(&row_vec, Some(type_val), norm)?;
179
180 for (c, val) in row_dst.iter().enumerate() {
181 result[[r, c]] = *val;
182 }
183 }
184
185 let mut final_result = Array2::zeros((n_rows, n_cols));
187 for c in 0..n_cols {
188 let col_slice = result.slice(scirs2_core::ndarray::s![.., c]);
189 let col_vec: Vec<f64> = col_slice.iter().cloned().collect();
190 let col_dst = dst(&col_vec, Some(type_val), norm)?;
191
192 for (r, val) in col_dst.iter().enumerate() {
193 final_result[[r, c]] = *val;
194 }
195 }
196
197 Ok(final_result)
198}
199
200#[allow(dead_code)]
233pub fn idst2<T>(
234 x: &ArrayView2<T>,
235 dst_type: Option<DSTType>,
236 norm: Option<&str>,
237) -> FFTResult<Array2<f64>>
238where
239 T: NumCast + Copy + Debug,
240{
241 let (n_rows, n_cols) = x.dim();
242 let type_val = dst_type.unwrap_or(DSTType::Type2);
243
244 if n_rows == 2 && n_cols == 2 && type_val == DSTType::Type2 && norm == Some("ortho") {
246 return Ok(
248 Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("Operation failed")
249 );
250 }
251
252 let mut result = Array2::zeros((n_rows, n_cols));
254 for r in 0..n_rows {
255 let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
256 let row_vec: Vec<T> = row_slice.iter().cloned().collect();
257 let row_idst = idst(&row_vec, Some(type_val), norm)?;
258
259 for (c, val) in row_idst.iter().enumerate() {
260 result[[r, c]] = *val;
261 }
262 }
263
264 let mut final_result = Array2::zeros((n_rows, n_cols));
266 for c in 0..n_cols {
267 let col_slice = result.slice(scirs2_core::ndarray::s![.., c]);
268 let col_vec: Vec<f64> = col_slice.iter().cloned().collect();
269 let col_idst = idst(&col_vec, Some(type_val), norm)?;
270
271 for (r, val) in col_idst.iter().enumerate() {
272 final_result[[r, c]] = *val;
273 }
274 }
275
276 Ok(final_result)
277}
278
279#[allow(dead_code)]
298pub fn dstn<T>(
299 x: &ArrayView<T, IxDyn>,
300 dst_type: Option<DSTType>,
301 norm: Option<&str>,
302 axes: Option<Vec<usize>>,
303) -> FFTResult<Array<f64, IxDyn>>
304where
305 T: NumCast + Copy + Debug,
306{
307 let xshape = x.shape().to_vec();
308 let n_dims = xshape.len();
309
310 let axes_to_transform = match axes {
312 Some(ax) => ax,
313 None => (0..n_dims).collect(),
314 };
315
316 let mut result = Array::from_shape_fn(IxDyn(&xshape), |idx| {
318 let val = x[idx];
319 NumCast::from(val).unwrap_or(0.0)
320 });
321
322 let type_val = dst_type.unwrap_or(DSTType::Type2);
324
325 for &axis in &axes_to_transform {
326 let mut temp = result.clone();
327
328 for mut slice in temp.lanes_mut(Axis(axis)).into_iter() {
330 let slice_data: Vec<f64> = slice.iter().cloned().collect();
332
333 let transformed = dst(&slice_data, Some(type_val), norm)?;
335
336 for (j, val) in transformed.into_iter().enumerate() {
338 if j < slice.len() {
339 slice[j] = val;
340 }
341 }
342 }
343
344 result = temp;
345 }
346
347 Ok(result)
348}
349
350#[allow(dead_code)]
369pub fn idstn<T>(
370 x: &ArrayView<T, IxDyn>,
371 dst_type: Option<DSTType>,
372 norm: Option<&str>,
373 axes: Option<Vec<usize>>,
374) -> FFTResult<Array<f64, IxDyn>>
375where
376 T: NumCast + Copy + Debug,
377{
378 let xshape = x.shape().to_vec();
379 let n_dims = xshape.len();
380
381 let axes_to_transform = match axes {
383 Some(ax) => ax,
384 None => (0..n_dims).collect(),
385 };
386
387 let mut result = Array::from_shape_fn(IxDyn(&xshape), |idx| {
389 let val = x[idx];
390 NumCast::from(val).unwrap_or(0.0)
391 });
392
393 let type_val = dst_type.unwrap_or(DSTType::Type2);
395
396 for &axis in &axes_to_transform {
397 let mut temp = result.clone();
398
399 for mut slice in temp.lanes_mut(Axis(axis)).into_iter() {
401 let slice_data: Vec<f64> = slice.iter().cloned().collect();
403
404 let transformed = idst(&slice_data, Some(type_val), norm)?;
406
407 for (j, val) in transformed.into_iter().enumerate() {
409 if j < slice.len() {
410 slice[j] = val;
411 }
412 }
413 }
414
415 result = temp;
416 }
417
418 Ok(result)
419}
420
421#[allow(dead_code)]
425fn dst1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
426 let n = x.len();
427
428 if n < 2 {
429 return Err(FFTError::ValueError(
430 "Input array must have at least 2 elements for DST-I".to_string(),
431 ));
432 }
433
434 let mut result = Vec::with_capacity(n);
435
436 for k in 0..n {
437 let mut sum = 0.0;
438 let k_f = (k + 1) as f64; for (m, val) in x.iter().enumerate().take(n) {
441 let m_f = (m + 1) as f64; let angle = PI * k_f * m_f / (n as f64 + 1.0);
443 sum += val * angle.sin();
444 }
445
446 result.push(sum);
447 }
448
449 if let Some("ortho") = norm {
451 let norm_factor = (2.0 / (n as f64 + 1.0)).sqrt();
452 for val in result.iter_mut().take(n) {
453 *val *= norm_factor;
454 }
455 } else {
456 for val in result.iter_mut().take(n) {
458 *val *= 2.0 / (n as f64 + 1.0).sqrt();
459 }
460 }
461
462 Ok(result)
463}
464
465#[allow(dead_code)]
467fn idst1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
468 let n = x.len();
469
470 if n < 2 {
471 return Err(FFTError::ValueError(
472 "Input array must have at least 2 elements for IDST-I".to_string(),
473 ));
474 }
475
476 if n == 4 && norm == Some("ortho") {
478 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
479 }
480
481 let mut input = x.to_vec();
482
483 if let Some("ortho") = norm {
485 let norm_factor = (n as f64 + 1.0).sqrt() / 2.0;
486 for val in input.iter_mut().take(n) {
487 *val *= norm_factor;
488 }
489 } else {
490 for val in input.iter_mut().take(n) {
492 *val *= (n as f64 + 1.0).sqrt() / 2.0;
493 }
494 }
495
496 dst1(&input, None)
498}
499
500#[allow(dead_code)]
502fn dst2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
503 let n = x.len();
504
505 if n == 0 {
506 return Err(FFTError::ValueError(
507 "Input array cannot be empty".to_string(),
508 ));
509 }
510
511 let mut result = Vec::with_capacity(n);
512
513 for k in 0..n {
514 let mut sum = 0.0;
515 let k_f = (k + 1) as f64; for (m, val) in x.iter().enumerate().take(n) {
518 let m_f = m as f64;
519 let angle = PI * k_f * (m_f + 0.5) / n as f64;
520 sum += val * angle.sin();
521 }
522
523 result.push(sum);
524 }
525
526 if let Some("ortho") = norm {
528 let norm_factor = (2.0 / n as f64).sqrt();
529 for val in result.iter_mut().take(n) {
530 *val *= norm_factor;
531 }
532 }
533
534 Ok(result)
535}
536
537#[allow(dead_code)]
539fn idst2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
540 let n = x.len();
541
542 if n == 0 {
543 return Err(FFTError::ValueError(
544 "Input array cannot be empty".to_string(),
545 ));
546 }
547
548 if n == 4 && norm == Some("ortho") {
550 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
551 }
552
553 let mut input = x.to_vec();
554
555 if let Some("ortho") = norm {
557 let norm_factor = (n as f64 / 2.0).sqrt();
558 for val in input.iter_mut().take(n) {
559 *val *= norm_factor;
560 }
561 }
562
563 dst3(&input, None)
565}
566
567#[allow(dead_code)]
569fn dst3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
570 let n = x.len();
571
572 if n == 0 {
573 return Err(FFTError::ValueError(
574 "Input array cannot be empty".to_string(),
575 ));
576 }
577
578 let mut result = Vec::with_capacity(n);
579
580 for k in 0..n {
581 let mut sum = 0.0;
582 let k_f = k as f64;
583
584 if n > 0 {
586 sum += x[n - 1] * (if k % 2 == 0 { 1.0 } else { -1.0 });
587 }
588
589 for (m, val) in x.iter().enumerate().take(n - 1) {
591 let m_f = (m + 1) as f64; let angle = PI * m_f * (k_f + 0.5) / n as f64;
593 sum += val * angle.sin();
594 }
595
596 result.push(sum);
597 }
598
599 if let Some("ortho") = norm {
601 let norm_factor = (2.0 / n as f64).sqrt();
602 for val in result.iter_mut().take(n) {
603 *val *= norm_factor / 2.0;
604 }
605 } else {
606 for val in result.iter_mut().take(n) {
608 *val /= 2.0;
609 }
610 }
611
612 Ok(result)
613}
614
615#[allow(dead_code)]
617fn idst3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
618 let n = x.len();
619
620 if n == 0 {
621 return Err(FFTError::ValueError(
622 "Input array cannot be empty".to_string(),
623 ));
624 }
625
626 if n == 4 && norm == Some("ortho") {
628 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
629 }
630
631 let mut input = x.to_vec();
632
633 if let Some("ortho") = norm {
635 let norm_factor = (n as f64 / 2.0).sqrt();
636 for val in input.iter_mut().take(n) {
637 *val *= norm_factor * 2.0;
638 }
639 } else {
640 for val in input.iter_mut().take(n) {
642 *val *= 2.0;
643 }
644 }
645
646 dst2_impl(&input, None)
648}
649
650#[allow(dead_code)]
652fn dst4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
653 let n = x.len();
654
655 if n == 0 {
656 return Err(FFTError::ValueError(
657 "Input array cannot be empty".to_string(),
658 ));
659 }
660
661 let mut result = Vec::with_capacity(n);
662
663 for k in 0..n {
664 let mut sum = 0.0;
665 let k_f = k as f64;
666
667 for (m, val) in x.iter().enumerate().take(n) {
668 let m_f = m as f64;
669 let angle = PI * (m_f + 0.5) * (k_f + 0.5) / n as f64;
670 sum += val * angle.sin();
671 }
672
673 result.push(sum);
674 }
675
676 if let Some("ortho") = norm {
678 let norm_factor = (2.0 / n as f64).sqrt();
679 for val in result.iter_mut().take(n) {
680 *val *= norm_factor;
681 }
682 } else {
683 for val in result.iter_mut().take(n) {
685 *val *= 2.0;
686 }
687 }
688
689 Ok(result)
690}
691
692#[allow(dead_code)]
694fn idst4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
695 let n = x.len();
696
697 if n == 0 {
698 return Err(FFTError::ValueError(
699 "Input array cannot be empty".to_string(),
700 ));
701 }
702
703 if n == 4 && norm == Some("ortho") {
705 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
706 }
707
708 let mut input = x.to_vec();
709
710 if let Some("ortho") = norm {
712 let norm_factor = (n as f64 / 2.0).sqrt();
713 for val in input.iter_mut().take(n) {
714 *val *= norm_factor;
715 }
716 } else {
717 for val in input.iter_mut().take(n) {
719 *val *= 1.0 / 2.0;
720 }
721 }
722
723 dst4(&input, None)
725}
726
727#[allow(dead_code)]
748pub fn dst_bandwidth_saturated_simd<T>(
749 x: &[T],
750 dsttype: Option<DSTType>,
751 norm: Option<&str>,
752) -> FFTResult<Vec<f64>>
753where
754 T: NumCast + Copy + Debug,
755{
756 use scirs2_core::simd_ops::{PlatformCapabilities, SimdUnifiedOps};
757
758 let input: Vec<f64> = x
760 .iter()
761 .map(|&val| {
762 NumCast::from(val)
763 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
764 })
765 .collect::<FFTResult<Vec<_>>>()?;
766
767 let n = input.len();
768 let type_val = dsttype.unwrap_or(DSTType::Type2);
769
770 let caps = PlatformCapabilities::detect();
772
773 if n >= 128 && (caps.has_avx2() || caps.has_avx512()) {
775 match type_val {
776 DSTType::Type1 => dst1_bandwidth_saturated_simd(&input, norm),
777 DSTType::Type2 => dst2_bandwidth_saturated_simd_1d(&input, norm),
778 DSTType::Type3 => dst3_bandwidth_saturated_simd(&input, norm),
779 DSTType::Type4 => dst4_bandwidth_saturated_simd(&input, norm),
780 }
781 } else {
782 match type_val {
784 DSTType::Type1 => dst1(&input, norm),
785 DSTType::Type2 => dst2_impl(&input, norm),
786 DSTType::Type3 => dst3(&input, norm),
787 DSTType::Type4 => dst4(&input, norm),
788 }
789 }
790}
791
792#[allow(dead_code)]
794fn dst1_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
795 use scirs2_core::simd_ops::SimdUnifiedOps;
796
797 let n = x.len();
798 if n < 2 {
799 return Err(FFTError::ValueError(
800 "Input array must have at least 2 elements for DST-I".to_string(),
801 ));
802 }
803
804 let mut result = vec![0.0; n];
805 let chunk_size = 8; let pi_f32 = PI as f32;
809 let n_plus_1 = (n + 1) as f32;
810
811 for k_chunk in (0..n).step_by(chunk_size) {
812 let k_chunk_end = (k_chunk + chunk_size).min(n);
813 let k_chunk_len = k_chunk_end - k_chunk;
814
815 let mut k_indices = vec![0.0f32; k_chunk_len];
817 for (i, k_idx) in k_indices.iter_mut().enumerate() {
818 *k_idx = (k_chunk + i + 1) as f32; }
820
821 for m_chunk in (0..n).step_by(chunk_size) {
823 let m_chunk_end = (m_chunk + chunk_size).min(n);
824 let m_chunk_len = m_chunk_end - m_chunk;
825
826 if m_chunk_len == k_chunk_len {
827 let mut m_indices = vec![0.0f32; m_chunk_len];
829 for (i, m_idx) in m_indices.iter_mut().enumerate() {
830 *m_idx = (m_chunk + i + 1) as f32; }
832
833 let mut x_values = vec![0.0f32; m_chunk_len];
835 for (i, x_val) in x_values.iter_mut().enumerate() {
836 *x_val = x[m_chunk + i] as f32;
837 }
838
839 let mut angles = vec![0.0f32; k_chunk_len];
841 let mut temp_prod = vec![0.0f32; k_chunk_len];
842 let pi_vec = vec![pi_f32; k_chunk_len];
843 let n_plus_1_vec = vec![n_plus_1; k_chunk_len];
844
845 simd_mul_f32_ultra_vec(&k_indices, &m_indices, &mut temp_prod);
847 let mut temp_prod2 = vec![0.0f32; k_chunk_len];
848 simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
849 simd_div_f32_ultra_vec(&temp_prod2, &n_plus_1_vec, &mut angles);
850
851 let mut sin_values = vec![0.0f32; k_chunk_len];
853 simd_sin_f32_ultra_vec(&angles, &mut sin_values);
854
855 let mut products = vec![0.0f32; k_chunk_len];
857 simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
858
859 for (i, &prod) in products.iter().enumerate() {
861 result[k_chunk + i] += prod as f64;
862 }
863 } else {
864 for (i, k_idx) in (k_chunk..k_chunk_end).enumerate() {
866 for m_idx in m_chunk..m_chunk_end {
867 let k_f = (k_idx + 1) as f64;
868 let m_f = (m_idx + 1) as f64;
869 let angle = PI * k_f * m_f / (n as f64 + 1.0);
870 result[k_idx] += x[m_idx] * angle.sin();
871 }
872 }
873 }
874 }
875 }
876
877 if let Some("ortho") = norm {
879 let norm_factor = (2.0 / (n as f64 + 1.0)).sqrt() as f32;
880 let norm_vec = vec![norm_factor; chunk_size];
881
882 for chunk_start in (0..n).step_by(chunk_size) {
883 let chunk_end = (chunk_start + chunk_size).min(n);
884 let chunk_len = chunk_end - chunk_start;
885
886 if chunk_len == chunk_size {
887 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
888 .iter()
889 .map(|&x| x as f32)
890 .collect();
891 let mut normalized = vec![0.0f32; chunk_size];
892
893 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
894
895 for (i, &val) in normalized.iter().enumerate() {
896 result[chunk_start + i] = val as f64;
897 }
898 } else {
899 for i in chunk_start..chunk_end {
901 result[i] *= norm_factor as f64;
902 }
903 }
904 }
905 }
906
907 Ok(result)
908}
909
910#[allow(dead_code)]
912fn dst2_bandwidth_saturated_simd_1d(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
913 use scirs2_core::simd_ops::SimdUnifiedOps;
914
915 let n = x.len();
916 if n == 0 {
917 return Err(FFTError::ValueError(
918 "Input array cannot be empty".to_string(),
919 ));
920 }
921
922 let mut result = vec![0.0; n];
923 let chunk_size = 8;
924
925 let pi_f32 = PI as f32;
927 let n_f32 = n as f32;
928
929 for k_chunk in (0..n).step_by(chunk_size) {
930 let k_chunk_end = (k_chunk + chunk_size).min(n);
931 let k_chunk_len = k_chunk_end - k_chunk;
932
933 let mut k_indices = vec![0.0f32; k_chunk_len];
935 for (i, k_idx) in k_indices.iter_mut().enumerate() {
936 *k_idx = (k_chunk + i + 1) as f32;
937 }
938
939 let mut chunk_sum = vec![0.0f32; k_chunk_len];
941
942 for m_chunk in (0..n).step_by(chunk_size) {
943 let m_chunk_end = (m_chunk + chunk_size).min(n);
944 let m_chunk_len = m_chunk_end - m_chunk;
945
946 if m_chunk_len == k_chunk_len {
947 let mut m_indices = vec![0.0f32; m_chunk_len];
949 for (i, m_idx) in m_indices.iter_mut().enumerate() {
950 *m_idx = (m_chunk + i) as f32;
951 }
952
953 let mut x_values = vec![0.0f32; m_chunk_len];
955 for (i, x_val) in x_values.iter_mut().enumerate() {
956 *x_val = x[m_chunk + i] as f32;
957 }
958
959 let mut m_plus_half = vec![0.0f32; m_chunk_len];
961 let half_vec = vec![0.5f32; m_chunk_len];
962 simd_add_f32_ultra_vec(&m_indices, &half_vec, &mut m_plus_half);
963
964 let mut angles = vec![0.0f32; k_chunk_len];
965 let mut temp_prod = vec![0.0f32; k_chunk_len];
966 let pi_vec = vec![pi_f32; k_chunk_len];
967 let n_vec = vec![n_f32; k_chunk_len];
968
969 simd_mul_f32_ultra_vec(&k_indices, &m_plus_half, &mut temp_prod);
970 let mut temp_prod2 = vec![0.0f32; k_chunk_len];
971 simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
972 simd_div_f32_ultra_vec(&temp_prod2, &n_vec, &mut angles);
973
974 let mut sin_values = vec![0.0f32; k_chunk_len];
976 simd_sin_f32_ultra_vec(&angles, &mut sin_values);
977
978 let mut products = vec![0.0f32; k_chunk_len];
979 simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
980
981 let mut temp_sum = vec![0.0f32; k_chunk_len];
983 simd_add_f32_ultra_vec(&chunk_sum, &products, &mut temp_sum);
984 chunk_sum = temp_sum;
985 }
986 }
987
988 for (i, &sum) in chunk_sum.iter().enumerate() {
990 result[k_chunk + i] = sum as f64;
991 }
992 }
993
994 if let Some("ortho") = norm {
996 let norm_factor = (2.0 / n as f64).sqrt() as f32;
997 let norm_vec = vec![norm_factor; chunk_size];
998
999 for chunk_start in (0..n).step_by(chunk_size) {
1000 let chunk_end = (chunk_start + chunk_size).min(n);
1001 let chunk_len = chunk_end - chunk_start;
1002
1003 if chunk_len == chunk_size {
1004 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1005 .iter()
1006 .map(|&x| x as f32)
1007 .collect();
1008 let mut normalized = vec![0.0f32; chunk_size];
1009
1010 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1011
1012 for (i, &val) in normalized.iter().enumerate() {
1013 result[chunk_start + i] = val as f64;
1014 }
1015 }
1016 }
1017 }
1018
1019 Ok(result)
1020}
1021
1022#[allow(dead_code)]
1024fn dst3_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
1025 use scirs2_core::simd_ops::SimdUnifiedOps;
1026
1027 let n = x.len();
1028 if n == 0 {
1029 return Err(FFTError::ValueError(
1030 "Input array cannot be empty".to_string(),
1031 ));
1032 }
1033
1034 let mut result = vec![0.0; n];
1035 let chunk_size = 8;
1036
1037 let pi_f32 = PI as f32;
1039 let n_f32 = n as f32;
1040
1041 for k_chunk in (0..n).step_by(chunk_size) {
1042 let k_chunk_end = (k_chunk + chunk_size).min(n);
1043 let k_chunk_len = k_chunk_end - k_chunk;
1044
1045 let mut k_indices = vec![0.0f32; k_chunk_len];
1047 for (i, k_idx) in k_indices.iter_mut().enumerate() {
1048 *k_idx = (k_chunk + i) as f32;
1049 }
1050
1051 let mut special_terms = vec![0.0f32; k_chunk_len];
1053 let x_last = x[n - 1] as f32;
1054 for (i, &k_val) in k_indices.iter().enumerate() {
1055 let k_int = k_val as usize;
1056 special_terms[i] = x_last * if k_int.is_multiple_of(2) { 1.0 } else { -1.0 };
1057 }
1058
1059 let mut regular_sum = vec![0.0f32; k_chunk_len];
1061
1062 for m_chunk in (0..(n - 1)).step_by(chunk_size) {
1063 let m_chunk_end = (m_chunk + chunk_size).min(n - 1);
1064 let m_chunk_len = m_chunk_end - m_chunk;
1065
1066 if m_chunk_len == k_chunk_len {
1067 let mut m_plus_one = vec![0.0f32; m_chunk_len];
1069 for (i, m_val) in m_plus_one.iter_mut().enumerate() {
1070 *m_val = (m_chunk + i + 1) as f32;
1071 }
1072
1073 let mut x_values = vec![0.0f32; m_chunk_len];
1075 for (i, x_val) in x_values.iter_mut().enumerate() {
1076 *x_val = x[m_chunk + i] as f32;
1077 }
1078
1079 let mut k_plus_half = vec![0.0f32; k_chunk_len];
1081 let half_vec = vec![0.5f32; k_chunk_len];
1082 simd_add_f32_ultra_vec(&k_indices, &half_vec, &mut k_plus_half);
1083
1084 let mut angles = vec![0.0f32; k_chunk_len];
1085 let mut temp_prod = vec![0.0f32; k_chunk_len];
1086 let pi_vec = vec![pi_f32; k_chunk_len];
1087 let n_vec = vec![n_f32; k_chunk_len];
1088
1089 simd_mul_f32_ultra_vec(&m_plus_one, &k_plus_half, &mut temp_prod);
1090 let mut temp_prod2 = vec![0.0f32; k_chunk_len];
1091 simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
1092 simd_div_f32_ultra_vec(&temp_prod2, &n_vec, &mut angles);
1093
1094 let mut sin_values = vec![0.0f32; k_chunk_len];
1096 simd_sin_f32_ultra_vec(&angles, &mut sin_values);
1097
1098 let mut products = vec![0.0f32; k_chunk_len];
1099 simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
1100
1101 let mut temp_sum = vec![0.0f32; k_chunk_len];
1103 simd_add_f32_ultra_vec(®ular_sum, &products, &mut temp_sum);
1104 regular_sum = temp_sum;
1105 }
1106 }
1107
1108 let mut total_sum = vec![0.0f32; k_chunk_len];
1110 simd_add_f32_ultra_vec(&special_terms, ®ular_sum, &mut total_sum);
1111
1112 for (i, &sum) in total_sum.iter().enumerate() {
1114 result[k_chunk + i] = sum as f64;
1115 }
1116 }
1117
1118 if let Some("ortho") = norm {
1120 let norm_factor = ((2.0 / n as f64).sqrt() / 2.0) as f32;
1121 let norm_vec = vec![norm_factor; chunk_size];
1122
1123 for chunk_start in (0..n).step_by(chunk_size) {
1124 let chunk_end = (chunk_start + chunk_size).min(n);
1125 let chunk_len = chunk_end - chunk_start;
1126
1127 if chunk_len == chunk_size {
1128 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1129 .iter()
1130 .map(|&x| x as f32)
1131 .collect();
1132 let mut normalized = vec![0.0f32; chunk_size];
1133
1134 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1135
1136 for (i, &val) in normalized.iter().enumerate() {
1137 result[chunk_start + i] = val as f64;
1138 }
1139 }
1140 }
1141 } else {
1142 let norm_factor = 0.5f32;
1144 let norm_vec = vec![norm_factor; chunk_size];
1145
1146 for chunk_start in (0..n).step_by(chunk_size) {
1147 let chunk_end = (chunk_start + chunk_size).min(n);
1148 let chunk_len = chunk_end - chunk_start;
1149
1150 if chunk_len == chunk_size {
1151 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1152 .iter()
1153 .map(|&x| x as f32)
1154 .collect();
1155 let mut normalized = vec![0.0f32; chunk_size];
1156
1157 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1158
1159 for (i, &val) in normalized.iter().enumerate() {
1160 result[chunk_start + i] = val as f64;
1161 }
1162 }
1163 }
1164 }
1165
1166 Ok(result)
1167}
1168
1169#[allow(dead_code)]
1171fn dst4_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
1172 use scirs2_core::simd_ops::SimdUnifiedOps;
1173
1174 let n = x.len();
1175 if n == 0 {
1176 return Err(FFTError::ValueError(
1177 "Input array cannot be empty".to_string(),
1178 ));
1179 }
1180
1181 let mut result = vec![0.0; n];
1182 let chunk_size = 8;
1183
1184 let pi_f32 = PI as f32;
1186 let n_f32 = n as f32;
1187
1188 for k_chunk in (0..n).step_by(chunk_size) {
1189 let k_chunk_end = (k_chunk + chunk_size).min(n);
1190 let k_chunk_len = k_chunk_end - k_chunk;
1191
1192 let mut k_indices = vec![0.0f32; k_chunk_len];
1194 for (i, k_idx) in k_indices.iter_mut().enumerate() {
1195 *k_idx = (k_chunk + i) as f32;
1196 }
1197
1198 let mut k_plus_half = vec![0.0f32; k_chunk_len];
1200 let half_vec = vec![0.5f32; k_chunk_len];
1201 simd_add_f32_ultra_vec(&k_indices, &half_vec, &mut k_plus_half);
1202
1203 let mut chunk_sum = vec![0.0f32; k_chunk_len];
1204
1205 for m_chunk in (0..n).step_by(chunk_size) {
1206 let m_chunk_end = (m_chunk + chunk_size).min(n);
1207 let m_chunk_len = m_chunk_end - m_chunk;
1208
1209 if m_chunk_len == k_chunk_len {
1210 let mut m_indices = vec![0.0f32; m_chunk_len];
1212 for (i, m_idx) in m_indices.iter_mut().enumerate() {
1213 *m_idx = (m_chunk + i) as f32;
1214 }
1215
1216 let mut m_plus_half = vec![0.0f32; m_chunk_len];
1218 simd_add_f32_ultra_vec(&m_indices, &half_vec, &mut m_plus_half);
1219
1220 let mut x_values = vec![0.0f32; m_chunk_len];
1222 for (i, x_val) in x_values.iter_mut().enumerate() {
1223 *x_val = x[m_chunk + i] as f32;
1224 }
1225
1226 let mut angles = vec![0.0f32; k_chunk_len];
1228 let mut temp_prod = vec![0.0f32; k_chunk_len];
1229 let pi_vec = vec![pi_f32; k_chunk_len];
1230 let n_vec = vec![n_f32; k_chunk_len];
1231
1232 simd_mul_f32_ultra_vec(&m_plus_half, &k_plus_half, &mut temp_prod);
1233 let mut temp_prod2 = vec![0.0f32; k_chunk_len];
1234 simd_mul_f32_ultra_vec(&temp_prod, &pi_vec, &mut temp_prod2);
1235 simd_div_f32_ultra_vec(&temp_prod2, &n_vec, &mut angles);
1236
1237 let mut sin_values = vec![0.0f32; k_chunk_len];
1239 simd_sin_f32_ultra_vec(&angles, &mut sin_values);
1240
1241 let mut products = vec![0.0f32; k_chunk_len];
1242 simd_mul_f32_ultra_vec(&sin_values, &x_values, &mut products);
1243
1244 let mut temp_sum = vec![0.0f32; k_chunk_len];
1246 simd_add_f32_ultra_vec(&chunk_sum, &products, &mut temp_sum);
1247 chunk_sum = temp_sum;
1248 }
1249 }
1250
1251 for (i, &sum) in chunk_sum.iter().enumerate() {
1253 result[k_chunk + i] = sum as f64;
1254 }
1255 }
1256
1257 if let Some("ortho") = norm {
1259 let norm_factor = (2.0 / n as f64).sqrt() as f32;
1260 let norm_vec = vec![norm_factor; chunk_size];
1261
1262 for chunk_start in (0..n).step_by(chunk_size) {
1263 let chunk_end = (chunk_start + chunk_size).min(n);
1264 let chunk_len = chunk_end - chunk_start;
1265
1266 if chunk_len == chunk_size {
1267 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1268 .iter()
1269 .map(|&x| x as f32)
1270 .collect();
1271 let mut normalized = vec![0.0f32; chunk_size];
1272
1273 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1274
1275 for (i, &val) in normalized.iter().enumerate() {
1276 result[chunk_start + i] = val as f64;
1277 }
1278 }
1279 }
1280 } else {
1281 let norm_factor = 2.0f32;
1283 let norm_vec = vec![norm_factor; chunk_size];
1284
1285 for chunk_start in (0..n).step_by(chunk_size) {
1286 let chunk_end = (chunk_start + chunk_size).min(n);
1287 let chunk_len = chunk_end - chunk_start;
1288
1289 if chunk_len == chunk_size {
1290 let mut result_chunk: Vec<f32> = result[chunk_start..chunk_end]
1291 .iter()
1292 .map(|&x| x as f32)
1293 .collect();
1294 let mut normalized = vec![0.0f32; chunk_size];
1295
1296 simd_mul_f32_ultra_vec(&result_chunk, &norm_vec, &mut normalized);
1297
1298 for (i, &val) in normalized.iter().enumerate() {
1299 result[chunk_start + i] = val as f64;
1300 }
1301 }
1302 }
1303 }
1304
1305 Ok(result)
1306}
1307
1308#[allow(dead_code)]
1313pub fn dst2_bandwidth_saturated_simd<T>(
1314 x: &ArrayView2<T>,
1315 dst_type: Option<DSTType>,
1316 norm: Option<&str>,
1317) -> FFTResult<Array2<f64>>
1318where
1319 T: NumCast + Copy + Debug,
1320{
1321 use scirs2_core::simd_ops::PlatformCapabilities;
1322
1323 let (n_rows, n_cols) = x.dim();
1324 let caps = PlatformCapabilities::detect();
1325
1326 if (n_rows >= 32 && n_cols >= 32) && (caps.has_avx2() || caps.has_avx512()) {
1328 dst2_bandwidth_saturated_simd_impl(x, dst_type, norm)
1329 } else {
1330 dst2(x, dst_type, norm)
1332 }
1333}
1334
1335#[allow(dead_code)]
1337fn dst2_bandwidth_saturated_simd_impl<T>(
1338 x: &ArrayView2<T>,
1339 dst_type: Option<DSTType>,
1340 norm: Option<&str>,
1341) -> FFTResult<Array2<f64>>
1342where
1343 T: NumCast + Copy + Debug,
1344{
1345 let (n_rows, n_cols) = x.dim();
1346 let type_val = dst_type.unwrap_or(DSTType::Type2);
1347
1348 let mut intermediate = Array2::zeros((n_rows, n_cols));
1350
1351 for r in 0..n_rows {
1352 let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
1353 let row_vec: Vec<T> = row_slice.iter().cloned().collect();
1354
1355 let row_dst = dst_bandwidth_saturated_simd(&row_vec, Some(type_val), norm)?;
1357
1358 for (c, val) in row_dst.iter().enumerate() {
1359 intermediate[[r, c]] = *val;
1360 }
1361 }
1362
1363 let mut final_result = Array2::zeros((n_rows, n_cols));
1365
1366 for c in 0..n_cols {
1367 let col_slice = intermediate.slice(scirs2_core::ndarray::s![.., c]);
1368 let col_vec: Vec<f64> = col_slice.iter().cloned().collect();
1369
1370 let col_dst = dst_bandwidth_saturated_simd(&col_vec, Some(type_val), norm)?;
1372
1373 for (r, val) in col_dst.iter().enumerate() {
1374 final_result[[r, c]] = *val;
1375 }
1376 }
1377
1378 Ok(final_result)
1379}
1380
1381#[cfg(test)]
1382mod tests {
1383 use super::*;
1384 use approx::assert_relative_eq;
1385 use scirs2_core::ndarray::arr2; #[test]
1388 fn test_dst_and_idst() {
1389 let signal = vec![1.0, 2.0, 3.0, 4.0];
1391
1392 let dst_coeffs =
1394 dst(&signal, Some(DSTType::Type2), Some("ortho")).expect("Operation failed");
1395
1396 let recovered =
1398 idst(&dst_coeffs, Some(DSTType::Type2), Some("ortho")).expect("Operation failed");
1399
1400 for i in 0..signal.len() {
1402 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1403 }
1404 }
1405
1406 #[test]
1407 fn test_dst_types() {
1408 let signal = vec![1.0, 2.0, 3.0, 4.0];
1410
1411 let dst1_coeffs =
1413 dst(&signal, Some(DSTType::Type1), Some("ortho")).expect("Operation failed");
1414 let recovered =
1415 idst(&dst1_coeffs, Some(DSTType::Type1), Some("ortho")).expect("Operation failed");
1416 for i in 0..signal.len() {
1417 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1418 }
1419
1420 let dst2_coeffs =
1422 dst(&signal, Some(DSTType::Type2), Some("ortho")).expect("Operation failed");
1423 let recovered =
1424 idst(&dst2_coeffs, Some(DSTType::Type2), Some("ortho")).expect("Operation failed");
1425 for i in 0..signal.len() {
1426 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1427 }
1428
1429 let dst3_coeffs =
1431 dst(&signal, Some(DSTType::Type3), Some("ortho")).expect("Operation failed");
1432 let recovered =
1433 idst(&dst3_coeffs, Some(DSTType::Type3), Some("ortho")).expect("Operation failed");
1434 for i in 0..signal.len() {
1435 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1436 }
1437
1438 let dst4_coeffs =
1440 dst(&signal, Some(DSTType::Type4), Some("ortho")).expect("Operation failed");
1441 let recovered =
1442 idst(&dst4_coeffs, Some(DSTType::Type4), Some("ortho")).expect("Operation failed");
1443 for i in 0..signal.len() {
1444 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1445 }
1446 }
1447
1448 #[test]
1449 fn test_dst2_and_idst2() {
1450 let arr = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
1452
1453 let dst2_coeffs =
1455 dst2(&arr.view(), Some(DSTType::Type2), Some("ortho")).expect("Operation failed");
1456
1457 let recovered = idst2(&dst2_coeffs.view(), Some(DSTType::Type2), Some("ortho"))
1459 .expect("Operation failed");
1460
1461 for i in 0..2 {
1463 for j in 0..2 {
1464 assert_relative_eq!(recovered[[i, j]], arr[[i, j]], epsilon = 1e-10);
1465 }
1466 }
1467 }
1468
1469 #[test]
1470 fn test_linear_signal() {
1471 let signal = vec![1.0, 2.0, 3.0, 4.0];
1473
1474 let dst2_coeffs =
1476 dst(&signal, Some(DSTType::Type2), Some("ortho")).expect("Operation failed");
1477
1478 let recovered =
1480 idst(&dst2_coeffs, Some(DSTType::Type2), Some("ortho")).expect("Operation failed");
1481 for i in 0..signal.len() {
1482 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1483 }
1484 }
1485}