Skip to main content

torsh_backend/
fft.rs

1//! Fast Fourier Transform operations for all backends
2//!
3//! This module provides a unified interface for FFT operations across all backends,
4//! with optimized implementations for each platform.
5
6use crate::{BackendResult, Buffer, Device};
7use torsh_core::dtype::DType;
8
9#[cfg(not(feature = "std"))]
10use alloc::{boxed::Box, string::String, vec::Vec};
11
12/// FFT direction
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum FftDirection {
15    /// Forward FFT
16    Forward,
17    /// Inverse FFT
18    Inverse,
19}
20
21/// FFT normalization mode
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum FftNormalization {
24    /// No normalization
25    None,
26    /// Normalize by 1/N
27    Backward,
28    /// Normalize by 1/sqrt(N)
29    Ortho,
30}
31
32/// FFT operation type
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum FftType {
35    /// 1D Complex-to-Complex FFT
36    C2C,
37    /// 1D Real-to-Complex FFT
38    R2C,
39    /// 1D Complex-to-Real FFT
40    C2R,
41    /// 2D Complex-to-Complex FFT
42    C2C2D,
43    /// 2D Real-to-Complex FFT
44    R2C2D,
45    /// 2D Complex-to-Real FFT
46    C2R2D,
47    /// 3D Complex-to-Complex FFT
48    C2C3D,
49    /// 3D Real-to-Complex FFT
50    R2C3D,
51    /// 3D Complex-to-Real FFT
52    C2R3D,
53}
54
55/// FFT execution plan
56#[derive(Debug, Clone)]
57pub struct FftPlan {
58    /// Plan ID for caching
59    pub id: String,
60    /// FFT type
61    pub fft_type: FftType,
62    /// Transform dimensions
63    pub dimensions: Vec<usize>,
64    /// Batch size
65    pub batch_size: usize,
66    /// Input data type
67    pub input_dtype: DType,
68    /// Output data type
69    pub output_dtype: DType,
70    /// Direction
71    pub direction: FftDirection,
72    /// Normalization mode
73    pub normalization: FftNormalization,
74    /// Backend-specific plan data
75    pub backend_data: Vec<u8>,
76}
77
78impl FftPlan {
79    /// Create a new FFT plan
80    pub fn new(
81        fft_type: FftType,
82        dimensions: Vec<usize>,
83        batch_size: usize,
84        input_dtype: DType,
85        output_dtype: DType,
86        direction: FftDirection,
87        normalization: FftNormalization,
88    ) -> Self {
89        let id = format!(
90            "{:?}_{:?}_{}_{}_{:?}_{:?}_{:?}",
91            fft_type, dimensions, batch_size, input_dtype, output_dtype, direction, normalization
92        );
93
94        Self {
95            id,
96            fft_type,
97            dimensions,
98            batch_size,
99            input_dtype,
100            output_dtype,
101            direction,
102            normalization,
103            backend_data: Vec::new(),
104        }
105    }
106
107    /// Create a new 1D FFT plan with default parameters
108    ///
109    /// This is a convenience function for creating 1D FFT plans commonly used in benchmarks.
110    ///
111    /// # Arguments
112    ///
113    /// * `size` - Size of the 1D FFT
114    /// * `direction` - Forward or inverse transform
115    ///
116    /// # Returns
117    ///
118    /// A new FftPlan configured for 1D transforms
119    pub fn new_1d(size: usize, direction: FftDirection) -> Self {
120        Self::new(
121            FftType::C2C,
122            vec![size],
123            1,          // Single batch
124            DType::C64, // Complex64 input
125            DType::C64, // Complex64 output
126            direction,
127            FftNormalization::None,
128        )
129    }
130
131    /// Get the total number of elements in the transform
132    pub fn total_elements(&self) -> usize {
133        self.dimensions.iter().product::<usize>() * self.batch_size
134    }
135
136    /// Get input buffer size in bytes
137    pub fn input_buffer_size(&self) -> usize {
138        let element_size = match self.input_dtype {
139            DType::F32 => 4,
140            DType::F64 => 8,
141            DType::C64 => 8,
142            DType::C128 => 16,
143            _ => 4, // Default to f32
144        };
145
146        self.total_elements() * element_size
147    }
148
149    /// Get output buffer size in bytes
150    pub fn output_buffer_size(&self) -> usize {
151        let element_size = match self.output_dtype {
152            DType::F32 => 4,
153            DType::F64 => 8,
154            DType::C64 => 8,
155            DType::C128 => 16,
156            _ => 8, // Default to c32
157        };
158
159        match self.fft_type {
160            FftType::R2C | FftType::R2C2D | FftType::R2C3D => {
161                // Real-to-complex transforms have reduced output size
162                let mut output_elements = self.batch_size;
163                for (i, &dim) in self.dimensions.iter().enumerate() {
164                    if i == self.dimensions.len() - 1 {
165                        // Last dimension is halved + 1 for R2C
166                        output_elements *= (dim / 2) + 1;
167                    } else {
168                        output_elements *= dim;
169                    }
170                }
171                output_elements * element_size
172            }
173            _ => self.total_elements() * element_size,
174        }
175    }
176
177    /// Check if the plan is valid
178    pub fn is_valid(&self) -> bool {
179        !self.dimensions.is_empty() && self.batch_size > 0 && self.dimensions.iter().all(|&d| d > 0)
180    }
181}
182
183/// FFT operations trait
184#[async_trait::async_trait]
185pub trait FftOps: Send + Sync {
186    /// Create an FFT plan
187    async fn create_fft_plan(
188        &self,
189        device: &Device,
190        plan: &FftPlan,
191    ) -> BackendResult<Box<dyn FftExecutor>>;
192
193    /// Execute a 1D FFT
194    async fn fft_1d(
195        &self,
196        device: &Device,
197        input: &Buffer,
198        output: &Buffer,
199        size: usize,
200        direction: FftDirection,
201        normalization: FftNormalization,
202    ) -> BackendResult<()>;
203
204    /// Execute a 2D FFT
205    async fn fft_2d(
206        &self,
207        device: &Device,
208        input: &Buffer,
209        output: &Buffer,
210        size: (usize, usize),
211        direction: FftDirection,
212        normalization: FftNormalization,
213    ) -> BackendResult<()>;
214
215    /// Execute a 3D FFT
216    async fn fft_3d(
217        &self,
218        device: &Device,
219        input: &Buffer,
220        output: &Buffer,
221        size: (usize, usize, usize),
222        direction: FftDirection,
223        normalization: FftNormalization,
224    ) -> BackendResult<()>;
225
226    /// Execute a batched FFT
227    async fn fft_batch(
228        &self,
229        device: &Device,
230        input: &Buffer,
231        output: &Buffer,
232        size: &[usize],
233        batch_size: usize,
234        direction: FftDirection,
235        normalization: FftNormalization,
236    ) -> BackendResult<()>;
237
238    /// Execute a real-to-complex FFT
239    async fn rfft(
240        &self,
241        device: &Device,
242        input: &Buffer,
243        output: &Buffer,
244        size: &[usize],
245        direction: FftDirection,
246        normalization: FftNormalization,
247    ) -> BackendResult<()>;
248
249    /// Execute a complex-to-real FFT
250    async fn irfft(
251        &self,
252        device: &Device,
253        input: &Buffer,
254        output: &Buffer,
255        size: &[usize],
256        normalization: FftNormalization,
257    ) -> BackendResult<()>;
258
259    /// Check if FFT operations are supported
260    fn supports_fft(&self) -> bool;
261
262    /// Get optimal FFT sizes for performance
263    fn get_optimal_fft_sizes(&self, min_size: usize, max_size: usize) -> Vec<usize>;
264}
265
266/// FFT executor for cached plans
267#[async_trait::async_trait]
268pub trait FftExecutor: Send + Sync {
269    /// Execute the FFT plan
270    async fn execute(&self, device: &Device, input: &Buffer, output: &Buffer) -> BackendResult<()>;
271
272    /// Get the plan this executor was created for
273    fn plan(&self) -> &FftPlan;
274
275    /// Get memory requirements for execution
276    fn memory_requirements(&self) -> usize;
277
278    /// Check if the executor is valid
279    fn is_valid(&self) -> bool;
280}
281
282/// Default FFT operations implementation
283pub struct DefaultFftOps;
284
285impl DefaultFftOps {
286    /// Create a new DefaultFftOps instance
287    pub fn new() -> Self {
288        Self
289    }
290}
291
292impl Default for DefaultFftOps {
293    fn default() -> Self {
294        Self::new()
295    }
296}
297
298#[async_trait::async_trait]
299impl FftOps for DefaultFftOps {
300    async fn create_fft_plan(
301        &self,
302        _device: &Device,
303        plan: &FftPlan,
304    ) -> BackendResult<Box<dyn FftExecutor>> {
305        Ok(Box::new(DefaultFftExecutor { plan: plan.clone() }))
306    }
307
308    async fn fft_1d(
309        &self,
310        _device: &Device,
311        _input: &Buffer,
312        _output: &Buffer,
313        _size: usize,
314        _direction: FftDirection,
315        _normalization: FftNormalization,
316    ) -> BackendResult<()> {
317        Err(torsh_core::error::TorshError::BackendError(
318            "FFT operations not implemented for this backend".to_string(),
319        ))
320    }
321
322    async fn fft_2d(
323        &self,
324        _device: &Device,
325        _input: &Buffer,
326        _output: &Buffer,
327        _size: (usize, usize),
328        _direction: FftDirection,
329        _normalization: FftNormalization,
330    ) -> BackendResult<()> {
331        Err(torsh_core::error::TorshError::BackendError(
332            "FFT operations not implemented for this backend".to_string(),
333        ))
334    }
335
336    async fn fft_3d(
337        &self,
338        _device: &Device,
339        _input: &Buffer,
340        _output: &Buffer,
341        _size: (usize, usize, usize),
342        _direction: FftDirection,
343        _normalization: FftNormalization,
344    ) -> BackendResult<()> {
345        Err(torsh_core::error::TorshError::BackendError(
346            "FFT operations not implemented for this backend".to_string(),
347        ))
348    }
349
350    async fn fft_batch(
351        &self,
352        _device: &Device,
353        _input: &Buffer,
354        _output: &Buffer,
355        _size: &[usize],
356        _batch_size: usize,
357        _direction: FftDirection,
358        _normalization: FftNormalization,
359    ) -> BackendResult<()> {
360        Err(torsh_core::error::TorshError::BackendError(
361            "FFT operations not implemented for this backend".to_string(),
362        ))
363    }
364
365    async fn rfft(
366        &self,
367        _device: &Device,
368        _input: &Buffer,
369        _output: &Buffer,
370        _size: &[usize],
371        _direction: FftDirection,
372        _normalization: FftNormalization,
373    ) -> BackendResult<()> {
374        Err(torsh_core::error::TorshError::BackendError(
375            "FFT operations not implemented for this backend".to_string(),
376        ))
377    }
378
379    async fn irfft(
380        &self,
381        _device: &Device,
382        _input: &Buffer,
383        _output: &Buffer,
384        _size: &[usize],
385        _normalization: FftNormalization,
386    ) -> BackendResult<()> {
387        Err(torsh_core::error::TorshError::BackendError(
388            "FFT operations not implemented for this backend".to_string(),
389        ))
390    }
391
392    fn supports_fft(&self) -> bool {
393        false
394    }
395
396    fn get_optimal_fft_sizes(&self, min_size: usize, max_size: usize) -> Vec<usize> {
397        // Return power-of-2 sizes as default
398        let mut sizes = Vec::new();
399        let mut size = 1;
400        while size < min_size {
401            size *= 2;
402        }
403        while size <= max_size {
404            sizes.push(size);
405            size *= 2;
406        }
407        sizes
408    }
409}
410
411/// Default FFT executor implementation
412pub struct DefaultFftExecutor {
413    plan: FftPlan,
414}
415
416#[async_trait::async_trait]
417impl FftExecutor for DefaultFftExecutor {
418    async fn execute(
419        &self,
420        _device: &Device,
421        _input: &Buffer,
422        _output: &Buffer,
423    ) -> BackendResult<()> {
424        Err(torsh_core::error::TorshError::BackendError(
425            "FFT execution not implemented for this backend".to_string(),
426        ))
427    }
428
429    fn plan(&self) -> &FftPlan {
430        &self.plan
431    }
432
433    fn memory_requirements(&self) -> usize {
434        self.plan.input_buffer_size() + self.plan.output_buffer_size()
435    }
436
437    fn is_valid(&self) -> bool {
438        self.plan.is_valid()
439    }
440}
441
442/// Convenience functions for common FFT operations
443pub mod convenience {
444    use super::*;
445
446    /// Create a 1D complex-to-complex FFT plan
447    pub fn create_c2c_1d_plan(
448        size: usize,
449        batch_size: usize,
450        direction: FftDirection,
451        normalization: FftNormalization,
452    ) -> FftPlan {
453        FftPlan::new(
454            FftType::C2C,
455            vec![size],
456            batch_size,
457            DType::C64,
458            DType::C64,
459            direction,
460            normalization,
461        )
462    }
463
464    /// Create a 1D real-to-complex FFT plan
465    pub fn create_r2c_1d_plan(
466        size: usize,
467        batch_size: usize,
468        normalization: FftNormalization,
469    ) -> FftPlan {
470        FftPlan::new(
471            FftType::R2C,
472            vec![size],
473            batch_size,
474            DType::F32,
475            DType::C64,
476            FftDirection::Forward,
477            normalization,
478        )
479    }
480
481    /// Create a 2D complex-to-complex FFT plan
482    pub fn create_c2c_2d_plan(
483        size: (usize, usize),
484        batch_size: usize,
485        direction: FftDirection,
486        normalization: FftNormalization,
487    ) -> FftPlan {
488        FftPlan::new(
489            FftType::C2C2D,
490            vec![size.0, size.1],
491            batch_size,
492            DType::C64,
493            DType::C64,
494            direction,
495            normalization,
496        )
497    }
498
499    /// Create a 3D complex-to-complex FFT plan
500    pub fn create_c2c_3d_plan(
501        size: (usize, usize, usize),
502        batch_size: usize,
503        direction: FftDirection,
504        normalization: FftNormalization,
505    ) -> FftPlan {
506        FftPlan::new(
507            FftType::C2C3D,
508            vec![size.0, size.1, size.2],
509            batch_size,
510            DType::C64,
511            DType::C64,
512            direction,
513            normalization,
514        )
515    }
516
517    /// Get the next power of 2 greater than or equal to n
518    pub fn next_power_of_2(n: usize) -> usize {
519        if n == 0 {
520            return 1;
521        }
522        let mut power = 1;
523        while power < n {
524            power *= 2;
525        }
526        power
527    }
528
529    /// Check if a size is optimal for FFT (power of 2, 3, 5, 7)
530    pub fn is_optimal_fft_size(size: usize) -> bool {
531        if size == 0 {
532            return false;
533        }
534
535        let mut n = size;
536        for prime in &[2, 3, 5, 7] {
537            while n % prime == 0 {
538                n /= prime;
539            }
540        }
541
542        n == 1
543    }
544
545    /// Find the next optimal FFT size
546    pub fn next_optimal_fft_size(size: usize) -> usize {
547        let mut candidate = size;
548        while !is_optimal_fft_size(candidate) {
549            candidate += 1;
550        }
551        candidate
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558
559    #[test]
560    fn test_fft_plan_creation() {
561        let plan = FftPlan::new(
562            FftType::C2C,
563            vec![1024],
564            1,
565            DType::C64,
566            DType::C64,
567            FftDirection::Forward,
568            FftNormalization::None,
569        );
570
571        assert_eq!(plan.fft_type, FftType::C2C);
572        assert_eq!(plan.dimensions, vec![1024]);
573        assert_eq!(plan.batch_size, 1);
574        assert_eq!(plan.input_dtype, DType::C64);
575        assert_eq!(plan.output_dtype, DType::C64);
576        assert_eq!(plan.direction, FftDirection::Forward);
577        assert_eq!(plan.normalization, FftNormalization::None);
578        assert!(plan.is_valid());
579    }
580
581    #[test]
582    fn test_fft_plan_buffer_sizes() {
583        let plan = FftPlan::new(
584            FftType::C2C,
585            vec![1024],
586            1,
587            DType::C64,
588            DType::C64,
589            FftDirection::Forward,
590            FftNormalization::None,
591        );
592
593        assert_eq!(plan.input_buffer_size(), 1024 * 8); // C32 is 8 bytes
594        assert_eq!(plan.output_buffer_size(), 1024 * 8);
595    }
596
597    #[test]
598    fn test_r2c_plan_buffer_sizes() {
599        let plan = FftPlan::new(
600            FftType::R2C,
601            vec![1024],
602            1,
603            DType::F32,
604            DType::C64,
605            FftDirection::Forward,
606            FftNormalization::None,
607        );
608
609        assert_eq!(plan.input_buffer_size(), 1024 * 4); // F32 is 4 bytes
610        assert_eq!(plan.output_buffer_size(), (1024 / 2 + 1) * 8); // C32 is 8 bytes, output is N/2+1
611    }
612
613    #[test]
614    fn test_convenience_functions() {
615        let plan =
616            convenience::create_c2c_1d_plan(1024, 1, FftDirection::Forward, FftNormalization::None);
617
618        assert_eq!(plan.fft_type, FftType::C2C);
619        assert_eq!(plan.dimensions, vec![1024]);
620        assert!(plan.is_valid());
621    }
622
623    #[test]
624    fn test_optimal_fft_sizes() {
625        assert!(convenience::is_optimal_fft_size(1024)); // 2^10
626        assert!(convenience::is_optimal_fft_size(1080)); // 2^3 * 3^3 * 5
627        assert!(!convenience::is_optimal_fft_size(1023)); // Prime
628
629        assert_eq!(convenience::next_power_of_2(1000), 1024);
630        assert_eq!(convenience::next_power_of_2(1024), 1024);
631
632        assert_eq!(convenience::next_optimal_fft_size(1023), 1024);
633        assert_eq!(convenience::next_optimal_fft_size(1024), 1024);
634    }
635
636    #[test]
637    fn test_default_fft_ops() {
638        let ops = DefaultFftOps;
639        assert!(!ops.supports_fft());
640
641        let sizes = ops.get_optimal_fft_sizes(100, 2000);
642        assert!(!sizes.is_empty());
643        assert!(sizes.iter().all(|&size| size >= 100 && size <= 2000));
644        assert!(sizes.iter().all(|&size| size.is_power_of_two()));
645    }
646}