1use crate::{Tensor, TensorElement};
10use std::f64::consts::PI;
11use torsh_core::dtype::Complex64;
12use torsh_core::error::{Result, TorshError};
13
14#[derive(Debug, Clone)]
16pub struct FFTPlan {
17 pub size: usize,
19 pub twiddles: Vec<Complex64>,
21 pub bit_reversed_indices: Vec<usize>,
23 pub is_forward: bool,
25}
26
27impl FFTPlan {
28 pub fn new(size: usize, is_forward: bool) -> Result<Self> {
30 if size == 0 || (size & (size - 1)) != 0 {
31 return Err(TorshError::InvalidArgument(
32 "FFT size must be a power of 2".to_string(),
33 ));
34 }
35
36 let mut twiddles = Vec::with_capacity(size / 2);
37 let direction = if is_forward { -1.0 } else { 1.0 };
38
39 for k in 0..size / 2 {
41 let angle = direction * 2.0 * PI * k as f64 / size as f64;
42 twiddles.push(Complex64::new(angle.cos(), angle.sin()));
43 }
44
45 let mut bit_reversed_indices = vec![0; size];
47 let mut j = 0;
48 #[allow(clippy::needless_range_loop)]
49 for i in 1..size {
50 let mut bit = size >> 1;
51 while j & bit != 0 {
52 j ^= bit;
53 bit >>= 1;
54 }
55 j ^= bit;
56 bit_reversed_indices[i] = j;
57 }
58
59 Ok(Self {
60 size,
61 twiddles,
62 bit_reversed_indices,
63 is_forward,
64 })
65 }
66
67 pub fn execute(&self, data: &mut [Complex64]) -> Result<()> {
69 if data.len() != self.size {
70 return Err(TorshError::InvalidArgument(format!(
71 "Data size {} does not match plan size {}",
72 data.len(),
73 self.size
74 )));
75 }
76
77 for i in 0..self.size {
79 let j = self.bit_reversed_indices[i];
80 if i < j {
81 data.swap(i, j);
82 }
83 }
84
85 let mut n = 2;
87 while n <= self.size {
88 let step = self.size / n;
89 for i in (0..self.size).step_by(n) {
90 for j in 0..n / 2 {
91 let u = data[i + j];
92 let v = data[i + j + n / 2] * self.twiddles[j * step];
93 data[i + j] = u + v;
94 data[i + j + n / 2] = u - v;
95 }
96 }
97 n <<= 1;
98 }
99
100 if !self.is_forward {
102 let norm = 1.0 / self.size as f64;
103 for sample in data.iter_mut() {
104 *sample *= norm;
105 }
106 }
107
108 Ok(())
109 }
110}
111
112impl<T: TensorElement + Into<f64> + From<f64>> Tensor<T> {
114 pub fn fft(&self) -> Result<Tensor<Complex64>> {
116 self.fft_with_plan(None)
117 }
118
119 pub fn fft_with_plan(&self, plan: Option<&FFTPlan>) -> Result<Tensor<Complex64>> {
121 let shape = self.shape();
122 let last_dim_size = shape.dims().last().copied().unwrap_or(1);
123
124 if last_dim_size == 0 || (last_dim_size & (last_dim_size - 1)) != 0 {
126 return Err(TorshError::InvalidArgument(
127 "FFT requires the last dimension to be a power of 2".to_string(),
128 ));
129 }
130
131 let owned_plan;
133 let fft_plan = match plan {
134 Some(p) => {
135 if p.size != last_dim_size || !p.is_forward {
136 return Err(TorshError::InvalidArgument(
137 "Plan size or direction mismatch".to_string(),
138 ));
139 }
140 p
141 }
142 None => {
143 owned_plan = FFTPlan::new(last_dim_size, true)?;
144 &owned_plan
145 }
146 };
147
148 let input_data = self.to_vec()?;
150 let total_elements = input_data.len();
151 let num_ffts = total_elements / last_dim_size;
152
153 let mut complex_data = Vec::with_capacity(total_elements);
154 for &value in &input_data {
155 complex_data.push(Complex64::new(value.into(), 0.0));
156 }
157
158 for i in 0..num_ffts {
160 let start = i * last_dim_size;
161 let end = start + last_dim_size;
162 fft_plan.execute(&mut complex_data[start..end])?;
163 }
164
165 Tensor::from_complex_data(complex_data, shape.dims().to_vec(), self.device())
167 }
168
169 pub fn ifft(&self) -> Result<Tensor<T>>
171 where
172 T: TensorElement + From<f64>,
173 {
174 let complex_tensor = self.to_complex()?;
175 complex_tensor.ifft_complex()?.to_real()
176 }
177
178 fn to_complex(&self) -> Result<Tensor<Complex64>> {
180 let input_data = self.to_vec()?;
181 let complex_data: Vec<Complex64> = input_data
182 .iter()
183 .map(|&value| Complex64::new(value.into(), 0.0))
184 .collect();
185
186 Tensor::from_complex_data(complex_data, self.shape().dims().to_vec(), self.device())
187 }
188
189 pub fn fft2(&self) -> Result<Tensor<Complex64>> {
191 let shape = self.shape();
192 let dims = shape.dims();
193
194 if dims.len() < 2 {
195 return Err(TorshError::InvalidArgument(
196 "2D FFT requires at least 2 dimensions".to_string(),
197 ));
198 }
199
200 let temp = self.fft()?;
202
203 temp.fft_along_dim(dims.len() - 2)
205 }
206
207 pub fn ifft2(&self) -> Result<Tensor<T>>
209 where
210 T: TensorElement + From<f64>,
211 {
212 let complex_tensor = self.to_complex()?;
213 complex_tensor.ifft2_complex()?.to_real()
214 }
215
216 pub fn fft_along_dim_real(&self, dim: usize) -> Result<Tensor<Complex64>> {
218 let shape = self.shape();
219 let dims = shape.dims();
220
221 if dim >= dims.len() {
222 return Err(TorshError::InvalidArgument(format!(
223 "Dimension {} out of bounds for tensor with {} dimensions",
224 dim,
225 dims.len()
226 )));
227 }
228
229 if dim == dims.len() - 1 {
231 return self.fft();
232 }
233
234 let transposed = self.transpose_to_last_dim(dim)?;
236 let fft_result = transposed.fft()?;
237 fft_result.transpose_from_last_dim(dim)
238 }
239
240 pub fn rfft(&self) -> Result<Tensor<Complex64>> {
242 let shape = self.shape();
245 let last_dim_size = shape.dims().last().copied().unwrap_or(1);
246 let output_size = last_dim_size / 2 + 1;
247
248 let full_fft = self.fft()?;
249
250 let mut new_shape = shape.dims().to_vec();
252 *new_shape
253 .last_mut()
254 .expect("shape should have at least one dimension") = output_size;
255
256 full_fft.slice_last_dim_complex(0, output_size)
257 }
258
259 pub fn irfft(&self, output_size: Option<usize>) -> Result<Tensor<T>>
261 where
262 T: TensorElement + From<f64>,
263 {
264 let shape = self.shape();
265 let input_size = shape.dims().last().copied().unwrap_or(1);
266 let out_size = output_size.unwrap_or((input_size - 1) * 2);
267
268 let full_spectrum = self.reconstruct_hermitian_spectrum(out_size)?;
270
271 let complex_result = full_spectrum.ifft_complex()?;
273 complex_result.to_real()
274 }
275
276 pub fn power_spectrum(&self) -> Result<Tensor<T>>
278 where
279 T: TensorElement + From<f64>,
280 {
281 let fft_result = self.fft()?;
282 fft_result.power_spectrum_from_fft()
283 }
284
285 pub fn magnitude_spectrum(&self) -> Result<Tensor<T>>
287 where
288 T: TensorElement + From<f64>,
289 {
290 let fft_result = self.fft()?;
291 fft_result.magnitude_spectrum_from_fft()
292 }
293
294 pub fn phase_spectrum(&self) -> Result<Tensor<T>>
296 where
297 T: TensorElement + From<f64>,
298 {
299 let fft_result = self.fft()?;
300 fft_result.phase_spectrum_from_fft()
301 }
302
303 #[allow(dead_code)]
305 fn slice_last_dim(&self, start: usize, size: usize) -> Result<Self> {
306 let shape = self.shape();
309 let dims = shape.dims();
310 let last_dim_size = dims.last().copied().unwrap_or(1);
311
312 if start + size > last_dim_size {
313 return Err(TorshError::IndexOutOfBounds {
314 index: start + size - 1,
315 size: last_dim_size,
316 });
317 }
318
319 let mut new_dims = dims.to_vec();
321 *new_dims
322 .last_mut()
323 .expect("shape should have at least one dimension") = size;
324
325 let input_data = self.to_vec()?;
327 let total_elements = input_data.len();
328 let num_vectors = total_elements / last_dim_size;
329
330 let mut output_data = Vec::with_capacity(num_vectors * size);
331 for i in 0..num_vectors {
332 let base_idx = i * last_dim_size;
333 for j in 0..size {
334 output_data.push(input_data[base_idx + start + j]);
335 }
336 }
337
338 Self::from_data(output_data, new_dims, self.device())
339 }
340
341 fn reconstruct_hermitian_spectrum(&self, output_size: usize) -> Result<Tensor<Complex64>> {
343 let shape = self.shape();
346 let input_size = shape.dims().last().copied().unwrap_or(1);
347
348 if output_size < (input_size - 1) * 2 {
349 return Err(TorshError::InvalidArgument(
350 "Output size too small for IRFFT".to_string(),
351 ));
352 }
353
354 let mut new_dims = shape.dims().to_vec();
356 *new_dims
357 .last_mut()
358 .expect("shape should have at least one dimension") = output_size;
359
360 let input_data = self.to_vec()?;
361 let mut output_data = Vec::with_capacity(input_data.len() * output_size / input_size);
362
363 for &value in &input_data {
365 let f64_value: f64 = value.into();
367 output_data.push(Complex64::new(f64_value, 0.0));
368 }
369
370 while output_data.len() < output_data.capacity() {
372 output_data.push(Complex64::new(0.0, 0.0));
373 }
374
375 Tensor::from_complex_data(output_data, new_dims, self.device())
376 }
377}
378
379impl<T: TensorElement> Tensor<T> {
381 fn transpose_to_last_dim(&self, dim: usize) -> Result<Self> {
383 let ndim = self.shape().dims().len();
384 if dim == ndim - 1 {
385 return Ok(self.clone());
386 }
387 self.transpose(dim as i32, (ndim - 1) as i32)
388 }
389
390 fn transpose_from_last_dim(&self, original_dim: usize) -> Result<Self> {
392 let ndim = self.shape().dims().len();
393 if original_dim == ndim - 1 {
394 return Ok(self.clone());
395 }
396 self.transpose(original_dim as i32, (ndim - 1) as i32)
397 }
398}
399
400impl Tensor<Complex64> {
402 pub fn from_complex_data(
404 data: Vec<Complex64>,
405 shape: Vec<usize>,
406 device: torsh_core::device::DeviceType,
407 ) -> Result<Self> {
408 Tensor::from_data(data, shape, device)
409 }
410
411 pub fn to_real<T: TensorElement + From<f64>>(&self) -> Result<Tensor<T>> {
413 let complex_data = self.to_vec()?;
414 let real_data: Vec<T> = complex_data.iter().map(|c| T::from(c.re)).collect();
415
416 Tensor::from_data(real_data, self.shape().dims().to_vec(), self.device())
417 }
418
419 pub fn power_spectrum_from_fft<T: TensorElement + From<f64>>(&self) -> Result<Tensor<T>> {
421 let complex_data = self.to_vec()?;
422 let power_data: Vec<T> = complex_data
423 .iter()
424 .map(|c| T::from(c.norm().powi(2)))
425 .collect();
426
427 Tensor::from_data(power_data, self.shape().dims().to_vec(), self.device())
428 }
429
430 pub fn magnitude_spectrum_from_fft<T: TensorElement + From<f64>>(&self) -> Result<Tensor<T>> {
432 let complex_data = self.to_vec()?;
433 let magnitude_data: Vec<T> = complex_data.iter().map(|c| T::from(c.norm())).collect();
434
435 Tensor::from_data(magnitude_data, self.shape().dims().to_vec(), self.device())
436 }
437
438 pub fn phase_spectrum_from_fft<T: TensorElement + From<f64>>(&self) -> Result<Tensor<T>> {
440 let complex_data = self.to_vec()?;
441 let phase_data: Vec<T> = complex_data.iter().map(|c| T::from(c.arg())).collect();
442
443 Tensor::from_data(phase_data, self.shape().dims().to_vec(), self.device())
444 }
445
446 pub fn fft_complex(&self) -> Result<Tensor<Complex64>> {
448 let shape = self.shape();
449 let last_dim_size = shape.dims().last().copied().unwrap_or(1);
450
451 let plan = FFTPlan::new(last_dim_size, true)?;
452
453 let mut complex_data = self.to_vec()?;
454 let num_ffts = complex_data.len() / last_dim_size;
455
456 for i in 0..num_ffts {
458 let start = i * last_dim_size;
459 let end = start + last_dim_size;
460 plan.execute(&mut complex_data[start..end])?;
461 }
462
463 Tensor::from_complex_data(complex_data, shape.dims().to_vec(), self.device())
465 }
466
467 pub fn ifft_complex(&self) -> Result<Tensor<Complex64>> {
469 let shape = self.shape();
470 let last_dim_size = shape.dims().last().copied().unwrap_or(1);
471
472 let plan = FFTPlan::new(last_dim_size, false)?;
473
474 let mut complex_data = self.to_vec()?;
475 let num_ffts = complex_data.len() / last_dim_size;
476
477 for i in 0..num_ffts {
479 let start = i * last_dim_size;
480 let end = start + last_dim_size;
481 plan.execute(&mut complex_data[start..end])?;
482 }
483
484 Tensor::from_complex_data(complex_data, shape.dims().to_vec(), self.device())
485 }
486
487 pub fn ifft2_complex(&self) -> Result<Tensor<Complex64>> {
489 let shape = self.shape();
490 let dims = shape.dims();
491
492 if dims.len() < 2 {
493 return Err(TorshError::InvalidArgument(
494 "2D IFFT requires at least 2 dimensions".to_string(),
495 ));
496 }
497
498 let temp = self.ifft_along_dim(dims.len() - 2)?;
500
501 temp.ifft_complex()
503 }
504
505 pub fn ifft_along_dim(&self, dim: usize) -> Result<Tensor<Complex64>> {
507 let shape = self.shape();
508 let dims = shape.dims();
509
510 if dim >= dims.len() {
511 return Err(TorshError::InvalidArgument(format!(
512 "Dimension {} out of bounds for tensor with {} dimensions",
513 dim,
514 dims.len()
515 )));
516 }
517
518 if dim == dims.len() - 1 {
520 return self.ifft_complex();
521 }
522
523 let transposed = self.transpose_to_last_dim_complex(dim)?;
525 let ifft_result = transposed.ifft_complex()?;
526 ifft_result.transpose_from_last_dim_complex(dim)
527 }
528
529 fn transpose_to_last_dim_complex(&self, _dim: usize) -> Result<Tensor<Complex64>> {
531 Ok(self.clone())
534 }
535
536 fn transpose_from_last_dim_complex(&self, _dim: usize) -> Result<Tensor<Complex64>> {
538 Ok(self.clone())
541 }
542
543 pub fn fft2_complex(&self) -> Result<Tensor<Complex64>> {
545 let shape = self.shape();
546 let dims = shape.dims().to_vec();
547
548 if dims.len() < 2 {
549 return Err(TorshError::InvalidArgument(
550 "2D FFT requires at least 2 dimensions".to_string(),
551 ));
552 }
553
554 let temp = self.fft_complex()?;
556
557 temp.fft_along_dim_complex(dims.len() - 2)
559 }
560
561 pub fn fft_along_dim(&self, dim: usize) -> Result<Tensor<Complex64>> {
563 self.fft_along_dim_complex(dim)
564 }
565
566 pub fn fft_along_dim_complex(&self, dim: usize) -> Result<Tensor<Complex64>> {
568 let shape = self.shape();
569 let dims = shape.dims();
570
571 if dim >= dims.len() {
572 return Err(TorshError::InvalidArgument(format!(
573 "Dimension {} out of bounds for tensor with {} dimensions",
574 dim,
575 dims.len()
576 )));
577 }
578
579 if dim == dims.len() - 1 {
581 return self.fft_complex();
582 }
583
584 let transposed = self.transpose_to_last_dim_complex(dim)?;
586 let fft_result = transposed.fft_complex()?;
587 fft_result.transpose_from_last_dim_complex(dim)
588 }
589
590 pub fn slice_last_dim_complex(&self, start: usize, size: usize) -> Result<Tensor<Complex64>> {
592 let shape = self.shape();
593 let dims = shape.dims().to_vec();
594
595 if dims.is_empty() {
596 return Err(TorshError::InvalidArgument(
597 "Cannot slice empty tensor".to_string(),
598 ));
599 }
600
601 let last_dim = dims.len() - 1;
602 let last_dim_size = dims[last_dim];
603 let end = start + size;
604
605 if start >= last_dim_size || end > last_dim_size {
606 return Err(TorshError::InvalidArgument(format!(
607 "Invalid slice range {start}..{end} for dimension of size {last_dim_size}"
608 )));
609 }
610
611 let data = self.to_vec()?;
612 let num_elements_per_slice = dims[..last_dim].iter().product::<usize>();
613 let mut result_data = Vec::with_capacity(num_elements_per_slice * size);
614
615 for i in 0..num_elements_per_slice {
616 let slice_start = i * last_dim_size + start;
617 let slice_end = slice_start + size;
618 result_data.extend_from_slice(&data[slice_start..slice_end]);
619 }
620
621 let mut new_dims = dims;
622 new_dims[last_dim] = size;
623
624 Tensor::from_complex_data(result_data, new_dims, self.device())
625 }
626}
627
628pub mod windows {
630 use super::*;
631
632 pub fn hann<T: TensorElement + From<f64>>(size: usize) -> Result<Tensor<T>> {
634 let data: Vec<T> = (0..size)
635 .map(|i| {
636 let factor = 0.5 * (1.0 - (2.0 * PI * i as f64 / (size - 1) as f64).cos());
637 T::from(factor)
638 })
639 .collect();
640
641 Tensor::from_data(data, vec![size], torsh_core::device::DeviceType::Cpu)
642 }
643
644 pub fn hamming<T: TensorElement + From<f64>>(size: usize) -> Result<Tensor<T>> {
646 let data: Vec<T> = (0..size)
647 .map(|i| {
648 let factor = 0.54 - 0.46 * (2.0 * PI * i as f64 / (size - 1) as f64).cos();
649 T::from(factor)
650 })
651 .collect();
652
653 Tensor::from_data(data, vec![size], torsh_core::device::DeviceType::Cpu)
654 }
655
656 pub fn blackman<T: TensorElement + From<f64>>(size: usize) -> Result<Tensor<T>> {
658 let data: Vec<T> = (0..size)
659 .map(|i| {
660 let n = i as f64;
661 let n_max = (size - 1) as f64;
662 let factor =
663 0.42 - 0.5 * (2.0 * PI * n / n_max).cos() + 0.08 * (4.0 * PI * n / n_max).cos();
664 T::from(factor)
665 })
666 .collect();
667
668 Tensor::from_data(data, vec![size], torsh_core::device::DeviceType::Cpu)
669 }
670
671 pub fn kaiser<T: TensorElement + From<f64>>(size: usize, beta: f64) -> Result<Tensor<T>> {
673 let data: Vec<T> = (0..size)
675 .map(|i| {
676 let n = i as f64;
677 let n_max = (size - 1) as f64;
678 let factor = (beta * (1.0 - ((2.0 * n / n_max) - 1.0).powi(2)).sqrt()).exp();
679 T::from(factor)
680 })
681 .collect();
682
683 Tensor::from_data(data, vec![size], torsh_core::device::DeviceType::Cpu)
684 }
685}
686
687#[cfg(test)]
688mod tests {
689 use super::*;
690 use crate::Tensor;
691
692 #[test]
693 fn test_fft_plan_creation() {
694 let plan = FFTPlan::new(8, true).expect("FFT plan creation should succeed");
695 assert_eq!(plan.size, 8);
696 assert_eq!(plan.twiddles.len(), 4);
697 assert_eq!(plan.bit_reversed_indices.len(), 8);
698 assert!(plan.is_forward);
699 }
700
701 #[test]
702 fn test_complex_arithmetic() {
703 let a = Complex64::new(1.0, 2.0);
704 let b = Complex64::new(3.0, 4.0);
705
706 let sum = a + b;
707 assert_eq!(sum.re, 4.0);
708 assert_eq!(sum.im, 6.0);
709
710 let product = a * b;
711 assert_eq!(product.re, -5.0); assert_eq!(product.im, 10.0); assert_eq!(a.norm(), (5.0_f64).sqrt());
715 }
716
717 #[test]
718 fn test_fft_basic() {
719 let data = vec![1.0, 0.0, 0.0, 0.0];
721 let tensor = Tensor::from_data(data, vec![4], torsh_core::device::DeviceType::Cpu)
722 .expect("tensor creation should succeed");
723
724 let result = tensor.fft();
726 assert!(result.is_ok(), "FFT should work with valid input");
727
728 let fft_result = result.expect("FFT operation should succeed");
729 assert_eq!(fft_result.shape().dims(), &[4]);
730
731 let output_data = fft_result
733 .to_vec()
734 .expect("to_vec conversion should succeed");
735 assert_eq!(output_data.len(), 4);
736 assert!((output_data[0].re - 1.0).abs() < 1e-6);
738 assert!(output_data[0].im.abs() < 1e-6);
739 }
740
741 #[test]
742 fn test_windowing_functions() {
743 let hann_window = windows::hann::<f64>(8).expect("FFT operation should succeed");
744 assert_eq!(hann_window.shape().dims(), &[8]);
745
746 let hamming_window = windows::hamming::<f64>(8).expect("FFT operation should succeed");
747 assert_eq!(hamming_window.shape().dims(), &[8]);
748
749 let blackman_window = windows::blackman::<f64>(8).expect("FFT operation should succeed");
750 assert_eq!(blackman_window.shape().dims(), &[8]);
751 }
752}