tfhe_fft/
lib.rs

1//! tfhe-fft is a pure Rust high performance fast Fourier transform library that processes
2//! vectors of sizes that are powers of two.
3//!
4//! This library provides two FFT modules:
5//!  - The ordered module FFT applies a forward/inverse FFT that takes its input in standard
6//!  order, and outputs the result in standard order. For more detail on what the FFT
7//!  computes, check the ordered module-level documentation.
8//!  - The unordered module FFT applies a forward FFT that takes its input in standard order,
9//!  and outputs the result in a certain permuted order that may depend on the FFT plan. On the
10//!  other hand, the inverse FFT takes its input in that same permuted order and outputs its result
11//!  in standard order. This is useful for cases where the order of the coefficients in the
12//!  Fourier domain is not important. An example is using the Fourier transform for vector
13//!  convolution. The only operations that are performed in the Fourier domain are elementwise, and
14//!  so the order of the coefficients does not affect the results.
15//!
16//! Additionally, an optional 128-bit negacyclic FFT module is provided.
17//!
18//! # Features
19//!
20//!  - `std` (default): This enables runtime arch detection for accelerated SIMD instructions, and
21//!  an FFT plan that measures the various implementations to choose the fastest one at runtime.
22//!  - `fft128`: This flag provides access to the 128-bit FFT, which is accessible in the
23//!  `fft128` module.
24//!  - `nightly`: This enables unstable Rust features to further speed up the FFT, by enabling
25//!  AVX512F instructions on CPUs that support them. This feature requires a nightly Rust
26//!  toolchain.
27//!  - `serde`: This enables serialization and deserialization functions for the unordered plan.
28//!  These allow for data in the Fourier domain to be serialized from the permuted order to the
29//!  standard order, and deserialized from the standard order to the permuted order.
30//!  This is needed since the inverse transform must be used with the same plan that
31//!  computed/deserialized the forward transform (or more specifically, a plan with the same
32//!  internal base FFT size).
33//!
34//! # Example
35#![cfg_attr(feature = "std", doc = "```")]
36#![cfg_attr(not(feature = "std"), doc = "```ignore")]
37//! use tfhe_fft::c64;
38//! use tfhe_fft::ordered::{Plan, Method};
39//! use dyn_stack::{PodStack, GlobalPodBuffer};
40//! use num_complex::ComplexFloat;
41//! use std::time::Duration;
42//!
43//! const N: usize = 4;
44//! let plan = Plan::new(4, Method::Measure(Duration::from_millis(10)));
45//! let mut scratch_memory = GlobalPodBuffer::new(plan.fft_scratch().unwrap());
46//! let stack = PodStack::new(&mut scratch_memory);
47//!
48//! let data = [
49//!     c64::new(1.0, 0.0),
50//!     c64::new(2.0, 0.0),
51//!     c64::new(3.0, 0.0),
52//!     c64::new(4.0, 0.0),
53//! ];
54//!
55//! let mut transformed_fwd = data;
56//! plan.fwd(&mut transformed_fwd, stack);
57//!
58//! let mut transformed_inv = transformed_fwd;
59//! plan.inv(&mut transformed_inv, stack);
60//!
61//! for (actual, expected) in transformed_inv.iter().map(|z| z / N as f64).zip(data) {
62//!     assert!((expected - actual).abs() < 1e-9);
63//! }
64//! ```
65
66#![cfg_attr(not(feature = "std"), no_std)]
67#![allow(
68    clippy::erasing_op,
69    clippy::identity_op,
70    clippy::zero_prefixed_literal,
71    clippy::excessive_precision,
72    clippy::type_complexity,
73    clippy::too_many_arguments,
74    non_camel_case_types
75)]
76#![cfg_attr(docsrs, feature(doc_cfg))]
77#![warn(rustdoc::broken_intra_doc_links)]
78
79use core::marker::PhantomData;
80
81use fft_simd::{FftSimd, Pod};
82use num_complex::Complex64;
83
84/// 64-bit complex floating point type.
85pub type c64 = Complex64;
86
87macro_rules! izip {
88    // implemented this way to avoid a bug with type hints in rust-analyzer
89    // https://github.com/rust-lang/rust-analyzer/issues/13526
90    (@ __closure @ ($a:expr)) => { |a| (a,) };
91    (@ __closure @ ($a:expr, $b:expr)) => { |(a, b)| (a, b) };
92    (@ __closure @ ($a:expr, $b:expr, $c:expr)) => { |((a, b), c)| (a, b, c) };
93    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr)) => { |(((a, b), c), d)| (a, b, c, d) };
94    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr)) => { |((((a, b), c), d), e)| (a, b, c, d, e) };
95    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr)) => { |(((((a, b), c), d), e), f)| (a, b, c, d, e, f) };
96    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr)) => { |((((((a, b), c), d), e), f), g)| (a, b, c, d, e, f, g) };
97    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr)) => { |(((((((a, b), c), d), e), f), g), h)| (a, b, c, d, e, f, g, h) };
98    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr)) => { |((((((((a, b), c), d), e), f), g), h), i)| (a, b, c, d, e, f, g, h, i) };
99    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr)) => { |(((((((((a, b), c), d), e), f), g), h), i), j)| (a, b, c, d, e, f, g, h, i, j) };
100    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr)) => { |((((((((((a, b), c), d), e), f), g), h), i), j), k)| (a, b, c, d, e, f, g, h, i, j, k) };
101    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr)) => { |(((((((((((a, b), c), d), e), f), g), h), i), j), k), l)| (a, b, c, d, e, f, g, h, i, j, k, l) };
102    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr)) => { |((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m)| (a, b, c, d, e, f, g, h, i, j, k, l, m) };
103    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr)) => { |(((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n) };
104    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr, $o:expr)) => { |((((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n), o)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) };
105    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr, $o:expr, $p: expr)) => { |(((((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n), o), p)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p) };
106
107    ( $first:expr $(,)?) => {
108        {
109            ::core::iter::IntoIterator::into_iter($first)
110        }
111    };
112    ( $first:expr, $($rest:expr),+ $(,)?) => {
113        {
114            ::core::iter::IntoIterator::into_iter($first)
115                $(.zip($rest))*
116                .map(izip!(@ __closure @ ($first, $($rest),*)))
117        }
118    };
119}
120
121mod fft_simd;
122mod nat;
123
124#[cfg(feature = "std")]
125pub(crate) mod time;
126
127#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
128mod x86;
129
130type FnArray = [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 10];
131
132#[derive(Copy, Clone)]
133struct FftImpl {
134    fwd: FnArray,
135    inv: FnArray,
136}
137
138impl FftImpl {
139    #[inline]
140    pub fn make_fn_ptr(&self, n: usize) -> [fn(&mut [c64], &mut [c64], &[c64], &[c64]); 2] {
141        let idx = n.trailing_zeros() as usize - 1;
142        [self.fwd[idx], self.inv[idx]]
143    }
144}
145
146/// Computes the FFT of size 2^(N+1).
147trait RecursiveFft: nat::Nat {
148    fn fft_recurse_impl<c64xN: Pod>(
149        simd: impl FftSimd<c64xN>,
150        fwd: bool,
151        read_from_x: bool,
152        s: usize,
153        x: &mut [c64xN],
154        y: &mut [c64xN],
155        w_init: &[c64xN],
156        w: &[c64],
157    );
158}
159
160#[inline]
161fn fn_ptr<const FWD: bool, N: RecursiveFft, c64xN: Pod, Simd: FftSimd<c64xN>>(
162    simd: Simd,
163) -> fn(&mut [c64], &mut [c64], &[c64], &[c64]) {
164    // we can't pass `simd` to the closure even though it's a zero-sized struct,
165    // because we want the closure to be coercible to a function pointer.
166    // so we ignore the passed parameter and reconstruct it inside the closure -------------
167    let _ = simd;
168
169    #[inline(never)]
170    |buf: &mut [c64], scratch: &mut [c64], w_init: &[c64], w: &[c64]| {
171        struct Impl<'a, const FWD: bool, N, c64xN, Simd> {
172            simd: Simd,
173            buf: &'a mut [c64],
174            scratch: &'a mut [c64],
175            w_init: &'a [c64],
176            w: &'a [c64],
177            __marker: PhantomData<(N, c64xN)>,
178        }
179        // `simd` is reconstructed here. we know the unwrap can never fail because it was already
180        // passed to us as a function parameter, which proves that it's possible to construct.
181        let simd = Simd::try_new().unwrap();
182
183        // we use NullaryFnOnce instead of a closure because we need the #[inline(always)]
184        // annotation, which doesn't always work with closures for some reason.
185        impl<const FWD: bool, N: RecursiveFft, c64xN: Pod, Simd: FftSimd<c64xN>> pulp::NullaryFnOnce
186            for Impl<'_, FWD, N, c64xN, Simd>
187        {
188            type Output = ();
189
190            #[inline(always)]
191            fn call(self) -> Self::Output {
192                let Self {
193                    simd,
194                    buf,
195                    scratch,
196                    w_init,
197                    w,
198                    __marker: _,
199                } = self;
200                let n = 1 << (N::VALUE + 1);
201                assert_eq!(buf.len(), n);
202                assert_eq!(scratch.len(), n);
203                assert_eq!(w_init.len(), n);
204                assert_eq!(w.len(), n);
205                N::fft_recurse_impl(
206                    simd,
207                    FWD,
208                    true,
209                    1,
210                    bytemuck::cast_slice_mut(buf),
211                    bytemuck::cast_slice_mut(scratch),
212                    bytemuck::cast_slice(w_init),
213                    w,
214                );
215            }
216        }
217
218        simd.vectorize(Impl::<FWD, N, c64xN, Simd> {
219            simd,
220            buf,
221            scratch,
222            w_init,
223            w,
224            __marker: PhantomData,
225        })
226    }
227}
228
229mod dif2;
230mod dit2;
231
232mod dif4;
233mod dit4;
234
235mod dif8;
236mod dit8;
237
238mod dif16;
239mod dit16;
240
241pub mod ordered;
242pub mod unordered;
243
244#[cfg(feature = "fft128")]
245#[cfg_attr(docsrs, doc(cfg(feature = "fft128")))]
246pub mod fft128;