1use crate::err::try_vec;
30use crate::r2c::R2CTwiddlesHandler;
31use crate::r2c::c2r_twiddles::C2RTwiddlesFactory;
32use crate::traits::FftTrigonometry;
33use crate::util::compute_twiddle;
34use crate::{FftDirection, FftExecutor, ZaftError};
35use num_complex::Complex;
36use num_traits::{AsPrimitive, Float, MulAdd, Num, Zero};
37use std::ops::{Add, Mul, Neg, Sub};
38
39pub trait C2RFftExecutor<T> {
40 fn execute(&self, input: &[Complex<T>], output: &mut [T]) -> Result<(), ZaftError>;
41 fn real_length(&self) -> usize;
42 fn complex_length(&self) -> usize;
43}
44
45pub(crate) struct C2RFftEvenInterceptor<T> {
46 intercept: Box<dyn FftExecutor<T> + Send + Sync>,
47 twiddles: Vec<Complex<T>>,
48 length: usize,
49 complex_length: usize,
50 twiddles_handler: Box<dyn R2CTwiddlesHandler<T> + Send + Sync>,
51}
52
53impl<
54 T: Copy
55 + Clone
56 + FftTrigonometry
57 + Mul<T, Output = T>
58 + 'static
59 + Zero
60 + Num
61 + Float
62 + C2RTwiddlesFactory<T>,
63> C2RFftEvenInterceptor<T>
64where
65 f64: AsPrimitive<T>,
66{
67 pub(crate) fn install(
68 length: usize,
69 intercept: Box<dyn FftExecutor<T> + Send + Sync>,
70 ) -> Result<Self, ZaftError> {
71 assert_eq!(length % 2, 0, "R2C must be even in even interceptor");
72 assert_eq!(
73 intercept.length(),
74 length / 2,
75 "Underlying interceptor must have a half-length of real values"
76 );
77 assert_eq!(
78 intercept.direction(),
79 FftDirection::Inverse,
80 "Complex to real fft must be inverse"
81 );
82
83 let twiddles_count = if length % 4 == 0 {
84 length / 4
85 } else {
86 length / 4 + 1
87 };
88 let mut twiddles = try_vec![Complex::<T>::zero(); twiddles_count - 1];
89 for (i, twiddle) in twiddles.iter_mut().enumerate() {
90 *twiddle = compute_twiddle(i + 1, length, FftDirection::Inverse);
91 }
92 Ok(Self {
93 intercept,
94 twiddles,
95 length,
96 complex_length: length / 2 + 1,
97 twiddles_handler: T::make_c2r_twiddles_handler(),
98 })
99 }
100}
101
102impl<
103 T: Copy
104 + Mul<T, Output = T>
105 + Add<T, Output = T>
106 + Sub<T, Output = T>
107 + Num
108 + 'static
109 + Neg<Output = T>
110 + MulAdd<T, Output = T>,
111> C2RFftExecutor<T> for C2RFftEvenInterceptor<T>
112where
113 f64: AsPrimitive<T>,
114{
115 fn execute(&self, input: &[Complex<T>], output: &mut [T]) -> Result<(), ZaftError> {
116 if output.len() % self.length != 0 {
117 return Err(ZaftError::InvalidSizeMultiplier(input.len(), self.length));
118 }
119 if input.len() % self.complex_length != 0 {
120 return Err(ZaftError::InvalidSizeMultiplier(
121 output.len(),
122 self.complex_length,
123 ));
124 }
125
126 let mut scratch = try_vec![Complex::<T>::zero(); self.complex_length];
127
128 for (input, output) in input
129 .chunks_exact(self.complex_length)
130 .zip(output.chunks_exact_mut(self.length))
131 {
132 scratch.copy_from_slice(input);
133 scratch[0].im = 0.0f64.as_();
134 scratch.last_mut().unwrap().im = 0.0f64.as_();
135
136 let (mut input_left, mut input_right) = scratch.split_at_mut(input.len() / 2);
137
138 match (input_left.first_mut(), input_right.last_mut()) {
141 (Some(first_input), Some(last_input)) => {
142 let first_sum = *first_input + *last_input;
143 let first_diff = *first_input - *last_input;
144
145 *first_input = Complex {
146 re: first_sum.re - first_sum.im,
147 im: first_diff.re - first_diff.im,
148 };
149
150 input_left = &mut input_left[1..];
151 let right_len = input_right.len();
152 input_right = &mut input_right[..right_len - 1];
153 }
154 _ => return Ok(()),
155 };
156
157 self.twiddles_handler
158 .handle(&self.twiddles, input_left, input_right);
159
160 if scratch.len() % 2 == 1 {
162 let center_element = input[input.len() / 2];
163 let doubled = center_element + center_element;
164 scratch[input.len() / 2] = doubled.conj();
165 }
166
167 self.intercept.execute(&mut scratch[..output.len() / 2])?;
168
169 for (dst, src) in output.chunks_exact_mut(2).zip(scratch.iter()) {
170 dst[0] = src.re;
171 dst[1] = src.im;
172 }
173 }
174
175 Ok(())
176 }
177
178 fn real_length(&self) -> usize {
179 self.length
180 }
181
182 fn complex_length(&self) -> usize {
183 self.complex_length
184 }
185}
186
187pub(crate) struct C2RFftOddInterceptor<T> {
188 intercept: Box<dyn FftExecutor<T> + Send + Sync>,
189 length: usize,
190 complex_length: usize,
191}
192
193impl<T: Copy + Clone + FftTrigonometry + Mul<T, Output = T> + 'static + Zero + Num + Float>
194 C2RFftOddInterceptor<T>
195where
196 f64: AsPrimitive<T>,
197{
198 pub(crate) fn install(
199 length: usize,
200 intercept: Box<dyn FftExecutor<T> + Send + Sync>,
201 ) -> Result<Self, ZaftError> {
202 assert_ne!(length % 2, 0, "R2C must be even in even interceptor");
203 assert_eq!(
204 intercept.length(),
205 length,
206 "Underlying interceptor must have full length of real values"
207 );
208 assert_eq!(
209 intercept.direction(),
210 FftDirection::Inverse,
211 "Complex to real fft must be inverse"
212 );
213
214 Ok(Self {
215 intercept,
216 length,
217 complex_length: length / 2 + 1,
218 })
219 }
220}
221
222impl<
223 T: Copy
224 + Mul<T, Output = T>
225 + Add<T, Output = T>
226 + Sub<T, Output = T>
227 + Num
228 + 'static
229 + Neg<Output = T>
230 + MulAdd<T, Output = T>,
231> C2RFftExecutor<T> for C2RFftOddInterceptor<T>
232where
233 f64: AsPrimitive<T>,
234{
235 fn execute(&self, input: &[Complex<T>], output: &mut [T]) -> Result<(), ZaftError> {
236 if output.len() % self.length != 0 {
237 return Err(ZaftError::InvalidSizeMultiplier(input.len(), self.length));
238 }
239 if input.len() % self.complex_length != 0 {
240 return Err(ZaftError::InvalidSizeMultiplier(
241 output.len(),
242 self.complex_length,
243 ));
244 }
245
246 let mut scratch = try_vec![Complex::<T>::zero(); self.length];
247
248 for (input, output) in input
249 .chunks_exact(self.complex_length)
250 .zip(output.chunks_exact_mut(self.length))
251 {
252 scratch[..input.len()].copy_from_slice(input);
253 scratch[0].im = 0.0.as_();
254 for (buf, val) in scratch
255 .iter_mut()
256 .rev()
257 .take(self.length / 2)
258 .zip(input.iter().skip(1))
259 {
260 *buf = val.conj();
261 }
262 self.intercept.execute(&mut scratch)?;
263 for (dst, src) in output.iter_mut().zip(scratch.iter()) {
264 *dst = src.re;
265 }
266 }
267
268 Ok(())
269 }
270
271 fn real_length(&self) -> usize {
272 self.length
273 }
274
275 fn complex_length(&self) -> usize {
276 self.complex_length
277 }
278}