1use crate::{BackendResult, Buffer, Device};
7use torsh_core::dtype::DType;
8
9#[cfg(not(feature = "std"))]
10use alloc::{boxed::Box, string::String, vec::Vec};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum FftDirection {
15 Forward,
17 Inverse,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum FftNormalization {
24 None,
26 Backward,
28 Ortho,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum FftType {
35 C2C,
37 R2C,
39 C2R,
41 C2C2D,
43 R2C2D,
45 C2R2D,
47 C2C3D,
49 R2C3D,
51 C2R3D,
53}
54
55#[derive(Debug, Clone)]
57pub struct FftPlan {
58 pub id: String,
60 pub fft_type: FftType,
62 pub dimensions: Vec<usize>,
64 pub batch_size: usize,
66 pub input_dtype: DType,
68 pub output_dtype: DType,
70 pub direction: FftDirection,
72 pub normalization: FftNormalization,
74 pub backend_data: Vec<u8>,
76}
77
78impl FftPlan {
79 pub fn new(
81 fft_type: FftType,
82 dimensions: Vec<usize>,
83 batch_size: usize,
84 input_dtype: DType,
85 output_dtype: DType,
86 direction: FftDirection,
87 normalization: FftNormalization,
88 ) -> Self {
89 let id = format!(
90 "{:?}_{:?}_{}_{}_{:?}_{:?}_{:?}",
91 fft_type, dimensions, batch_size, input_dtype, output_dtype, direction, normalization
92 );
93
94 Self {
95 id,
96 fft_type,
97 dimensions,
98 batch_size,
99 input_dtype,
100 output_dtype,
101 direction,
102 normalization,
103 backend_data: Vec::new(),
104 }
105 }
106
107 pub fn new_1d(size: usize, direction: FftDirection) -> Self {
120 Self::new(
121 FftType::C2C,
122 vec![size],
123 1, DType::C64, DType::C64, direction,
127 FftNormalization::None,
128 )
129 }
130
131 pub fn total_elements(&self) -> usize {
133 self.dimensions.iter().product::<usize>() * self.batch_size
134 }
135
136 pub fn input_buffer_size(&self) -> usize {
138 let element_size = match self.input_dtype {
139 DType::F32 => 4,
140 DType::F64 => 8,
141 DType::C64 => 8,
142 DType::C128 => 16,
143 _ => 4, };
145
146 self.total_elements() * element_size
147 }
148
149 pub fn output_buffer_size(&self) -> usize {
151 let element_size = match self.output_dtype {
152 DType::F32 => 4,
153 DType::F64 => 8,
154 DType::C64 => 8,
155 DType::C128 => 16,
156 _ => 8, };
158
159 match self.fft_type {
160 FftType::R2C | FftType::R2C2D | FftType::R2C3D => {
161 let mut output_elements = self.batch_size;
163 for (i, &dim) in self.dimensions.iter().enumerate() {
164 if i == self.dimensions.len() - 1 {
165 output_elements *= (dim / 2) + 1;
167 } else {
168 output_elements *= dim;
169 }
170 }
171 output_elements * element_size
172 }
173 _ => self.total_elements() * element_size,
174 }
175 }
176
177 pub fn is_valid(&self) -> bool {
179 !self.dimensions.is_empty() && self.batch_size > 0 && self.dimensions.iter().all(|&d| d > 0)
180 }
181}
182
183#[async_trait::async_trait]
185pub trait FftOps: Send + Sync {
186 async fn create_fft_plan(
188 &self,
189 device: &Device,
190 plan: &FftPlan,
191 ) -> BackendResult<Box<dyn FftExecutor>>;
192
193 async fn fft_1d(
195 &self,
196 device: &Device,
197 input: &Buffer,
198 output: &Buffer,
199 size: usize,
200 direction: FftDirection,
201 normalization: FftNormalization,
202 ) -> BackendResult<()>;
203
204 async fn fft_2d(
206 &self,
207 device: &Device,
208 input: &Buffer,
209 output: &Buffer,
210 size: (usize, usize),
211 direction: FftDirection,
212 normalization: FftNormalization,
213 ) -> BackendResult<()>;
214
215 async fn fft_3d(
217 &self,
218 device: &Device,
219 input: &Buffer,
220 output: &Buffer,
221 size: (usize, usize, usize),
222 direction: FftDirection,
223 normalization: FftNormalization,
224 ) -> BackendResult<()>;
225
226 async fn fft_batch(
228 &self,
229 device: &Device,
230 input: &Buffer,
231 output: &Buffer,
232 size: &[usize],
233 batch_size: usize,
234 direction: FftDirection,
235 normalization: FftNormalization,
236 ) -> BackendResult<()>;
237
238 async fn rfft(
240 &self,
241 device: &Device,
242 input: &Buffer,
243 output: &Buffer,
244 size: &[usize],
245 direction: FftDirection,
246 normalization: FftNormalization,
247 ) -> BackendResult<()>;
248
249 async fn irfft(
251 &self,
252 device: &Device,
253 input: &Buffer,
254 output: &Buffer,
255 size: &[usize],
256 normalization: FftNormalization,
257 ) -> BackendResult<()>;
258
259 fn supports_fft(&self) -> bool;
261
262 fn get_optimal_fft_sizes(&self, min_size: usize, max_size: usize) -> Vec<usize>;
264}
265
266#[async_trait::async_trait]
268pub trait FftExecutor: Send + Sync {
269 async fn execute(&self, device: &Device, input: &Buffer, output: &Buffer) -> BackendResult<()>;
271
272 fn plan(&self) -> &FftPlan;
274
275 fn memory_requirements(&self) -> usize;
277
278 fn is_valid(&self) -> bool;
280}
281
282pub struct DefaultFftOps;
284
285impl DefaultFftOps {
286 pub fn new() -> Self {
288 Self
289 }
290}
291
292impl Default for DefaultFftOps {
293 fn default() -> Self {
294 Self::new()
295 }
296}
297
298#[async_trait::async_trait]
299impl FftOps for DefaultFftOps {
300 async fn create_fft_plan(
301 &self,
302 _device: &Device,
303 plan: &FftPlan,
304 ) -> BackendResult<Box<dyn FftExecutor>> {
305 Ok(Box::new(DefaultFftExecutor { plan: plan.clone() }))
306 }
307
308 async fn fft_1d(
309 &self,
310 _device: &Device,
311 _input: &Buffer,
312 _output: &Buffer,
313 _size: usize,
314 _direction: FftDirection,
315 _normalization: FftNormalization,
316 ) -> BackendResult<()> {
317 Err(torsh_core::error::TorshError::BackendError(
318 "FFT operations not implemented for this backend".to_string(),
319 ))
320 }
321
322 async fn fft_2d(
323 &self,
324 _device: &Device,
325 _input: &Buffer,
326 _output: &Buffer,
327 _size: (usize, usize),
328 _direction: FftDirection,
329 _normalization: FftNormalization,
330 ) -> BackendResult<()> {
331 Err(torsh_core::error::TorshError::BackendError(
332 "FFT operations not implemented for this backend".to_string(),
333 ))
334 }
335
336 async fn fft_3d(
337 &self,
338 _device: &Device,
339 _input: &Buffer,
340 _output: &Buffer,
341 _size: (usize, usize, usize),
342 _direction: FftDirection,
343 _normalization: FftNormalization,
344 ) -> BackendResult<()> {
345 Err(torsh_core::error::TorshError::BackendError(
346 "FFT operations not implemented for this backend".to_string(),
347 ))
348 }
349
350 async fn fft_batch(
351 &self,
352 _device: &Device,
353 _input: &Buffer,
354 _output: &Buffer,
355 _size: &[usize],
356 _batch_size: usize,
357 _direction: FftDirection,
358 _normalization: FftNormalization,
359 ) -> BackendResult<()> {
360 Err(torsh_core::error::TorshError::BackendError(
361 "FFT operations not implemented for this backend".to_string(),
362 ))
363 }
364
365 async fn rfft(
366 &self,
367 _device: &Device,
368 _input: &Buffer,
369 _output: &Buffer,
370 _size: &[usize],
371 _direction: FftDirection,
372 _normalization: FftNormalization,
373 ) -> BackendResult<()> {
374 Err(torsh_core::error::TorshError::BackendError(
375 "FFT operations not implemented for this backend".to_string(),
376 ))
377 }
378
379 async fn irfft(
380 &self,
381 _device: &Device,
382 _input: &Buffer,
383 _output: &Buffer,
384 _size: &[usize],
385 _normalization: FftNormalization,
386 ) -> BackendResult<()> {
387 Err(torsh_core::error::TorshError::BackendError(
388 "FFT operations not implemented for this backend".to_string(),
389 ))
390 }
391
392 fn supports_fft(&self) -> bool {
393 false
394 }
395
396 fn get_optimal_fft_sizes(&self, min_size: usize, max_size: usize) -> Vec<usize> {
397 let mut sizes = Vec::new();
399 let mut size = 1;
400 while size < min_size {
401 size *= 2;
402 }
403 while size <= max_size {
404 sizes.push(size);
405 size *= 2;
406 }
407 sizes
408 }
409}
410
411pub struct DefaultFftExecutor {
413 plan: FftPlan,
414}
415
416#[async_trait::async_trait]
417impl FftExecutor for DefaultFftExecutor {
418 async fn execute(
419 &self,
420 _device: &Device,
421 _input: &Buffer,
422 _output: &Buffer,
423 ) -> BackendResult<()> {
424 Err(torsh_core::error::TorshError::BackendError(
425 "FFT execution not implemented for this backend".to_string(),
426 ))
427 }
428
429 fn plan(&self) -> &FftPlan {
430 &self.plan
431 }
432
433 fn memory_requirements(&self) -> usize {
434 self.plan.input_buffer_size() + self.plan.output_buffer_size()
435 }
436
437 fn is_valid(&self) -> bool {
438 self.plan.is_valid()
439 }
440}
441
442pub mod convenience {
444 use super::*;
445
446 pub fn create_c2c_1d_plan(
448 size: usize,
449 batch_size: usize,
450 direction: FftDirection,
451 normalization: FftNormalization,
452 ) -> FftPlan {
453 FftPlan::new(
454 FftType::C2C,
455 vec![size],
456 batch_size,
457 DType::C64,
458 DType::C64,
459 direction,
460 normalization,
461 )
462 }
463
464 pub fn create_r2c_1d_plan(
466 size: usize,
467 batch_size: usize,
468 normalization: FftNormalization,
469 ) -> FftPlan {
470 FftPlan::new(
471 FftType::R2C,
472 vec![size],
473 batch_size,
474 DType::F32,
475 DType::C64,
476 FftDirection::Forward,
477 normalization,
478 )
479 }
480
481 pub fn create_c2c_2d_plan(
483 size: (usize, usize),
484 batch_size: usize,
485 direction: FftDirection,
486 normalization: FftNormalization,
487 ) -> FftPlan {
488 FftPlan::new(
489 FftType::C2C2D,
490 vec![size.0, size.1],
491 batch_size,
492 DType::C64,
493 DType::C64,
494 direction,
495 normalization,
496 )
497 }
498
499 pub fn create_c2c_3d_plan(
501 size: (usize, usize, usize),
502 batch_size: usize,
503 direction: FftDirection,
504 normalization: FftNormalization,
505 ) -> FftPlan {
506 FftPlan::new(
507 FftType::C2C3D,
508 vec![size.0, size.1, size.2],
509 batch_size,
510 DType::C64,
511 DType::C64,
512 direction,
513 normalization,
514 )
515 }
516
517 pub fn next_power_of_2(n: usize) -> usize {
519 if n == 0 {
520 return 1;
521 }
522 let mut power = 1;
523 while power < n {
524 power *= 2;
525 }
526 power
527 }
528
529 pub fn is_optimal_fft_size(size: usize) -> bool {
531 if size == 0 {
532 return false;
533 }
534
535 let mut n = size;
536 for prime in &[2, 3, 5, 7] {
537 while n % prime == 0 {
538 n /= prime;
539 }
540 }
541
542 n == 1
543 }
544
545 pub fn next_optimal_fft_size(size: usize) -> usize {
547 let mut candidate = size;
548 while !is_optimal_fft_size(candidate) {
549 candidate += 1;
550 }
551 candidate
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::*;
558
559 #[test]
560 fn test_fft_plan_creation() {
561 let plan = FftPlan::new(
562 FftType::C2C,
563 vec![1024],
564 1,
565 DType::C64,
566 DType::C64,
567 FftDirection::Forward,
568 FftNormalization::None,
569 );
570
571 assert_eq!(plan.fft_type, FftType::C2C);
572 assert_eq!(plan.dimensions, vec![1024]);
573 assert_eq!(plan.batch_size, 1);
574 assert_eq!(plan.input_dtype, DType::C64);
575 assert_eq!(plan.output_dtype, DType::C64);
576 assert_eq!(plan.direction, FftDirection::Forward);
577 assert_eq!(plan.normalization, FftNormalization::None);
578 assert!(plan.is_valid());
579 }
580
581 #[test]
582 fn test_fft_plan_buffer_sizes() {
583 let plan = FftPlan::new(
584 FftType::C2C,
585 vec![1024],
586 1,
587 DType::C64,
588 DType::C64,
589 FftDirection::Forward,
590 FftNormalization::None,
591 );
592
593 assert_eq!(plan.input_buffer_size(), 1024 * 8); assert_eq!(plan.output_buffer_size(), 1024 * 8);
595 }
596
597 #[test]
598 fn test_r2c_plan_buffer_sizes() {
599 let plan = FftPlan::new(
600 FftType::R2C,
601 vec![1024],
602 1,
603 DType::F32,
604 DType::C64,
605 FftDirection::Forward,
606 FftNormalization::None,
607 );
608
609 assert_eq!(plan.input_buffer_size(), 1024 * 4); assert_eq!(plan.output_buffer_size(), (1024 / 2 + 1) * 8); }
612
613 #[test]
614 fn test_convenience_functions() {
615 let plan =
616 convenience::create_c2c_1d_plan(1024, 1, FftDirection::Forward, FftNormalization::None);
617
618 assert_eq!(plan.fft_type, FftType::C2C);
619 assert_eq!(plan.dimensions, vec![1024]);
620 assert!(plan.is_valid());
621 }
622
623 #[test]
624 fn test_optimal_fft_sizes() {
625 assert!(convenience::is_optimal_fft_size(1024)); assert!(convenience::is_optimal_fft_size(1080)); assert!(!convenience::is_optimal_fft_size(1023)); assert_eq!(convenience::next_power_of_2(1000), 1024);
630 assert_eq!(convenience::next_power_of_2(1024), 1024);
631
632 assert_eq!(convenience::next_optimal_fft_size(1023), 1024);
633 assert_eq!(convenience::next_optimal_fft_size(1024), 1024);
634 }
635
636 #[test]
637 fn test_default_fft_ops() {
638 let ops = DefaultFftOps;
639 assert!(!ops.supports_fft());
640
641 let sizes = ops.get_optimal_fft_sizes(100, 2000);
642 assert!(!sizes.is_empty());
643 assert!(sizes.iter().all(|&size| size >= 100 && size <= 2000));
644 assert!(sizes.iter().all(|&size| size.is_power_of_two()));
645 }
646}