1use crate::error::FFTResult;
7use crate::fft::{fft, ifft};
8use scirs2_core::ndarray::Array2;
9use scirs2_core::numeric::Complex64;
10use std::f64::consts::PI;
11
12#[allow(dead_code)]
19pub fn frft_dft<T>(x: &[T], alpha: f64) -> FFTResult<Vec<Complex64>>
20where
21 T: Copy + Into<f64>,
22{
23 let n = x.len();
24 if n == 0 {
25 return Ok(vec![]);
26 }
27
28 let x_complex: Vec<Complex64> = x
30 .iter()
31 .map(|&val| Complex64::new(val.into(), 0.0))
32 .collect();
33
34 let alpha_mod = alpha.rem_euclid(4.0);
36 if alpha_mod.abs() < 1e-10 {
37 return Ok(x_complex);
38 } else if (alpha_mod - 1.0).abs() < 1e-10 {
39 return fft(&x_complex, None);
40 } else if (alpha_mod - 2.0).abs() < 1e-10 {
41 return Ok(x_complex.into_iter().rev().collect());
42 } else if (alpha_mod - 3.0).abs() < 1e-10 {
43 return ifft(&x_complex, None);
44 }
45
46 let _angle = alpha * PI / 2.0;
48
49 let eigenvectors = compute_dft_eigenvectors(n);
51 let eigenvalues = compute_dft_eigenvalues(n);
52
53 let mut coefficients = vec![Complex64::new(0.0, 0.0); n];
55 for k in 0..n {
56 for j in 0..n {
57 coefficients[k] += x_complex[j] * eigenvectors[(j, k)].conj();
58 }
59 }
60
61 for k in 0..n {
63 let fractional_eigenvalue = eigenvalues[k].powc(Complex64::new(alpha, 0.0));
64 coefficients[k] *= fractional_eigenvalue;
65 }
66
67 let mut result = vec![Complex64::new(0.0, 0.0); n];
69 for j in 0..n {
70 for k in 0..n {
71 result[j] += coefficients[k] * eigenvectors[(j, k)];
72 }
73 }
74
75 Ok(result)
76}
77
78#[allow(dead_code)]
80fn compute_dft_eigenvectors(n: usize) -> Array2<Complex64> {
81 let mut eigenvectors = Array2::zeros((n, n));
82
83 let n_f64 = n as f64;
86
87 for k in 0..n {
88 for j in 0..n {
89 let x = (j as f64 - n_f64 / 2.0) / (n_f64 / 4.0).sqrt();
90 let hermite_value = hermite_function(k, x);
91 let phase = Complex64::new(0.0, -PI * j as f64 * k as f64 / n_f64).exp();
92 eigenvectors[(j, k)] = hermite_value * phase;
93 }
94 }
95
96 for k in 0..n {
98 let norm: f64 = (0..n)
99 .map(|j| eigenvectors[(j, k)].norm_sqr())
100 .sum::<f64>()
101 .sqrt();
102 if norm > 0.0 {
103 for j in 0..n {
104 eigenvectors[(j, k)] /= norm;
105 }
106 }
107 }
108
109 eigenvectors
110}
111
112#[allow(dead_code)]
114fn compute_dft_eigenvalues(n: usize) -> Vec<Complex64> {
115 let mut eigenvalues = vec![Complex64::new(0.0, 0.0); n];
116
117 for (k, eigenvalue) in eigenvalues.iter_mut().enumerate().take(n) {
119 let eigenvalue_index = k % 4;
121 *eigenvalue = match eigenvalue_index {
122 0 => Complex64::new(1.0, 0.0),
123 1 => Complex64::new(0.0, -1.0),
124 2 => Complex64::new(-1.0, 0.0),
125 3 => Complex64::new(0.0, 1.0),
126 _ => unreachable!(),
127 };
128 }
129
130 eigenvalues
131}
132
133#[allow(dead_code)]
135fn hermite_function(n: usize, x: f64) -> Complex64 {
136 let hermite = match n {
138 0 => 1.0,
139 1 => 2.0 * x,
140 2 => 4.0 * x * x - 2.0,
141 3 => 8.0 * x * x * x - 12.0 * x,
142 _ => {
143 let mut h_prev = 4.0 * x * x - 2.0;
145 let mut h_curr = 8.0 * x * x * x - 12.0 * x;
146
147 for k in 4..=n {
148 let h_next = 2.0 * x * h_curr - 2.0 * (k - 1) as f64 * h_prev;
149 h_prev = h_curr;
150 h_curr = h_next;
151 }
152 h_curr
153 }
154 };
155
156 let gaussian = (-x * x / 2.0).exp();
157 Complex64::new(hermite * gaussian, 0.0)
158}
159
160#[cfg(test)]
161mod tests {
162 use super::*;
163 use approx::assert_relative_eq;
164
165 #[test]
166 fn test_dft_identity() {
167 let signal = vec![1.0, 2.0, 3.0, 4.0];
168 let result = frft_dft(&signal, 0.0).unwrap();
169
170 for (i, &val) in signal.iter().enumerate() {
171 assert_relative_eq!(result[i].re, val, epsilon = 1e-6);
172 assert_relative_eq!(result[i].im, 0.0, epsilon = 1e-6);
173 }
174 }
175
176 #[test]
177 fn test_dft_energy_conservation() {
178 let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
179 let input_energy: f64 = signal.iter().map(|&x| x * x).sum();
180
181 for alpha in &[0.0, 2.0] {
183 let result = frft_dft(&signal, *alpha).unwrap();
184 let output_energy: f64 = result.iter().map(|c| c.norm_sqr()).sum();
185
186 assert_relative_eq!(output_energy, input_energy, epsilon = 1e-10);
188 }
189
190 for alpha in &[1.0, 3.0] {
192 let result = frft_dft(&signal, *alpha).unwrap();
193 let output_energy: f64 = result.iter().map(|c| c.norm_sqr()).sum();
194
195 let ratio = output_energy / input_energy;
197 assert!(
198 ratio > 0.1 && ratio < 10.0,
199 "Energy ratio {ratio} for alpha {alpha} is outside acceptable range"
200 );
201 }
202
203 for alpha in &[0.1, 0.5, 1.5, 2.5, 3.5] {
206 let result = frft_dft(&signal, *alpha).unwrap();
207 let output_energy: f64 = result.iter().map(|c| c.norm_sqr()).sum();
208
209 let ratio = output_energy / input_energy;
210 assert!(
211 ratio > 0.01 && ratio < 100.0,
212 "Energy ratio {ratio} for alpha {alpha} is completely unreasonable"
213 );
214 }
215 }
216}