zaft/r2c/
c2r.rs

1/*
2 * // Copyright (c) Radzivon Bartoshyk 10/2025. All rights reserved.
3 * //
4 * // Redistribution and use in source and binary forms, with or without modification,
5 * // are permitted provided that the following conditions are met:
6 * //
7 * // 1.  Redistributions of source code must retain the above copyright notice, this
8 * // list of conditions and the following disclaimer.
9 * //
10 * // 2.  Redistributions in binary form must reproduce the above copyright notice,
11 * // this list of conditions and the following disclaimer in the documentation
12 * // and/or other materials provided with the distribution.
13 * //
14 * // 3.  Neither the name of the copyright holder nor the names of its
15 * // contributors may be used to endorse or promote products derived from
16 * // this software without specific prior written permission.
17 * //
18 * // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 * // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 * // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 * // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 * // FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 * // DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 * // SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 * // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 * // OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 * // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 */
29use crate::err::try_vec;
30use crate::r2c::R2CTwiddlesHandler;
31use crate::r2c::c2r_twiddles::C2RTwiddlesFactory;
32use crate::traits::FftTrigonometry;
33use crate::util::compute_twiddle;
34use crate::{FftDirection, FftExecutor, ZaftError};
35use num_complex::Complex;
36use num_traits::{AsPrimitive, Float, MulAdd, Num, Zero};
37use std::ops::{Add, Mul, Neg, Sub};
38
39pub trait C2RFftExecutor<T> {
40    fn execute(&self, input: &[Complex<T>], output: &mut [T]) -> Result<(), ZaftError>;
41    fn real_length(&self) -> usize;
42    fn complex_length(&self) -> usize;
43}
44
45pub(crate) struct C2RFftEvenInterceptor<T> {
46    intercept: Box<dyn FftExecutor<T> + Send + Sync>,
47    twiddles: Vec<Complex<T>>,
48    length: usize,
49    complex_length: usize,
50    twiddles_handler: Box<dyn R2CTwiddlesHandler<T> + Send + Sync>,
51}
52
53impl<
54    T: Copy
55        + Clone
56        + FftTrigonometry
57        + Mul<T, Output = T>
58        + 'static
59        + Zero
60        + Num
61        + Float
62        + C2RTwiddlesFactory<T>,
63> C2RFftEvenInterceptor<T>
64where
65    f64: AsPrimitive<T>,
66{
67    pub(crate) fn install(
68        length: usize,
69        intercept: Box<dyn FftExecutor<T> + Send + Sync>,
70    ) -> Result<Self, ZaftError> {
71        assert_eq!(length % 2, 0, "R2C must be even in even interceptor");
72        assert_eq!(
73            intercept.length(),
74            length / 2,
75            "Underlying interceptor must have a half-length of real values"
76        );
77        assert_eq!(
78            intercept.direction(),
79            FftDirection::Inverse,
80            "Complex to real fft must be inverse"
81        );
82
83        let twiddles_count = if length % 4 == 0 {
84            length / 4
85        } else {
86            length / 4 + 1
87        };
88        let mut twiddles = try_vec![Complex::<T>::zero(); twiddles_count - 1];
89        for (i, twiddle) in twiddles.iter_mut().enumerate() {
90            *twiddle = compute_twiddle(i + 1, length, FftDirection::Inverse);
91        }
92        Ok(Self {
93            intercept,
94            twiddles,
95            length,
96            complex_length: length / 2 + 1,
97            twiddles_handler: T::make_c2r_twiddles_handler(),
98        })
99    }
100}
101
102impl<
103    T: Copy
104        + Mul<T, Output = T>
105        + Add<T, Output = T>
106        + Sub<T, Output = T>
107        + Num
108        + 'static
109        + Neg<Output = T>
110        + MulAdd<T, Output = T>,
111> C2RFftExecutor<T> for C2RFftEvenInterceptor<T>
112where
113    f64: AsPrimitive<T>,
114{
115    fn execute(&self, input: &[Complex<T>], output: &mut [T]) -> Result<(), ZaftError> {
116        if output.len() % self.length != 0 {
117            return Err(ZaftError::InvalidSizeMultiplier(input.len(), self.length));
118        }
119        if input.len() % self.complex_length != 0 {
120            return Err(ZaftError::InvalidSizeMultiplier(
121                output.len(),
122                self.complex_length,
123            ));
124        }
125
126        let mut scratch = try_vec![Complex::<T>::zero(); self.complex_length];
127
128        for (input, output) in input
129            .chunks_exact(self.complex_length)
130            .zip(output.chunks_exact_mut(self.length))
131        {
132            scratch.copy_from_slice(input);
133            scratch[0].im = 0.0f64.as_();
134            scratch.last_mut().unwrap().im = 0.0f64.as_();
135
136            let (mut input_left, mut input_right) = scratch.split_at_mut(input.len() / 2);
137
138            // We have to preprocess the input in-place before we send it to the FFT.
139            // The first and centermost values have to be preprocessed separately from the rest, so do that now.
140            match (input_left.first_mut(), input_right.last_mut()) {
141                (Some(first_input), Some(last_input)) => {
142                    let first_sum = *first_input + *last_input;
143                    let first_diff = *first_input - *last_input;
144
145                    *first_input = Complex {
146                        re: first_sum.re - first_sum.im,
147                        im: first_diff.re - first_diff.im,
148                    };
149
150                    input_left = &mut input_left[1..];
151                    let right_len = input_right.len();
152                    input_right = &mut input_right[..right_len - 1];
153                }
154                _ => return Ok(()),
155            };
156
157            self.twiddles_handler
158                .handle(&self.twiddles, input_left, input_right);
159
160            // If the output len is odd, the loop above can't preprocess the centermost element, so handle that separately
161            if scratch.len() % 2 == 1 {
162                let center_element = input[input.len() / 2];
163                let doubled = center_element + center_element;
164                scratch[input.len() / 2] = doubled.conj();
165            }
166
167            self.intercept.execute(&mut scratch[..output.len() / 2])?;
168
169            for (dst, src) in output.chunks_exact_mut(2).zip(scratch.iter()) {
170                dst[0] = src.re;
171                dst[1] = src.im;
172            }
173        }
174
175        Ok(())
176    }
177
178    fn real_length(&self) -> usize {
179        self.length
180    }
181
182    fn complex_length(&self) -> usize {
183        self.complex_length
184    }
185}
186
187pub(crate) struct C2RFftOddInterceptor<T> {
188    intercept: Box<dyn FftExecutor<T> + Send + Sync>,
189    length: usize,
190    complex_length: usize,
191}
192
193impl<T: Copy + Clone + FftTrigonometry + Mul<T, Output = T> + 'static + Zero + Num + Float>
194    C2RFftOddInterceptor<T>
195where
196    f64: AsPrimitive<T>,
197{
198    pub(crate) fn install(
199        length: usize,
200        intercept: Box<dyn FftExecutor<T> + Send + Sync>,
201    ) -> Result<Self, ZaftError> {
202        assert_ne!(length % 2, 0, "R2C must be even in even interceptor");
203        assert_eq!(
204            intercept.length(),
205            length,
206            "Underlying interceptor must have full length of real values"
207        );
208        assert_eq!(
209            intercept.direction(),
210            FftDirection::Inverse,
211            "Complex to real fft must be inverse"
212        );
213
214        Ok(Self {
215            intercept,
216            length,
217            complex_length: length / 2 + 1,
218        })
219    }
220}
221
222impl<
223    T: Copy
224        + Mul<T, Output = T>
225        + Add<T, Output = T>
226        + Sub<T, Output = T>
227        + Num
228        + 'static
229        + Neg<Output = T>
230        + MulAdd<T, Output = T>,
231> C2RFftExecutor<T> for C2RFftOddInterceptor<T>
232where
233    f64: AsPrimitive<T>,
234{
235    fn execute(&self, input: &[Complex<T>], output: &mut [T]) -> Result<(), ZaftError> {
236        if output.len() % self.length != 0 {
237            return Err(ZaftError::InvalidSizeMultiplier(input.len(), self.length));
238        }
239        if input.len() % self.complex_length != 0 {
240            return Err(ZaftError::InvalidSizeMultiplier(
241                output.len(),
242                self.complex_length,
243            ));
244        }
245
246        let mut scratch = try_vec![Complex::<T>::zero(); self.length];
247
248        for (input, output) in input
249            .chunks_exact(self.complex_length)
250            .zip(output.chunks_exact_mut(self.length))
251        {
252            scratch[..input.len()].copy_from_slice(input);
253            scratch[0].im = 0.0.as_();
254            for (buf, val) in scratch
255                .iter_mut()
256                .rev()
257                .take(self.length / 2)
258                .zip(input.iter().skip(1))
259            {
260                *buf = val.conj();
261            }
262            self.intercept.execute(&mut scratch)?;
263            for (dst, src) in output.iter_mut().zip(scratch.iter()) {
264                *dst = src.re;
265            }
266        }
267
268        Ok(())
269    }
270
271    fn real_length(&self) -> usize {
272        self.length
273    }
274
275    fn complex_length(&self) -> usize {
276        self.complex_length
277    }
278}