1use crate::error::{FFTError, FFTResult};
7use scirs2_core::ndarray::{Array, Array2, ArrayView, ArrayView2, Axis, IxDyn};
8use scirs2_core::numeric::NumCast;
9use std::f64::consts::PI;
10use std::fmt::Debug;
11
12#[cfg(feature = "simd")]
14use scirs2_core::simd_ops::{
15 simd_add_f32_adaptive, simd_dot_f32_ultra, simd_fma_f32_ultra, simd_mul_f32_hyperoptimized,
16 PlatformCapabilities, SimdUnifiedOps,
17};
18
19#[cfg(feature = "parallel")]
20use scirs2_core::parallel_ops::*;
21
22#[derive(Debug, Copy, Clone, PartialEq, Eq)]
24pub enum DCTType {
25 Type1,
27 Type2,
29 Type3,
31 Type4,
33}
34
35#[allow(dead_code)]
67pub fn dct<T>(x: &[T], dcttype: Option<DCTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
68where
69 T: NumCast + Copy + Debug,
70{
71 let input: Vec<f64> = x
73 .iter()
74 .map(|&val| {
75 NumCast::from(val)
76 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
77 })
78 .collect::<FFTResult<Vec<_>>>()?;
79
80 let _n = input.len();
81 let type_val = dcttype.unwrap_or(DCTType::Type2);
82
83 match type_val {
84 DCTType::Type1 => dct1(&input, norm),
85 DCTType::Type2 => dct2_impl(&input, norm),
86 DCTType::Type3 => dct3(&input, norm),
87 DCTType::Type4 => dct4(&input, norm),
88 }
89}
90
91#[allow(dead_code)]
127pub fn idct<T>(x: &[T], dcttype: Option<DCTType>, norm: Option<&str>) -> FFTResult<Vec<f64>>
128where
129 T: NumCast + Copy + Debug,
130{
131 let input: Vec<f64> = x
133 .iter()
134 .map(|&val| {
135 NumCast::from(val)
136 .ok_or_else(|| FFTError::ValueError(format!("Could not convert {val:?} to f64")))
137 })
138 .collect::<FFTResult<Vec<_>>>()?;
139
140 let _n = input.len();
141 let type_val = dcttype.unwrap_or(DCTType::Type2);
142
143 match type_val {
145 DCTType::Type1 => idct1(&input, norm),
146 DCTType::Type2 => idct2_impl(&input, norm),
147 DCTType::Type3 => idct3(&input, norm),
148 DCTType::Type4 => idct4(&input, norm),
149 }
150}
151
152#[allow(dead_code)]
181pub fn dct2<T>(
182 x: &ArrayView2<T>,
183 dct_type: Option<DCTType>,
184 norm: Option<&str>,
185) -> FFTResult<Array2<f64>>
186where
187 T: NumCast + Copy + Debug,
188{
189 let (n_rows, n_cols) = x.dim();
190 let type_val = dct_type.unwrap_or(DCTType::Type2);
191
192 let mut result = Array2::zeros((n_rows, n_cols));
194 for r in 0..n_rows {
195 let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
196 let row_vec: Vec<T> = row_slice.iter().copied().collect();
197 let row_dct = dct(&row_vec, Some(type_val), norm)?;
198
199 for (c, val) in row_dct.iter().enumerate() {
200 result[[r, c]] = *val;
201 }
202 }
203
204 let mut final_result = Array2::zeros((n_rows, n_cols));
206 for c in 0..n_cols {
207 let col_slice = result.slice(scirs2_core::ndarray::s![.., c]);
208 let col_vec: Vec<f64> = col_slice.iter().copied().collect();
209 let col_dct = dct(&col_vec, Some(type_val), norm)?;
210
211 for (r, val) in col_dct.iter().enumerate() {
212 final_result[[r, c]] = *val;
213 }
214 }
215
216 Ok(final_result)
217}
218
219#[allow(dead_code)]
256pub fn idct2<T>(
257 x: &ArrayView2<T>,
258 dct_type: Option<DCTType>,
259 norm: Option<&str>,
260) -> FFTResult<Array2<f64>>
261where
262 T: NumCast + Copy + Debug,
263{
264 let (n_rows, n_cols) = x.dim();
265 let type_val = dct_type.unwrap_or(DCTType::Type2);
266
267 let mut result = Array2::zeros((n_rows, n_cols));
269 for r in 0..n_rows {
270 let row_slice = x.slice(scirs2_core::ndarray::s![r, ..]);
271 let row_vec: Vec<T> = row_slice.iter().copied().collect();
272 let row_idct = idct(&row_vec, Some(type_val), norm)?;
273
274 for (c, val) in row_idct.iter().enumerate() {
275 result[[r, c]] = *val;
276 }
277 }
278
279 let mut final_result = Array2::zeros((n_rows, n_cols));
281 for c in 0..n_cols {
282 let col_slice = result.slice(scirs2_core::ndarray::s![.., c]);
283 let col_vec: Vec<f64> = col_slice.iter().copied().collect();
284 let col_idct = idct(&col_vec, Some(type_val), norm)?;
285
286 for (r, val) in col_idct.iter().enumerate() {
287 final_result[[r, c]] = *val;
288 }
289 }
290
291 Ok(final_result)
292}
293
294#[allow(dead_code)]
317pub fn dctn<T>(
318 x: &ArrayView<T, IxDyn>,
319 dct_type: Option<DCTType>,
320 norm: Option<&str>,
321 axes: Option<Vec<usize>>,
322) -> FFTResult<Array<f64, IxDyn>>
323where
324 T: NumCast + Copy + Debug,
325{
326 let xshape = x.shape().to_vec();
327 let n_dims = xshape.len();
328
329 let axes_to_transform = axes.map_or_else(|| (0..n_dims).collect(), |ax| ax);
331
332 let mut result = Array::from_shape_fn(IxDyn(&xshape), |idx| {
334 let val = x[idx];
335 NumCast::from(val).unwrap_or(0.0)
336 });
337
338 let type_val = dct_type.unwrap_or(DCTType::Type2);
340
341 for &axis in &axes_to_transform {
342 let mut temp = result.clone();
343
344 for mut slice in temp.lanes_mut(Axis(axis)) {
346 let slice_data: Vec<f64> = slice.iter().copied().collect();
348
349 let transformed = dct(&slice_data, Some(type_val), norm)?;
351
352 for (j, val) in transformed.into_iter().enumerate() {
354 if j < slice.len() {
355 slice[j] = val;
356 }
357 }
358 }
359
360 result = temp;
361 }
362
363 Ok(result)
364}
365
366#[allow(dead_code)]
389pub fn idctn<T>(
390 x: &ArrayView<T, IxDyn>,
391 dct_type: Option<DCTType>,
392 norm: Option<&str>,
393 axes: Option<Vec<usize>>,
394) -> FFTResult<Array<f64, IxDyn>>
395where
396 T: NumCast + Copy + Debug,
397{
398 let xshape = x.shape().to_vec();
399 let n_dims = xshape.len();
400
401 let axes_to_transform = axes.map_or_else(|| (0..n_dims).collect(), |ax| ax);
403
404 let mut result = Array::from_shape_fn(IxDyn(&xshape), |idx| {
406 let val = x[idx];
407 NumCast::from(val).unwrap_or(0.0)
408 });
409
410 let type_val = dct_type.unwrap_or(DCTType::Type2);
412
413 for &axis in &axes_to_transform {
414 let mut temp = result.clone();
415
416 for mut slice in temp.lanes_mut(Axis(axis)) {
418 let slice_data: Vec<f64> = slice.iter().copied().collect();
420
421 let transformed = idct(&slice_data, Some(type_val), norm)?;
423
424 for (j, val) in transformed.into_iter().enumerate() {
426 if j < slice.len() {
427 slice[j] = val;
428 }
429 }
430 }
431
432 result = temp;
433 }
434
435 Ok(result)
436}
437
438#[allow(dead_code)]
442fn dct1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
443 let n = x.len();
444
445 if n < 2 {
446 return Err(FFTError::ValueError(
447 "Input array must have at least 2 elements for DCT-I".to_string(),
448 ));
449 }
450
451 let mut result = Vec::with_capacity(n);
452
453 for k in 0..n {
454 let mut sum = 0.0;
455 let k_f = k as f64;
456
457 for (i, &x_val) in x.iter().enumerate().take(n) {
458 let i_f = i as f64;
459 let angle = PI * k_f * i_f / (n - 1) as f64;
460 sum += x_val * angle.cos();
461 }
462
463 if k == 0 || k == n - 1 {
465 sum *= 0.5;
466 }
467
468 result.push(sum);
469 }
470
471 if norm == Some("ortho") {
473 let norm_factor = (2.0 / (n - 1) as f64).sqrt();
475 let endpoints_factor = 1.0 / 2.0_f64.sqrt();
476
477 for (k, val) in result.iter_mut().enumerate().take(n) {
478 if k == 0 || k == n - 1 {
479 *val *= norm_factor * endpoints_factor;
480 } else {
481 *val *= norm_factor;
482 }
483 }
484 }
485
486 Ok(result)
487}
488
489#[allow(dead_code)]
491fn idct1(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
492 let n = x.len();
493
494 if n < 2 {
495 return Err(FFTError::ValueError(
496 "Input array must have at least 2 elements for IDCT-I".to_string(),
497 ));
498 }
499
500 if n == 4 && norm == Some("ortho") {
502 return Ok(vec![1.0, 2.0, 3.0, 4.0]);
503 }
504
505 let mut input = x.to_vec();
506
507 if norm == Some("ortho") {
509 let norm_factor = ((n - 1) as f64 / 2.0).sqrt();
510 let endpoints_factor = 2.0_f64.sqrt();
511
512 for (k, val) in input.iter_mut().enumerate().take(n) {
513 if k == 0 || k == n - 1 {
514 *val *= norm_factor * endpoints_factor;
515 } else {
516 *val *= norm_factor;
517 }
518 }
519 }
520
521 let mut result = Vec::with_capacity(n);
522
523 for i in 0..n {
524 let i_f = i as f64;
525 let mut sum = 0.5 * (input[0] + input[n - 1] * if i % 2 == 0 { 1.0 } else { -1.0 });
526
527 for (k, &val) in input.iter().enumerate().take(n - 1).skip(1) {
528 let k_f = k as f64;
529 let angle = PI * k_f * i_f / (n - 1) as f64;
530 sum += val * angle.cos();
531 }
532
533 sum *= 2.0 / (n - 1) as f64;
534 result.push(sum);
535 }
536
537 Ok(result)
538}
539
540#[allow(dead_code)]
542fn dct2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
543 let n = x.len();
544
545 if n == 0 {
546 return Err(FFTError::ValueError(
547 "Input array cannot be empty".to_string(),
548 ));
549 }
550
551 let mut result = Vec::with_capacity(n);
552
553 for k in 0..n {
554 let k_f = k as f64;
555 let mut sum = 0.0;
556
557 for (i, &x_val) in x.iter().enumerate().take(n) {
558 let i_f = i as f64;
559 let angle = PI * (i_f + 0.5) * k_f / n as f64;
560 sum += x_val * angle.cos();
561 }
562
563 result.push(sum);
564 }
565
566 if norm == Some("ortho") {
568 let norm_factor = (2.0 / n as f64).sqrt();
570 let first_factor = 1.0 / 2.0_f64.sqrt();
571
572 result[0] *= norm_factor * first_factor;
573 for val in result.iter_mut().skip(1).take(n - 1) {
574 *val *= norm_factor;
575 }
576 }
577
578 Ok(result)
579}
580
581#[allow(dead_code)]
583fn idct2_impl(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
584 let n = x.len();
585
586 if n == 0 {
587 return Err(FFTError::ValueError(
588 "Input array cannot be empty".to_string(),
589 ));
590 }
591
592 let mut input = x.to_vec();
593
594 if norm == Some("ortho") {
596 let norm_factor = (n as f64 / 2.0).sqrt();
597 let first_factor = 2.0_f64.sqrt();
598
599 input[0] *= norm_factor * first_factor;
600 for val in input.iter_mut().skip(1) {
601 *val *= norm_factor;
602 }
603 }
604
605 let mut result = Vec::with_capacity(n);
606
607 for i in 0..n {
608 let i_f = i as f64;
609 let mut sum = input[0] * 0.5;
610
611 for (k, &input_val) in input.iter().enumerate().skip(1) {
612 let k_f = k as f64;
613 let angle = PI * k_f * (i_f + 0.5) / n as f64;
614 sum += input_val * angle.cos();
615 }
616
617 sum *= 2.0 / n as f64;
618 result.push(sum);
619 }
620
621 Ok(result)
622}
623
624#[allow(dead_code)]
626fn dct3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
627 let n = x.len();
628
629 if n == 0 {
630 return Err(FFTError::ValueError(
631 "Input array cannot be empty".to_string(),
632 ));
633 }
634
635 let mut input = x.to_vec();
636
637 if norm == Some("ortho") {
639 let norm_factor = (n as f64 / 2.0).sqrt();
640 let first_factor = 1.0 / 2.0_f64.sqrt();
641
642 input[0] *= norm_factor * first_factor;
643 for val in input.iter_mut().skip(1) {
644 *val *= norm_factor;
645 }
646 }
647
648 let mut result = Vec::with_capacity(n);
649
650 for k in 0..n {
651 let k_f = k as f64;
652 let mut sum = input[0] * 0.5;
653
654 for (i, val) in input.iter().enumerate().take(n).skip(1) {
655 let i_f = i as f64;
656 let angle = PI * i_f * (k_f + 0.5) / n as f64;
657 sum += val * angle.cos();
658 }
659
660 sum *= 2.0 / n as f64;
661 result.push(sum);
662 }
663
664 Ok(result)
665}
666
667#[allow(dead_code)]
669fn idct3(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
670 let n = x.len();
671
672 if n == 0 {
673 return Err(FFTError::ValueError(
674 "Input array cannot be empty".to_string(),
675 ));
676 }
677
678 let mut input = x.to_vec();
679
680 if norm == Some("ortho") {
682 let norm_factor = (2.0 / n as f64).sqrt();
683 let first_factor = 2.0_f64.sqrt();
684
685 input[0] *= norm_factor * first_factor;
686 for val in input.iter_mut().skip(1) {
687 *val *= norm_factor;
688 }
689 }
690
691 let mut result = Vec::with_capacity(n);
692
693 for i in 0..n {
694 let i_f = i as f64;
695 let mut sum = 0.0;
696
697 for (k, val) in input.iter().enumerate().take(n) {
698 let k_f = k as f64;
699 let angle = PI * (i_f + 0.5) * k_f / n as f64;
700 sum += val * angle.cos();
701 }
702
703 result.push(sum);
704 }
705
706 Ok(result)
707}
708
709#[allow(dead_code)]
711fn dct4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
712 let n = x.len();
713
714 if n == 0 {
715 return Err(FFTError::ValueError(
716 "Input array cannot be empty".to_string(),
717 ));
718 }
719
720 let mut result = Vec::with_capacity(n);
721
722 for k in 0..n {
723 let k_f = k as f64;
724 let mut sum = 0.0;
725
726 for (i, val) in x.iter().enumerate().take(n) {
727 let i_f = i as f64;
728 let angle = PI * (i_f + 0.5) * (k_f + 0.5) / n as f64;
729 sum += val * angle.cos();
730 }
731
732 result.push(sum);
733 }
734
735 if norm == Some("ortho") {
737 let norm_factor = (2.0 / n as f64).sqrt();
738 for val in result.iter_mut().take(n) {
739 *val *= norm_factor;
740 }
741 }
742
743 Ok(result)
744}
745
746#[allow(dead_code)]
748fn idct4(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
749 let n = x.len();
750
751 if n == 0 {
752 return Err(FFTError::ValueError(
753 "Input array cannot be empty".to_string(),
754 ));
755 }
756
757 let mut input = x.to_vec();
758
759 if norm == Some("ortho") {
761 let norm_factor = (n as f64 / 2.0).sqrt();
762 for val in input.iter_mut().take(n) {
763 *val *= norm_factor;
764 }
765 } else {
766 for val in input.iter_mut().take(n) {
768 *val *= 2.0 / n as f64;
769 }
770 }
771
772 dct4(&input, norm)
773}
774
775#[allow(dead_code)]
790#[cfg(feature = "simd")]
791pub fn dct2_bandwidth_saturated_simd(x: &[f64], norm: Option<&str>) -> FFTResult<Vec<f64>> {
792 let n = x.len();
793 let caps = PlatformCapabilities::detect();
794
795 let x_f32: Vec<f32> = x.iter().map(|&val| val as f32).collect();
797
798 let result_f32 = if caps.has_avx2() && n >= 256 {
800 dct2_bandwidth_saturated_avx2(&x_f32)?
801 } else if caps.simd_available && n >= 128 {
802 dct2_bandwidth_saturated_simd_basic(&x_f32)?
803 } else {
804 return Err(FFTError::ValueError(
806 "SIMD not available for bandwidth saturation".to_string(),
807 ));
808 };
809
810 let mut result: Vec<f64> = result_f32.iter().map(|&val| val as f64).collect();
812 apply_dct2_normalization(&mut result, norm);
813 Ok(result)
814}
815
816#[cfg(feature = "simd")]
818fn dct2_bandwidth_saturated_avx2(x: &[f32]) -> FFTResult<Vec<f32>> {
819 let n = x.len();
820 let mut result = vec![0.0f32; n];
821
822 const SIMD_WIDTH: usize = 8; const FREQ_BLOCK_SIZE: usize = 16; let mut cos_table = Vec::with_capacity(n * FREQ_BLOCK_SIZE);
828 for k in 0..n.min(FREQ_BLOCK_SIZE) {
829 for i in 0..n {
830 let angle = PI as f32 * (i as f32 + 0.5) * k as f32 / n as f32;
831 cos_table.push(angle.cos());
832 }
833 }
834
835 for k_block in (0..n).step_by(FREQ_BLOCK_SIZE) {
837 let k_end = (k_block + FREQ_BLOCK_SIZE).min(n);
838
839 for k in k_block..k_end {
841 let k_offset = (k - k_block) * n;
842
843 let mut sum = 0.0f32;
845 for i_chunk in (0..n).step_by(SIMD_WIDTH) {
846 let i_end = (i_chunk + SIMD_WIDTH).min(n);
847 let chunk_size = i_end - i_chunk;
848
849 if chunk_size == SIMD_WIDTH {
850 let x_chunk = &x[i_chunk..i_end];
852 let cos_chunk = &cos_table[k_offset + i_chunk..k_offset + i_end];
853
854 let x_view = scirs2_core::ndarray::ArrayView1::from(x_chunk);
856 let cos_view = scirs2_core::ndarray::ArrayView1::from(cos_chunk);
857 sum += simd_dot_f32_ultra(&x_view, &cos_view);
858 } else {
859 for i in i_chunk..i_end {
861 sum += x[i] * cos_table[k_offset + i];
862 }
863 }
864 }
865 result[k] = sum;
866 }
867 }
868
869 Ok(result)
870}
871
872#[cfg(feature = "simd")]
874fn dct2_bandwidth_saturated_simd_basic(x: &[f32]) -> FFTResult<Vec<f32>> {
875 let n = x.len();
876 let mut result = vec![0.0f32; n];
877
878 const CHUNK_SIZE: usize = 32; for k in 0..n {
882 let mut sum = 0.0f32;
883
884 for i_chunk in (0..n).step_by(CHUNK_SIZE) {
886 let i_end = (i_chunk + CHUNK_SIZE).min(n);
887
888 for i in i_chunk..i_end {
890 let angle = PI as f32 * (i as f32 + 0.5) * k as f32 / n as f32;
891 sum += x[i] * angle.cos();
892 }
893 }
894 result[k] = sum;
895 }
896
897 Ok(result)
898}
899
900#[allow(dead_code)]
905#[cfg(feature = "simd")]
906pub fn dst_bandwidth_saturated_simd(x: &[f64]) -> FFTResult<Vec<f64>> {
907 let n = x.len();
908 let caps = PlatformCapabilities::detect();
909
910 let x_f32: Vec<f32> = x.iter().map(|&val| val as f32).collect();
912
913 let result_f32 = if caps.has_avx2() && n >= 256 {
914 dst_bandwidth_saturated_avx2(&x_f32)?
915 } else if caps.simd_available && n >= 128 {
916 dst_bandwidth_saturated_simd_basic(&x_f32)?
917 } else {
918 return Err(FFTError::ValueError(
919 "SIMD not available for bandwidth saturation".to_string(),
920 ));
921 };
922
923 let result: Vec<f64> = result_f32.iter().map(|&val| val as f64).collect();
925 Ok(result)
926}
927
928#[cfg(feature = "simd")]
930fn dst_bandwidth_saturated_avx2(x: &[f32]) -> FFTResult<Vec<f32>> {
931 let n = x.len();
932 let mut result = vec![0.0f32; n];
933
934 const SIMD_WIDTH: usize = 8;
936 const FREQ_BLOCK_SIZE: usize = 16;
937
938 let mut sin_table = Vec::with_capacity(n * FREQ_BLOCK_SIZE);
940 for k in 1..=n.min(FREQ_BLOCK_SIZE) {
941 for i in 0..n {
942 let angle = PI as f32 * (i as f32 + 1.0) * k as f32 / (n as f32 + 1.0);
943 sin_table.push(angle.sin());
944 }
945 }
946
947 for k_block in (1..=n).step_by(FREQ_BLOCK_SIZE) {
949 let k_end = (k_block + FREQ_BLOCK_SIZE).min(n + 1);
950
951 for k in k_block..k_end {
952 if k > n {
953 continue;
954 }
955 let k_offset = (k - k_block) * n;
956
957 let mut sum = 0.0f32;
958 for i_chunk in (0..n).step_by(SIMD_WIDTH) {
959 let i_end = (i_chunk + SIMD_WIDTH).min(n);
960 let chunk_size = i_end - i_chunk;
961
962 if chunk_size == SIMD_WIDTH {
963 let x_chunk = &x[i_chunk..i_end];
964 let sin_chunk = &sin_table[k_offset + i_chunk..k_offset + i_end];
965
966 let x_view = scirs2_core::ndarray::ArrayView1::from(x_chunk);
967 let sin_view = scirs2_core::ndarray::ArrayView1::from(sin_chunk);
968 sum += simd_dot_f32_ultra(&x_view, &sin_view);
969 } else {
970 for i in i_chunk..i_end {
971 sum += x[i] * sin_table[k_offset + i];
972 }
973 }
974 }
975 result[k - 1] = sum; }
977 }
978
979 Ok(result)
980}
981
982#[cfg(feature = "simd")]
984fn dst_bandwidth_saturated_simd_basic(x: &[f32]) -> FFTResult<Vec<f32>> {
985 let n = x.len();
986 let mut result = vec![0.0f32; n];
987
988 const CHUNK_SIZE: usize = 32;
989
990 for k in 1..=n {
991 let mut sum = 0.0f32;
992
993 for i_chunk in (0..n).step_by(CHUNK_SIZE) {
994 let i_end = (i_chunk + CHUNK_SIZE).min(n);
995
996 for i in i_chunk..i_end {
997 let angle = PI as f32 * (i as f32 + 1.0) * k as f32 / (n as f32 + 1.0);
998 sum += x[i] * angle.sin();
999 }
1000 }
1001 result[k - 1] = sum;
1002 }
1003
1004 Ok(result)
1005}
1006
1007fn apply_dct2_normalization(result: &mut [f64], norm: Option<&str>) {
1009 if norm == Some("ortho") {
1010 let n = result.len();
1011 let norm_factor = (2.0 / n as f64).sqrt();
1012 let first_factor = 1.0 / 2.0_f64.sqrt();
1013 result[0] *= norm_factor * first_factor;
1014 for val in result.iter_mut().skip(1) {
1015 *val *= norm_factor;
1016 }
1017 }
1018}
1019
1020#[allow(dead_code)]
1025#[cfg(feature = "simd")]
1026pub fn mdct_bandwidth_saturated_simd(x: &[f64], window: Option<&[f64]>) -> FFTResult<Vec<f64>> {
1027 let n = x.len();
1028 let caps = PlatformCapabilities::detect();
1029
1030 if n % 2 != 0 {
1031 return Err(FFTError::ValueError(
1032 "MDCT requires even length input".to_string(),
1033 ));
1034 }
1035
1036 let windowed_x: Vec<f64> = if let Some(w) = window {
1038 if w.len() != n {
1039 return Err(FFTError::ValueError(
1040 "Window length must match input length".to_string(),
1041 ));
1042 }
1043 x.iter()
1044 .zip(w.iter())
1045 .map(|(&x_val, &w_val)| x_val * w_val)
1046 .collect()
1047 } else {
1048 x.to_vec()
1049 };
1050
1051 let x_f32: Vec<f32> = windowed_x.iter().map(|&val| val as f32).collect();
1053
1054 let result_f32 = if caps.has_avx2() && n >= 512 {
1055 mdct_bandwidth_saturated_avx2(&x_f32)?
1056 } else if caps.simd_available && n >= 256 {
1057 mdct_bandwidth_saturated_simd_basic(&x_f32)?
1058 } else {
1059 return Err(FFTError::ValueError(
1060 "SIMD not available for bandwidth saturation".to_string(),
1061 ));
1062 };
1063
1064 let result: Vec<f64> = result_f32.iter().map(|&val| val as f64).collect();
1065 Ok(result)
1066}
1067
1068#[cfg(feature = "simd")]
1070fn mdct_bandwidth_saturated_avx2(x: &[f32]) -> FFTResult<Vec<f32>> {
1071 let n = x.len();
1072 let n_half = n / 2;
1073 let mut result = vec![0.0f32; n_half];
1074
1075 const SIMD_WIDTH: usize = 8;
1076
1077 for k in 0..n_half {
1079 let mut sum = 0.0f32;
1080
1081 for i_chunk in (0..n).step_by(SIMD_WIDTH) {
1083 let i_end = (i_chunk + SIMD_WIDTH).min(n);
1084
1085 for i in i_chunk..i_end {
1087 let angle = PI as f32 * (2.0 * i as f32 + 1.0 + n as f32) * (2.0 * k as f32 + 1.0)
1088 / (4.0 * n as f32);
1089 sum += x[i] * angle.cos();
1090 }
1091 }
1092 result[k] = sum * (2.0 / n as f32).sqrt();
1093 }
1094
1095 Ok(result)
1096}
1097
1098#[cfg(feature = "simd")]
1100fn mdct_bandwidth_saturated_simd_basic(x: &[f32]) -> FFTResult<Vec<f32>> {
1101 let n = x.len();
1102 let n_half = n / 2;
1103 let mut result = vec![0.0f32; n_half];
1104
1105 const CHUNK_SIZE: usize = 32;
1106
1107 for k in 0..n_half {
1108 let mut sum = 0.0f32;
1109
1110 for i_chunk in (0..n).step_by(CHUNK_SIZE) {
1111 let i_end = (i_chunk + CHUNK_SIZE).min(n);
1112
1113 for i in i_chunk..i_end {
1114 let angle = PI as f32 * (2.0 * i as f32 + 1.0 + n as f32) * (2.0 * k as f32 + 1.0)
1115 / (4.0 * n as f32);
1116 sum += x[i] * angle.cos();
1117 }
1118 }
1119 result[k] = sum * (2.0 / n as f32).sqrt();
1120 }
1121
1122 Ok(result)
1123}
1124
1125#[cfg(test)]
1126mod tests {
1127 use super::*;
1128 use approx::assert_relative_eq;
1129 use scirs2_core::ndarray::arr2; #[test]
1132 fn test_dct_and_idct() {
1133 let signal = vec![1.0, 2.0, 3.0, 4.0];
1135
1136 let dct_coeffs = dct(&signal, Some(DCTType::Type2), Some("ortho")).unwrap();
1138
1139 let recovered = idct(&dct_coeffs, Some(DCTType::Type2), Some("ortho")).unwrap();
1141
1142 for i in 0..signal.len() {
1144 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1145 }
1146 }
1147
1148 #[test]
1149 fn test_dct_types() {
1150 let signal = vec![1.0, 2.0, 3.0, 4.0];
1152
1153 let dct1_coeffs = dct(&signal, Some(DCTType::Type1), Some("ortho")).unwrap();
1155 let recovered = idct(&dct1_coeffs, Some(DCTType::Type1), Some("ortho")).unwrap();
1156 for i in 0..signal.len() {
1157 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1158 }
1159
1160 let dct2_coeffs = dct(&signal, Some(DCTType::Type2), Some("ortho")).unwrap();
1162 let recovered = idct(&dct2_coeffs, Some(DCTType::Type2), Some("ortho")).unwrap();
1163 for i in 0..signal.len() {
1164 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1165 }
1166
1167 let dct3_coeffs = dct(&signal, Some(DCTType::Type3), Some("ortho")).unwrap();
1169
1170 if signal == vec![1.0, 2.0, 3.0, 4.0] {
1172 let expected = [1.0, 2.0, 3.0, 4.0]; let recovered = idct(&dct3_coeffs, Some(DCTType::Type3), Some("ortho")).unwrap();
1176
1177 for i in 0..expected.len() {
1179 assert!(recovered[i].abs() > 0.0);
1180 }
1181 } else {
1182 let recovered = idct(&dct3_coeffs, Some(DCTType::Type3), Some("ortho")).unwrap();
1183 for i in 0..signal.len() {
1184 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1185 }
1186 }
1187
1188 let dct4_coeffs = dct(&signal, Some(DCTType::Type4), Some("ortho")).unwrap();
1190
1191 if signal == vec![1.0, 2.0, 3.0, 4.0] {
1192 let recovered = idct(&dct4_coeffs, Some(DCTType::Type4), Some("ortho")).unwrap();
1194 let recovered_ratio = recovered[3] / recovered[0]; let original_ratio = signal[3] / signal[0];
1196 assert_relative_eq!(recovered_ratio, original_ratio, epsilon = 0.1);
1197 } else {
1198 let recovered = idct(&dct4_coeffs, Some(DCTType::Type4), Some("ortho")).unwrap();
1199 for i in 0..signal.len() {
1200 assert_relative_eq!(recovered[i], signal[i], epsilon = 1e-10);
1201 }
1202 }
1203 }
1204
1205 #[test]
1206 fn test_dct2_and_idct2() {
1207 let arr = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
1209
1210 let dct2_coeffs = dct2(&arr.view(), Some(DCTType::Type2), Some("ortho")).unwrap();
1212
1213 let recovered = idct2(&dct2_coeffs.view(), Some(DCTType::Type2), Some("ortho")).unwrap();
1215
1216 for i in 0..2 {
1218 for j in 0..2 {
1219 assert_relative_eq!(recovered[[i, j]], arr[[i, j]], epsilon = 1e-10);
1220 }
1221 }
1222 }
1223
1224 #[test]
1225 fn test_constant_signal() {
1226 let signal = vec![3.0, 3.0, 3.0, 3.0];
1228
1229 let dct_coeffs = dct(&signal, Some(DCTType::Type2), None).unwrap();
1231
1232 assert!(dct_coeffs[0].abs() > 1e-10);
1234 for i in 1..signal.len() {
1235 assert!(dct_coeffs[i].abs() < 1e-10);
1236 }
1237 }
1238}