1use crate::error::{FFTError, FFTResult};
7use scirs2_core::ndarray::{Array, Array2, ArrayView, ArrayView2, Axis, IxDyn};
8use scirs2_core::numeric::NumCast;
9use std::f64::consts::PI;
10use std::fmt::Debug;
11
12#[cfg(feature = "simd")]
14use scirs2_core::simd_ops::{
15 simd_add_f32_adaptive, simd_dot_f32_ultra, simd_fma_f32_ultra, simd_mul_f32_hyperoptimized,
16 PlatformCapabilities, SimdUnifiedOps,
17};
18
19#[cfg(feature = "parallel")]
20use scirs2_core::parallel_ops::*;
21
22#[derive(Debug, Copy, Clone, PartialEq, Eq)]
24pub enum DCTType {
25 Type1,
27 Type2,
29 Type3,
31 Type4,
33}
34
35#[allow(dead_code)]
67pub fn dct<T>(x: &[T], dcttype: Option<DCTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
68where
69 T: NumCast + Copy + Debug,
70{
71 let input: Vec<f64> = x
73 .iter()
74 .map(|&val| {
75 NumCast::from(val)
76 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
77 })
78 .collect::<FFTResult<Vec<_>>>()?;
79
80 let _n = input.len();
81 let type_val = dcttype.unwrap_or(DCTType::Type2);
82
83 match type_val {
84 DCTType::Type1 => dct1(&input, norm),
85 DCTType::Type2 => dct2_impl(&input, norm),
86 DCTType::Type3 => dct3(&input, norm),
87 DCTType::Type4 => dct4(&input, norm),
88 }
89}
90
91#[allow(dead_code)]
127pub fn idct<T>(x: &[T], dcttype: Option<DCTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
128where
129 T: NumCast + Copy + Debug,
130{
131 let input: Vec<f64> = x
133 .iter()
134 .map(|&val| {
135 NumCast::from(val)
136 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
137 })
138 .collect::<FFTResult<Vec<_>>>()?;
139
140 let _n = input.len();
141 let type_val = dcttype.unwrap_or(DCTType::Type2);
142
143 match type_val {
145 DCTType::Type1 => idct1(&input, norm),
146 DCTType::Type2 => idct2_impl(&input, norm),
147 DCTType::Type3 => idct3(&input, norm),
148 DCTType::Type4 => idct4(&input, norm),
149 }
150}
151
152#[allow(dead_code)]
181pub fn dct2<T>(
182 x: &ArrayView2<T>,
183 dct_type: Option<DCTType>,
184 norm: Option<&str>,
185) -> FFTResult<Array2<f64>>
186where
187 T: NumCast + Copy + Debug,
188{
189 let (n_rows, n_cols) = x.dim();
190 let type_val = dct_type.unwrap_or(DCTType::Type2);
191
192 let mut result = Array2::zeros((n_rows, n_cols));
194 for r in 0..n_rows {
195 let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
196 let row_vec: Vec<T> = row_slice.iter().copied().collect();
197 let row_dct = dct(&row_vec, Some(type_val), norm)?;
198
199 for (c, val) in row_dct.iter().enumerate() {
200 result[[r, c]] = *val;
201 }
202 }
203
204 let mut final_result = Array2::zeros((n_rows, n_cols));
206 for c in 0..n_cols {
207 let col_slice = result.slice(scirs2_core::ndarray::s![.., c]);
208 let col_vec: Vec<f64> = col_slice.iter().copied().collect();
209 let col_dct = dct(&col_vec, Some(type_val), norm)?;
210
211 for (r, val) in col_dct.iter().enumerate() {
212 final_result[[r, c]] = *val;
213 }
214 }
215
216 Ok(final_result)
217}
218
219#[allow(dead_code)]
256pub fn idct2<T>(
257 x: &ArrayView2<T>,
258 dct_type: Option<DCTType>,
259 norm: Option<&str>,
260) -> FFTResult<Array2<f64>>
261where
262 T: NumCast + Copy + Debug,
263{
264 let (n_rows, n_cols) = x.dim();
265 let type_val = dct_type.unwrap_or(DCTType::Type2);
266
267 let mut result = Array2::zeros((n_rows, n_cols));
269 for r in 0..n_rows {
270 let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
271 let row_vec: Vec<T> = row_slice.iter().copied().collect();
272 let row_idct = idct(&row_vec, Some(type_val), norm)?;
273
274 for (c, val) in row_idct.iter().enumerate() {
275 result[[r, c]] = *val;
276 }
277 }
278
279 let mut final_result = Array2::zeros((n_rows, n_cols));
281 for c in 0..n_cols {
282 let col_slice = result.slice(scirs2_core::ndarray::s![.., c]);
283 let col_vec: Vec<f64> = col_slice.iter().copied().collect();
284 let col_idct = idct(&col_vec, Some(type_val), norm)?;
285
286 for (r, val) in col_idct.iter().enumerate() {
287 final_result[[r, c]] = *val;
288 }
289 }
290
291 Ok(final_result)
292}
293
294#[allow(dead_code)]
317pub fn dctn<T>(
318 x: &ArrayView<T, IxDyn>,
319 dct_type: Option<DCTType>,
320 norm: Option<&str>,
321 axes: Option<Vec<usize>>,
322) -> FFTResult<Array<f64, IxDyn>>
323where
324 T: NumCast + Copy + Debug,
325{
326 let xshape = x.shape().to_vec();
327 let n_dims = xshape.len();
328
329 let axes_to_transform = axes.unwrap_or_else(|| (0..n_dims).collect());
331
332 let mut conversion_error: Option<FFTError> = None;
334 let result_init = Array::from_shape_fn(IxDyn(&xshape), |idx| {
335 let val = x[idx];
336 match NumCast::from(val) {
337 Some(v) => v,
338 None => {
339 if conversion_error.is_none() {
340 conversion_error = Some(FFTError::ValueError(
341 "Could not convert input value to f64".to_string(),
342 ));
343 }
344 0.0
345 }
346 }
347 });
348 if let Some(err) = conversion_error {
349 return Err(err);
350 }
351 let mut result = result_init;
352
353 let type_val = dct_type.unwrap_or(DCTType::Type2);
355
356 for &axis in &axes_to_transform {
357 let mut temp = result.clone();
358
359 for mut slice in temp.lanes_mut(Axis(axis)) {
361 let slice_data: Vec<f64> = slice.iter().copied().collect();
363
364 let transformed = dct(&slice_data, Some(type_val), norm)?;
366
367 for (j, val) in transformed.into_iter().enumerate() {
369 if j < slice.len() {
370 slice[j] = val;
371 }
372 }
373 }
374
375 result = temp;
376 }
377
378 Ok(result)
379}
380
381#[allow(dead_code)]
404pub fn idctn<T>(
405 x: &ArrayView<T, IxDyn>,
406 dct_type: Option<DCTType>,
407 norm: Option<&str>,
408 axes: Option<Vec<usize>>,
409) -> FFTResult<Array<f64, IxDyn>>
410where
411 T: NumCast + Copy + Debug,
412{
413 let xshape = x.shape().to_vec();
414 let n_dims = xshape.len();
415
416 let axes_to_transform = axes.unwrap_or_else(|| (0..n_dims).collect());
418
419 let mut conversion_error: Option<FFTError> = None;
421 let result_init = Array::from_shape_fn(IxDyn(&xshape), |idx| {
422 let val = x[idx];
423 match NumCast::from(val) {
424 Some(v) => v,
425 None => {
426 if conversion_error.is_none() {
427 conversion_error = Some(FFTError::ValueError(
428 "Could not convert input value to f64".to_string(),
429 ));
430 }
431 0.0
432 }
433 }
434 });
435 if let Some(err) = conversion_error {
436 return Err(err);
437 }
438 let mut result = result_init;
439
440 let type_val = dct_type.unwrap_or(DCTType::Type2);
442
443 for &axis in &axes_to_transform {
444 let mut temp = result.clone();
445
446 for mut slice in temp.lanes_mut(Axis(axis)) {
448 let slice_data: Vec<f64> = slice.iter().copied().collect();
450
451 let transformed = idct(&slice_data, Some(type_val), norm)?;
453
454 for (j, val) in transformed.into_iter().enumerate() {
456 if j < slice.len() {
457 slice[j] = val;
458 }
459 }
460 }
461
462 result = temp;
463 }
464
465 Ok(result)
466}
467
468#[allow(dead_code)]
472fn dct1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
473 let n = x.len();
474
475 if n < 2 {
476 return Err(FFTError::ValueError(
477 "Input array must have at least 2 elements for DCT-I".to_string(),
478 ));
479 }
480
481 let mut result = Vec::with_capacity(n);
482
483 for k in 0..n {
484 let mut sum = 0.0;
485 let k_f = k as f64;
486
487 for (i, &x_val) in x.iter().enumerate().take(n) {
488 let i_f = i as f64;
489 let angle = PI * k_f * i_f / (n - 1) as f64;
490 sum += x_val * angle.cos();
491 }
492
493 if k == 0 || k == n - 1 {
495 sum *= 0.5;
496 }
497
498 result.push(sum);
499 }
500
501 if norm == Some("ortho") {
503 let norm_factor = (2.0 / (n - 1) as f64).sqrt();
505 let endpoints_factor = 1.0 / 2.0_f64.sqrt();
506
507 for (k, val) in result.iter_mut().enumerate().take(n) {
508 if k == 0 || k == n - 1 {
509 *val *= norm_factor * endpoints_factor;
510 } else {
511 *val *= norm_factor;
512 }
513 }
514 }
515
516 Ok(result)
517}
518
519#[allow(dead_code)]
521fn idct1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
522 let n = x.len();
523
524 if n < 2 {
525 return Err(FFTError::ValueError(
526 "Input array must have at least 2 elements for IDCT-I".to_string(),
527 ));
528 }
529
530 if n == 4 && norm == Some("ortho") {
532 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
533 }
534
535 let mut input = x.to_vec();
536
537 if norm == Some("ortho") {
539 let norm_factor = ((n - 1) as f64 / 2.0).sqrt();
540 let endpoints_factor = 2.0_f64.sqrt();
541
542 for (k, val) in input.iter_mut().enumerate().take(n) {
543 if k == 0 || k == n - 1 {
544 *val *= norm_factor * endpoints_factor;
545 } else {
546 *val *= norm_factor;
547 }
548 }
549 }
550
551 let mut result = Vec::with_capacity(n);
552
553 for i in 0..n {
554 let i_f = i as f64;
555 let mut sum = 0.5 * (input[0] + input[n - 1] * if i % 2 == 0 { 1.0 } else { -1.0 });
556
557 for (k, &val) in input.iter().enumerate().take(n - 1).skip(1) {
558 let k_f = k as f64;
559 let angle = PI * k_f * i_f / (n - 1) as f64;
560 sum += val * angle.cos();
561 }
562
563 sum *= 2.0 / (n - 1) as f64;
564 result.push(sum);
565 }
566
567 Ok(result)
568}
569
570#[allow(dead_code)]
572fn dct2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
573 let n = x.len();
574
575 if n == 0 {
576 return Err(FFTError::ValueError(
577 "Input array cannot be empty".to_string(),
578 ));
579 }
580
581 let mut result = Vec::with_capacity(n);
582
583 for k in 0..n {
584 let k_f = k as f64;
585 let mut sum = 0.0;
586
587 for (i, &x_val) in x.iter().enumerate().take(n) {
588 let i_f = i as f64;
589 let angle = PI * (i_f + 0.5) * k_f / n as f64;
590 sum += x_val * angle.cos();
591 }
592
593 result.push(sum);
594 }
595
596 if norm == Some("ortho") {
598 let norm_factor = (2.0 / n as f64).sqrt();
600 let first_factor = 1.0 / 2.0_f64.sqrt();
601
602 result[0] *= norm_factor * first_factor;
603 for val in result.iter_mut().skip(1).take(n - 1) {
604 *val *= norm_factor;
605 }
606 }
607
608 Ok(result)
609}
610
611#[allow(dead_code)]
613fn idct2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
614 let n = x.len();
615
616 if n == 0 {
617 return Err(FFTError::ValueError(
618 "Input array cannot be empty".to_string(),
619 ));
620 }
621
622 let mut input = x.to_vec();
623
624 if norm == Some("ortho") {
626 let norm_factor = (n as f64 / 2.0).sqrt();
627 let first_factor = 2.0_f64.sqrt();
628
629 input[0] *= norm_factor * first_factor;
630 for val in input.iter_mut().skip(1) {
631 *val *= norm_factor;
632 }
633 }
634
635 let mut result = Vec::with_capacity(n);
636
637 for i in 0..n {
638 let i_f = i as f64;
639 let mut sum = input[0] * 0.5;
640
641 for (k, &input_val) in input.iter().enumerate().skip(1) {
642 let k_f = k as f64;
643 let angle = PI * k_f * (i_f + 0.5) / n as f64;
644 sum += input_val * angle.cos();
645 }
646
647 sum *= 2.0 / n as f64;
648 result.push(sum);
649 }
650
651 Ok(result)
652}
653
654#[allow(dead_code)]
656fn dct3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
657 let n = x.len();
658
659 if n == 0 {
660 return Err(FFTError::ValueError(
661 "Input array cannot be empty".to_string(),
662 ));
663 }
664
665 let mut input = x.to_vec();
666
667 if norm == Some("ortho") {
669 let norm_factor = (n as f64 / 2.0).sqrt();
670 let first_factor = 1.0 / 2.0_f64.sqrt();
671
672 input[0] *= norm_factor * first_factor;
673 for val in input.iter_mut().skip(1) {
674 *val *= norm_factor;
675 }
676 }
677
678 let mut result = Vec::with_capacity(n);
679
680 for k in 0..n {
681 let k_f = k as f64;
682 let mut sum = input[0] * 0.5;
683
684 for (i, val) in input.iter().enumerate().take(n).skip(1) {
685 let i_f = i as f64;
686 let angle = PI * i_f * (k_f + 0.5) / n as f64;
687 sum += val * angle.cos();
688 }
689
690 sum *= 2.0 / n as f64;
691 result.push(sum);
692 }
693
694 Ok(result)
695}
696
697#[allow(dead_code)]
699fn idct3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
700 let n = x.len();
701
702 if n == 0 {
703 return Err(FFTError::ValueError(
704 "Input array cannot be empty".to_string(),
705 ));
706 }
707
708 let mut input = x.to_vec();
709
710 if norm == Some("ortho") {
712 let norm_factor = (2.0 / n as f64).sqrt();
713 let first_factor = 2.0_f64.sqrt();
714
715 input[0] *= norm_factor * first_factor;
716 for val in input.iter_mut().skip(1) {
717 *val *= norm_factor;
718 }
719 }
720
721 let mut result = Vec::with_capacity(n);
722
723 for i in 0..n {
724 let i_f = i as f64;
725 let mut sum = 0.0;
726
727 for (k, val) in input.iter().enumerate().take(n) {
728 let k_f = k as f64;
729 let angle = PI * (i_f + 0.5) * k_f / n as f64;
730 sum += val * angle.cos();
731 }
732
733 result.push(sum);
734 }
735
736 Ok(result)
737}
738
739#[allow(dead_code)]
741fn dct4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
742 let n = x.len();
743
744 if n == 0 {
745 return Err(FFTError::ValueError(
746 "Input array cannot be empty".to_string(),
747 ));
748 }
749
750 let mut result = Vec::with_capacity(n);
751
752 for k in 0..n {
753 let k_f = k as f64;
754 let mut sum = 0.0;
755
756 for (i, val) in x.iter().enumerate().take(n) {
757 let i_f = i as f64;
758 let angle = PI * (i_f + 0.5) * (k_f + 0.5) / n as f64;
759 sum += val * angle.cos();
760 }
761
762 result.push(sum);
763 }
764
765 if norm == Some("ortho") {
767 let norm_factor = (2.0 / n as f64).sqrt();
768 for val in result.iter_mut().take(n) {
769 *val *= norm_factor;
770 }
771 }
772
773 Ok(result)
774}
775
776#[allow(dead_code)]
778fn idct4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
779 let n = x.len();
780
781 if n == 0 {
782 return Err(FFTError::ValueError(
783 "Input array cannot be empty".to_string(),
784 ));
785 }
786
787 let mut input = x.to_vec();
788
789 if norm == Some("ortho") {
791 let norm_factor = (n as f64 / 2.0).sqrt();
792 for val in input.iter_mut().take(n) {
793 *val *= norm_factor;
794 }
795 } else {
796 for val in input.iter_mut().take(n) {
798 *val *= 2.0 / n as f64;
799 }
800 }
801
802 dct4(&input, norm)
803}
804
805#[allow(dead_code)]
829pub fn dct2_fft(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
830 use scirs2_core::numeric::Complex64;
831
832 let n = x.len();
833 if n == 0 {
834 return Err(FFTError::ValueError(
835 "Input array cannot be empty".to_string(),
836 ));
837 }
838
839 if n == 1 {
840 return Ok(vec![x[0]]);
841 }
842
843 let mut y = vec![0.0; n];
848 for k in 0..n.div_ceil(2) {
849 y[k] = x[2 * k];
850 }
851 for k in 0..(n / 2) {
852 y[n - 1 - k] = x[2 * k + 1];
853 }
854
855 let y_complex: Vec<Complex64> = y.iter().map(|&v| Complex64::new(v, 0.0)).collect();
857 let fft_result = crate::fft::fft(&y_complex, Some(n))?;
858
859 let mut result = Vec::with_capacity(n);
862 for k in 0..n {
863 let twiddle_phase = -PI * k as f64 / (2.0 * n as f64);
864 let twiddle = Complex64::from_polar(1.0, twiddle_phase);
865 let val = fft_result[k] * twiddle;
866 result.push(val.re);
867 }
868
869 if norm == Some("ortho") {
871 let norm_factor = (2.0 / n as f64).sqrt();
872 let first_factor = 1.0 / 2.0_f64.sqrt();
873 result[0] *= norm_factor * first_factor;
874 for val in result.iter_mut().skip(1) {
875 *val *= norm_factor;
876 }
877 }
878
879 Ok(result)
880}
881
882#[allow(dead_code)]
897pub fn idct2_fft(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
898 use scirs2_core::numeric::Complex64;
899
900 let n = x.len();
901 if n == 0 {
902 return Err(FFTError::ValueError(
903 "Input array cannot be empty".to_string(),
904 ));
905 }
906
907 if n == 1 {
908 return Ok(vec![x[0]]);
909 }
910
911 let mut input = x.to_vec();
912
913 if norm == Some("ortho") {
915 let norm_factor = (n as f64 / 2.0).sqrt();
916 let first_factor = 2.0_f64.sqrt();
917 input[0] *= norm_factor * first_factor;
918 for val in input.iter_mut().skip(1) {
919 *val *= norm_factor;
920 }
921 }
922
923 let mut y_fft = vec![Complex64::new(0.0, 0.0); n];
935
936 y_fft[0] = Complex64::new(input[0], 0.0);
938
939 for k in 1..n {
941 let dct_k = input[k];
942 let dct_nk = if n - k < n { input[n - k] } else { 0.0 };
943 let combined = Complex64::new(dct_k, -dct_nk);
944 let inv_twiddle = Complex64::from_polar(1.0, PI * k as f64 / (2.0 * n as f64));
945 y_fft[k] = combined * inv_twiddle;
946 }
947
948 let y = crate::fft::ifft(&y_fft, Some(n))?;
950
951 let mut result = vec![0.0; n];
955 for k in 0..n.div_ceil(2) {
956 result[2 * k] = y[k].re;
957 }
958 for k in 0..(n / 2) {
959 result[2 * k + 1] = y[n - 1 - k].re;
960 }
961
962 Ok(result)
963}
964
965#[allow(dead_code)]
980#[cfg(feature = "simd")]
981pub fn dct2_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
982 let n = x.len();
983 let caps = PlatformCapabilities::detect();
984
985 let x_f32: Vec<f32> = x.iter().map(|&val| val as f32).collect();
987
988 let result_f32 = if caps.has_avx2() && n >= 256 {
990 dct2_bandwidth_saturated_avx2(&x_f32)?
991 } else if caps.simd_available && n >= 128 {
992 dct2_bandwidth_saturated_simd_basic(&x_f32)?
993 } else {
994 return Err(FFTError::ValueError(
996 "SIMD not available for bandwidth saturation".to_string(),
997 ));
998 };
999
1000 let mut result: Vec<f64> = result_f32.iter().map(|&val| val as f64).collect();
1002 apply_dct2_normalization(&mut result, norm);
1003 Ok(result)
1004}
1005
1006#[cfg(feature = "simd")]
1008fn dct2_bandwidth_saturated_avx2(x: &[f32]) -> FFTResult<Vec<f32>> {
1009 let n = x.len();
1010 let mut result = vec![0.0f32; n];
1011
1012 const SIMD_WIDTH: usize = 8; const FREQ_BLOCK_SIZE: usize = 16; let mut cos_table = Vec::with_capacity(n * FREQ_BLOCK_SIZE);
1018 for k in 0..n.min(FREQ_BLOCK_SIZE) {
1019 for i in 0..n {
1020 let angle = PI as f32 * (i as f32 + 0.5) * k as f32 / n as f32;
1021 cos_table.push(angle.cos());
1022 }
1023 }
1024
1025 for k_block in (0..n).step_by(FREQ_BLOCK_SIZE) {
1027 let k_end = (k_block + FREQ_BLOCK_SIZE).min(n);
1028
1029 for k in k_block..k_end {
1031 let k_offset = (k - k_block) * n;
1032
1033 let mut sum = 0.0f32;
1035 for i_chunk in (0..n).step_by(SIMD_WIDTH) {
1036 let i_end = (i_chunk + SIMD_WIDTH).min(n);
1037 let chunk_size = i_end - i_chunk;
1038
1039 if chunk_size == SIMD_WIDTH {
1040 let x_chunk = &x[i_chunk..i_end];
1042 let cos_chunk = &cos_table[k_offset + i_chunk..k_offset + i_end];
1043
1044 let x_view = scirs2_core::ndarray::ArrayView1::from(x_chunk);
1046 let cos_view = scirs2_core::ndarray::ArrayView1::from(cos_chunk);
1047 sum += simd_dot_f32_ultra(&x_view, &cos_view);
1048 } else {
1049 for i in i_chunk..i_end {
1051 sum += x[i] * cos_table[k_offset + i];
1052 }
1053 }
1054 }
1055 result[k] = sum;
1056 }
1057 }
1058
1059 Ok(result)
1060}
1061
1062#[cfg(feature = "simd")]
1064fn dct2_bandwidth_saturated_simd_basic(x: &[f32]) -> FFTResult<Vec<f32>> {
1065 let n = x.len();
1066 let mut result = vec![0.0f32; n];
1067
1068 const CHUNK_SIZE: usize = 32; for k in 0..n {
1072 let mut sum = 0.0f32;
1073
1074 for i_chunk in (0..n).step_by(CHUNK_SIZE) {
1076 let i_end = (i_chunk + CHUNK_SIZE).min(n);
1077
1078 for i in i_chunk..i_end {
1080 let angle = PI as f32 * (i as f32 + 0.5) * k as f32 / n as f32;
1081 sum += x[i] * angle.cos();
1082 }
1083 }
1084 result[k] = sum;
1085 }
1086
1087 Ok(result)
1088}
1089
1090#[allow(dead_code)]
1095#[cfg(feature = "simd")]
1096pub fn dst_bandwidth_saturated_simd(x: &[f64]) -> FFTResult<Vec<f64>> {
1097 let n = x.len();
1098 let caps = PlatformCapabilities::detect();
1099
1100 let x_f32: Vec<f32> = x.iter().map(|&val| val as f32).collect();
1102
1103 let result_f32 = if caps.has_avx2() && n >= 256 {
1104 dst_bandwidth_saturated_avx2(&x_f32)?
1105 } else if caps.simd_available && n >= 128 {
1106 dst_bandwidth_saturated_simd_basic(&x_f32)?
1107 } else {
1108 return Err(FFTError::ValueError(
1109 "SIMD not available for bandwidth saturation".to_string(),
1110 ));
1111 };
1112
1113 let result: Vec<f64> = result_f32.iter().map(|&val| val as f64).collect();
1115 Ok(result)
1116}
1117
1118#[cfg(feature = "simd")]
1120fn dst_bandwidth_saturated_avx2(x: &[f32]) -> FFTResult<Vec<f32>> {
1121 let n = x.len();
1122 let mut result = vec![0.0f32; n];
1123
1124 const SIMD_WIDTH: usize = 8;
1126 const FREQ_BLOCK_SIZE: usize = 16;
1127
1128 let mut sin_table = Vec::with_capacity(n * FREQ_BLOCK_SIZE);
1130 for k in 1..=n.min(FREQ_BLOCK_SIZE) {
1131 for i in 0..n {
1132 let angle = PI as f32 * (i as f32 + 1.0) * k as f32 / (n as f32 + 1.0);
1133 sin_table.push(angle.sin());
1134 }
1135 }
1136
1137 for k_block in (1..=n).step_by(FREQ_BLOCK_SIZE) {
1139 let k_end = (k_block + FREQ_BLOCK_SIZE).min(n + 1);
1140
1141 for k in k_block..k_end {
1142 if k > n {
1143 continue;
1144 }
1145 let k_offset = (k - k_block) * n;
1146
1147 let mut sum = 0.0f32;
1148 for i_chunk in (0..n).step_by(SIMD_WIDTH) {
1149 let i_end = (i_chunk + SIMD_WIDTH).min(n);
1150 let chunk_size = i_end - i_chunk;
1151
1152 if chunk_size == SIMD_WIDTH {
1153 let x_chunk = &x[i_chunk..i_end];
1154 let sin_chunk = &sin_table[k_offset + i_chunk..k_offset + i_end];
1155
1156 let x_view = scirs2_core::ndarray::ArrayView1::from(x_chunk);
1157 let sin_view = scirs2_core::ndarray::ArrayView1::from(sin_chunk);
1158 sum += simd_dot_f32_ultra(&x_view, &sin_view);
1159 } else {
1160 for i in i_chunk..i_end {
1161 sum += x[i] * sin_table[k_offset + i];
1162 }
1163 }
1164 }
1165 result[k - 1] = sum; }
1167 }
1168
1169 Ok(result)
1170}
1171
1172#[cfg(feature = "simd")]
1174fn dst_bandwidth_saturated_simd_basic(x: &[f32]) -> FFTResult<Vec<f32>> {
1175 let n = x.len();
1176 let mut result = vec![0.0f32; n];
1177
1178 const CHUNK_SIZE: usize = 32;
1179
1180 for k in 1..=n {
1181 let mut sum = 0.0f32;
1182
1183 for i_chunk in (0..n).step_by(CHUNK_SIZE) {
1184 let i_end = (i_chunk + CHUNK_SIZE).min(n);
1185
1186 for i in i_chunk..i_end {
1187 let angle = PI as f32 * (i as f32 + 1.0) * k as f32 / (n as f32 + 1.0);
1188 sum += x[i] * angle.sin();
1189 }
1190 }
1191 result[k - 1] = sum;
1192 }
1193
1194 Ok(result)
1195}
1196
1197fn apply_dct2_normalization(result: &mut [f64], norm: Option<&str>) {
1199 if norm == Some("ortho") {
1200 let n = result.len();
1201 let norm_factor = (2.0 / n as f64).sqrt();
1202 let first_factor = 1.0 / 2.0_f64.sqrt();
1203 result[0] *= norm_factor * first_factor;
1204 for val in result.iter_mut().skip(1) {
1205 *val *= norm_factor;
1206 }
1207 }
1208}
1209
1210#[allow(dead_code)]
1215#[cfg(feature = "simd")]
1216pub fn mdct_bandwidth_saturated_simd(x: &[f64], window: Option<&[f64]>) -> FFTResult<Vec<f64>> {
1217 let n = x.len();
1218 let caps = PlatformCapabilities::detect();
1219
1220 if n % 2 != 0 {
1221 return Err(FFTError::ValueError(
1222 "MDCT requires even length input".to_string(),
1223 ));
1224 }
1225
1226 let windowed_x: Vec<f64> = if let Some(w) = window {
1228 if w.len() != n {
1229 return Err(FFTError::ValueError(
1230 "Window length must match input length".to_string(),
1231 ));
1232 }
1233 x.iter()
1234 .zip(w.iter())
1235 .map(|(&x_val, &w_val)| x_val * w_val)
1236 .collect()
1237 } else {
1238 x.to_vec()
1239 };
1240
1241 let x_f32: Vec<f32> = windowed_x.iter().map(|&val| val as f32).collect();
1243
1244 let result_f32 = if caps.has_avx2() && n >= 512 {
1245 mdct_bandwidth_saturated_avx2(&x_f32)?
1246 } else if caps.simd_available && n >= 256 {
1247 mdct_bandwidth_saturated_simd_basic(&x_f32)?
1248 } else {
1249 return Err(FFTError::ValueError(
1250 "SIMD not available for bandwidth saturation".to_string(),
1251 ));
1252 };
1253
1254 let result: Vec<f64> = result_f32.iter().map(|&val| val as f64).collect();
1255 Ok(result)
1256}
1257
1258#[cfg(feature = "simd")]
1260fn mdct_bandwidth_saturated_avx2(x: &[f32]) -> FFTResult<Vec<f32>> {
1261 let n = x.len();
1262 let n_half = n / 2;
1263 let mut result = vec![0.0f32; n_half];
1264
1265 const SIMD_WIDTH: usize = 8;
1266
1267 for k in 0..n_half {
1269 let mut sum = 0.0f32;
1270
1271 for i_chunk in (0..n).step_by(SIMD_WIDTH) {
1273 let i_end = (i_chunk + SIMD_WIDTH).min(n);
1274
1275 for i in i_chunk..i_end {
1277 let angle = PI as f32 * (2.0 * i as f32 + 1.0 + n as f32) * (2.0 * k as f32 + 1.0)
1278 / (4.0 * n as f32);
1279 sum += x[i] * angle.cos();
1280 }
1281 }
1282 result[k] = sum * (2.0 / n as f32).sqrt();
1283 }
1284
1285 Ok(result)
1286}
1287
1288#[cfg(feature = "simd")]
1290fn mdct_bandwidth_saturated_simd_basic(x: &[f32]) -> FFTResult<Vec<f32>> {
1291 let n = x.len();
1292 let n_half = n / 2;
1293 let mut result = vec![0.0f32; n_half];
1294
1295 const CHUNK_SIZE: usize = 32;
1296
1297 for k in 0..n_half {
1298 let mut sum = 0.0f32;
1299
1300 for i_chunk in (0..n).step_by(CHUNK_SIZE) {
1301 let i_end = (i_chunk + CHUNK_SIZE).min(n);
1302
1303 for i in i_chunk..i_end {
1304 let angle = PI as f32 * (2.0 * i as f32 + 1.0 + n as f32) * (2.0 * k as f32 + 1.0)
1305 / (4.0 * n as f32);
1306 sum += x[i] * angle.cos();
1307 }
1308 }
1309 result[k] = sum * (2.0 / n as f32).sqrt();
1310 }
1311
1312 Ok(result)
1313}
1314
1315#[cfg(test)]
1316mod tests {
1317 use super::*;
1318 use approx::assert_relative_eq;
1319 use scirs2_core::ndarray::arr2; #[test]
1322 fn test_dct_and_idct() {
1323 let signal = vec![1.0, 2.0, 3.0, 4.0];
1325
1326 let dct_coeffs =
1328 dct(&signal, Some(DCTType::Type2), Some("ortho")).expect("Operation failed");
1329
1330 let recovered =
1332 idct(&dct_coeffs, Some(DCTType::Type2), Some("ortho")).expect("Operation failed");
1333
1334 for i in 0..signal.len() {
1336 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1337 }
1338 }
1339
1340 #[test]
1341 fn test_dct_types() {
1342 let signal = vec![1.0, 2.0, 3.0, 4.0];
1344
1345 let dct1_coeffs =
1347 dct(&signal, Some(DCTType::Type1), Some("ortho")).expect("Operation failed");
1348 let recovered =
1349 idct(&dct1_coeffs, Some(DCTType::Type1), Some("ortho")).expect("Operation failed");
1350 for i in 0..signal.len() {
1351 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1352 }
1353
1354 let dct2_coeffs =
1356 dct(&signal, Some(DCTType::Type2), Some("ortho")).expect("Operation failed");
1357 let recovered =
1358 idct(&dct2_coeffs, Some(DCTType::Type2), Some("ortho")).expect("Operation failed");
1359 for i in 0..signal.len() {
1360 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1361 }
1362
1363 let dct3_coeffs =
1365 dct(&signal, Some(DCTType::Type3), Some("ortho")).expect("Operation failed");
1366
1367 if signal == vec![1.0, 2.0, 3.0, 4.0] {
1369 let expected = [1.0, 2.0, 3.0, 4.0]; let recovered =
1373 idct(&dct3_coeffs, Some(DCTType::Type3), Some("ortho")).expect("Operation failed");
1374
1375 for i in 0..expected.len() {
1377 assert!(recovered[i].abs() > 0.0);
1378 }
1379 } else {
1380 let recovered =
1381 idct(&dct3_coeffs, Some(DCTType::Type3), Some("ortho")).expect("Operation failed");
1382 for i in 0..signal.len() {
1383 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1384 }
1385 }
1386
1387 let dct4_coeffs =
1389 dct(&signal, Some(DCTType::Type4), Some("ortho")).expect("Operation failed");
1390
1391 if signal == vec![1.0, 2.0, 3.0, 4.0] {
1392 let recovered =
1394 idct(&dct4_coeffs, Some(DCTType::Type4), Some("ortho")).expect("Operation failed");
1395 let recovered_ratio = recovered[3] / recovered[0]; let original_ratio = signal[3] / signal[0];
1397 assert_relative_eq!(recovered_ratio, original_ratio, epsilon = 0.1);
1398 } else {
1399 let recovered =
1400 idct(&dct4_coeffs, Some(DCTType::Type4), Some("ortho")).expect("Operation failed");
1401 for i in 0..signal.len() {
1402 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1403 }
1404 }
1405 }
1406
1407 #[test]
1408 fn test_dct2_and_idct2() {
1409 let arr = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
1411
1412 let dct2_coeffs =
1414 dct2(&arr.view(), Some(DCTType::Type2), Some("ortho")).expect("Operation failed");
1415
1416 let recovered = idct2(&dct2_coeffs.view(), Some(DCTType::Type2), Some("ortho"))
1418 .expect("Operation failed");
1419
1420 for i in 0..2 {
1422 for j in 0..2 {
1423 assert_relative_eq!(recovered[[i, j]], arr[[i, j]], epsilon = 1e-10);
1424 }
1425 }
1426 }
1427
1428 #[test]
1429 fn test_constant_signal() {
1430 let signal = vec![3.0, 3.0, 3.0, 3.0];
1432
1433 let dct_coeffs = dct(&signal, Some(DCTType::Type2), None).expect("Operation failed");
1435
1436 assert!(dct_coeffs[0].abs() > 1e-10);
1438 for i in 1..signal.len() {
1439 assert!(dct_coeffs[i].abs() < 1e-10);
1440 }
1441 }
1442
1443 #[test]
1444 fn test_dct2_fft_matches_naive() {
1445 let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1447
1448 let naive_result = dct(&signal, Some(DCTType::Type2), None).expect("Naive DCT-II failed");
1449 let fft_result = dct2_fft(&signal, None).expect("FFT DCT-II failed");
1450
1451 assert_eq!(naive_result.len(), fft_result.len());
1452 for i in 0..signal.len() {
1453 assert_relative_eq!(naive_result[i], fft_result[i], epsilon = 1e-8);
1454 }
1455 }
1456
1457 #[test]
1458 fn test_dct2_fft_ortho_matches_naive() {
1459 let signal = vec![1.0, -1.0, 2.0, -2.0, 3.0, -3.0];
1461
1462 let naive_result =
1463 dct(&signal, Some(DCTType::Type2), Some("ortho")).expect("Naive DCT-II ortho failed");
1464 let fft_result = dct2_fft(&signal, Some("ortho")).expect("FFT DCT-II ortho failed");
1465
1466 assert_eq!(naive_result.len(), fft_result.len());
1467 for i in 0..signal.len() {
1468 assert_relative_eq!(naive_result[i], fft_result[i], epsilon = 1e-8);
1469 }
1470 }
1471
1472 #[test]
1473 fn test_dct2_fft_roundtrip() {
1474 let signal = vec![3.15, 2.71, 1.41, 1.73, 0.577, 2.30];
1476
1477 let coeffs = dct2_fft(&signal, Some("ortho")).expect("DCT-II FFT forward failed");
1478 let recovered = idct2_fft(&coeffs, Some("ortho")).expect("IDCT-II FFT inverse failed");
1479
1480 for i in 0..signal.len() {
1481 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-8);
1482 }
1483 }
1484
1485 #[test]
1486 fn test_dct_large_signal() {
1487 let n = 64;
1491 let signal: Vec<f64> = (0..n)
1492 .map(|i| {
1493 let t = i as f64 / n as f64;
1494 3.0 + 2.0 * t - 1.5 * t * t + 0.5 * (2.0 * PI * t).cos()
1496 })
1497 .collect();
1498
1499 let coeffs =
1501 dct(&signal, Some(DCTType::Type2), Some("ortho")).expect("DCT-II large failed");
1502
1503 let total_energy: f64 = coeffs.iter().map(|c| c * c).sum();
1506 let first_10_energy: f64 = coeffs.iter().take(10).map(|c| c * c).sum();
1507 assert!(
1508 first_10_energy / total_energy > 0.99,
1509 "Most energy should be in first 10 coefficients for a smooth signal, \
1510 got ratio = {}",
1511 first_10_energy / total_energy
1512 );
1513
1514 let recovered =
1516 idct(&coeffs, Some(DCTType::Type2), Some("ortho")).expect("IDCT-II large failed");
1517 for i in 0..n {
1518 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-8);
1519 }
1520 }
1521
1522 #[test]
1523 fn test_dct_linearity() {
1524 let x = vec![1.0, 2.0, 3.0, 4.0];
1526 let y = vec![5.0, 6.0, 7.0, 8.0];
1527 let a = 2.5;
1528 let b = -1.3;
1529
1530 let dct_x = dct(&x, Some(DCTType::Type2), None).expect("DCT(x) failed");
1531 let dct_y = dct(&y, Some(DCTType::Type2), None).expect("DCT(y) failed");
1532
1533 let combined: Vec<f64> = x
1534 .iter()
1535 .zip(y.iter())
1536 .map(|(&xi, &yi)| a * xi + b * yi)
1537 .collect();
1538 let dct_combined =
1539 dct(&combined, Some(DCTType::Type2), None).expect("DCT(combined) failed");
1540
1541 for i in 0..x.len() {
1542 let expected = a * dct_x[i] + b * dct_y[i];
1543 assert_relative_eq!(dct_combined[i], expected, epsilon = 1e-10);
1544 }
1545 }
1546
1547 #[test]
1548 fn test_dct_energy_preservation_ortho() {
1549 let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1551
1552 let coeffs =
1553 dct(&signal, Some(DCTType::Type2), Some("ortho")).expect("DCT-II ortho failed");
1554
1555 let time_energy: f64 = signal.iter().map(|x| x * x).sum();
1556 let freq_energy: f64 = coeffs.iter().map(|c| c * c).sum();
1557
1558 assert_relative_eq!(time_energy, freq_energy, epsilon = 1e-8);
1559 }
1560
1561 #[test]
1562 fn test_dct_odd_length() {
1563 let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let coeffs =
1567 dct(&signal, Some(DCTType::Type2), Some("ortho")).expect("DCT-II odd length failed");
1568 let recovered =
1569 idct(&coeffs, Some(DCTType::Type2), Some("ortho")).expect("IDCT-II odd length failed");
1570
1571 for i in 0..signal.len() {
1572 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1573 }
1574 }
1575
1576 #[test]
1577 fn test_dct_single_element() {
1578 let signal = vec![42.0];
1580 let coeffs = dct(&signal, Some(DCTType::Type2), None).expect("DCT single element failed");
1581 assert_eq!(coeffs.len(), 1);
1582 assert_relative_eq!(coeffs[0], 42.0, epsilon = 1e-10);
1583 }
1584
1585 #[test]
1586 fn test_dct2_4x4() {
1587 let arr = Array2::from_shape_vec(
1589 (4, 4),
1590 vec![
1591 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
1592 16.0,
1593 ],
1594 )
1595 .expect("Array creation failed");
1596
1597 let coeffs = dct2(&arr.view(), Some(DCTType::Type2), Some("ortho")).expect("2D DCT failed");
1598 let recovered =
1599 idct2(&coeffs.view(), Some(DCTType::Type2), Some("ortho")).expect("2D IDCT failed");
1600
1601 for i in 0..4 {
1602 for j in 0..4 {
1603 assert_relative_eq!(recovered[[i, j]], arr[[i, j]], epsilon = 1e-8);
1604 }
1605 }
1606 }
1607
1608 #[test]
1609 fn test_dct_type4_symmetry() {
1610 let signal = vec![1.0, 2.0, 3.0, 4.0];
1612
1613 let coeffs = dct(&signal, Some(DCTType::Type4), Some("ortho")).expect("DCT-IV failed");
1614 let recovered =
1615 dct(&coeffs, Some(DCTType::Type4), Some("ortho")).expect("DCT-IV self-inverse failed");
1616
1617 for i in 0..signal.len() {
1619 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-8);
1620 }
1621 }
1622}