Skip to main content

torsh_tensor/
fft.rs

1//! Fast Fourier Transform (FFT) operations for tensors
2//!
3//! This module provides comprehensive FFT functionality including:
4//! - 1D, 2D, and N-dimensional FFTs
5//! - Real and complex FFTs
6//! - Inverse FFTs
7//! - Optimized implementations for various tensor sizes
8
9use crate::{Tensor, TensorElement};
10use std::f64::consts::PI;
11use torsh_core::dtype::Complex64;
12use torsh_core::error::{Result, TorshError};
13
14/// FFT plan for optimized repeated transforms
15#[derive(Debug, Clone)]
16pub struct FFTPlan {
17    /// Size of the transform
18    pub size: usize,
19    /// Precomputed twiddle factors
20    pub twiddles: Vec<Complex64>,
21    /// Bit-reversed indices for in-place computation
22    pub bit_reversed_indices: Vec<usize>,
23    /// Whether this is a forward or inverse transform
24    pub is_forward: bool,
25}
26
27impl FFTPlan {
28    /// Create a new FFT plan for the given size
29    pub fn new(size: usize, is_forward: bool) -> Result<Self> {
30        if size == 0 || (size & (size - 1)) != 0 {
31            return Err(TorshError::InvalidArgument(
32                "FFT size must be a power of 2".to_string(),
33            ));
34        }
35
36        let mut twiddles = Vec::with_capacity(size / 2);
37        let direction = if is_forward { -1.0 } else { 1.0 };
38
39        // Precompute twiddle factors
40        for k in 0..size / 2 {
41            let angle = direction * 2.0 * PI * k as f64 / size as f64;
42            twiddles.push(Complex64::new(angle.cos(), angle.sin()));
43        }
44
45        // Compute bit-reversed indices
46        let mut bit_reversed_indices = vec![0; size];
47        let mut j = 0;
48        #[allow(clippy::needless_range_loop)]
49        for i in 1..size {
50            let mut bit = size >> 1;
51            while j & bit != 0 {
52                j ^= bit;
53                bit >>= 1;
54            }
55            j ^= bit;
56            bit_reversed_indices[i] = j;
57        }
58
59        Ok(Self {
60            size,
61            twiddles,
62            bit_reversed_indices,
63            is_forward,
64        })
65    }
66
67    /// Execute the FFT plan on the given data
68    pub fn execute(&self, data: &mut [Complex64]) -> Result<()> {
69        if data.len() != self.size {
70            return Err(TorshError::InvalidArgument(format!(
71                "Data size {} does not match plan size {}",
72                data.len(),
73                self.size
74            )));
75        }
76
77        // Bit-reverse the input
78        for i in 0..self.size {
79            let j = self.bit_reversed_indices[i];
80            if i < j {
81                data.swap(i, j);
82            }
83        }
84
85        // Cooley-Tukey FFT algorithm
86        let mut n = 2;
87        while n <= self.size {
88            let step = self.size / n;
89            for i in (0..self.size).step_by(n) {
90                for j in 0..n / 2 {
91                    let u = data[i + j];
92                    let v = data[i + j + n / 2] * self.twiddles[j * step];
93                    data[i + j] = u + v;
94                    data[i + j + n / 2] = u - v;
95                }
96            }
97            n <<= 1;
98        }
99
100        // Normalize for inverse transform
101        if !self.is_forward {
102            let norm = 1.0 / self.size as f64;
103            for sample in data.iter_mut() {
104                *sample *= norm;
105            }
106        }
107
108        Ok(())
109    }
110}
111
112/// FFT operations for tensors
113impl<T: TensorElement + Into<f64> + From<f64>> Tensor<T> {
114    /// Compute 1D FFT along the last dimension
115    pub fn fft(&self) -> Result<Tensor<Complex64>> {
116        self.fft_with_plan(None)
117    }
118
119    /// Compute 1D FFT with a precomputed plan
120    pub fn fft_with_plan(&self, plan: Option<&FFTPlan>) -> Result<Tensor<Complex64>> {
121        let shape = self.shape();
122        let last_dim_size = shape.dims().last().copied().unwrap_or(1);
123
124        // Check if size is a power of 2
125        if last_dim_size == 0 || (last_dim_size & (last_dim_size - 1)) != 0 {
126            return Err(TorshError::InvalidArgument(
127                "FFT requires the last dimension to be a power of 2".to_string(),
128            ));
129        }
130
131        // Create or use existing plan
132        let owned_plan;
133        let fft_plan = match plan {
134            Some(p) => {
135                if p.size != last_dim_size || !p.is_forward {
136                    return Err(TorshError::InvalidArgument(
137                        "Plan size or direction mismatch".to_string(),
138                    ));
139                }
140                p
141            }
142            None => {
143                owned_plan = FFTPlan::new(last_dim_size, true)?;
144                &owned_plan
145            }
146        };
147
148        // Convert input data to complex
149        let input_data = self.to_vec()?;
150        let total_elements = input_data.len();
151        let num_ffts = total_elements / last_dim_size;
152
153        let mut complex_data = Vec::with_capacity(total_elements);
154        for &value in &input_data {
155            complex_data.push(Complex64::new(value.into(), 0.0));
156        }
157
158        // Perform FFT on each vector
159        for i in 0..num_ffts {
160            let start = i * last_dim_size;
161            let end = start + last_dim_size;
162            fft_plan.execute(&mut complex_data[start..end])?;
163        }
164
165        // Create output tensor (same shape but complex type)
166        Tensor::from_complex_data(complex_data, shape.dims().to_vec(), self.device())
167    }
168
169    /// Compute 1D inverse FFT along the last dimension
170    pub fn ifft(&self) -> Result<Tensor<T>>
171    where
172        T: TensorElement + From<f64>,
173    {
174        let complex_tensor = self.to_complex()?;
175        complex_tensor.ifft_complex()?.to_real()
176    }
177
178    /// Convert generic tensor to complex form
179    fn to_complex(&self) -> Result<Tensor<Complex64>> {
180        let input_data = self.to_vec()?;
181        let complex_data: Vec<Complex64> = input_data
182            .iter()
183            .map(|&value| Complex64::new(value.into(), 0.0))
184            .collect();
185
186        Tensor::from_complex_data(complex_data, self.shape().dims().to_vec(), self.device())
187    }
188
189    /// Compute 2D FFT on the last two dimensions
190    pub fn fft2(&self) -> Result<Tensor<Complex64>> {
191        let shape = self.shape();
192        let dims = shape.dims();
193
194        if dims.len() < 2 {
195            return Err(TorshError::InvalidArgument(
196                "2D FFT requires at least 2 dimensions".to_string(),
197            ));
198        }
199
200        // First, FFT along the last dimension
201        let temp = self.fft()?;
202
203        // Then, FFT along the second-to-last dimension
204        temp.fft_along_dim(dims.len() - 2)
205    }
206
207    /// Compute 2D inverse FFT on the last two dimensions
208    pub fn ifft2(&self) -> Result<Tensor<T>>
209    where
210        T: TensorElement + From<f64>,
211    {
212        let complex_tensor = self.to_complex()?;
213        complex_tensor.ifft2_complex()?.to_real()
214    }
215
216    /// Compute FFT along a specific dimension for real tensors
217    pub fn fft_along_dim_real(&self, dim: usize) -> Result<Tensor<Complex64>> {
218        let shape = self.shape();
219        let dims = shape.dims();
220
221        if dim >= dims.len() {
222            return Err(TorshError::InvalidArgument(format!(
223                "Dimension {} out of bounds for tensor with {} dimensions",
224                dim,
225                dims.len()
226            )));
227        }
228
229        // If it's the last dimension, use the optimized path
230        if dim == dims.len() - 1 {
231            return self.fft();
232        }
233
234        // For other dimensions, we need to transpose, FFT, then transpose back
235        let transposed = self.transpose_to_last_dim(dim)?;
236        let fft_result = transposed.fft()?;
237        fft_result.transpose_from_last_dim(dim)
238    }
239
240    /// Real-to-complex FFT (more efficient for real inputs)
241    pub fn rfft(&self) -> Result<Tensor<Complex64>> {
242        // For real FFT, we only need to compute half of the coefficients
243        // due to Hermitian symmetry
244        let shape = self.shape();
245        let last_dim_size = shape.dims().last().copied().unwrap_or(1);
246        let output_size = last_dim_size / 2 + 1;
247
248        let full_fft = self.fft()?;
249
250        // Extract the first half + 1 coefficients
251        let mut new_shape = shape.dims().to_vec();
252        *new_shape
253            .last_mut()
254            .expect("shape should have at least one dimension") = output_size;
255
256        full_fft.slice_last_dim_complex(0, output_size)
257    }
258
259    /// Complex-to-real inverse FFT
260    pub fn irfft(&self, output_size: Option<usize>) -> Result<Tensor<T>>
261    where
262        T: TensorElement + From<f64>,
263    {
264        let shape = self.shape();
265        let input_size = shape.dims().last().copied().unwrap_or(1);
266        let out_size = output_size.unwrap_or((input_size - 1) * 2);
267
268        // Reconstruct full complex spectrum using Hermitian symmetry
269        let full_spectrum = self.reconstruct_hermitian_spectrum(out_size)?;
270
271        // Perform inverse FFT and convert to real
272        let complex_result = full_spectrum.ifft_complex()?;
273        complex_result.to_real()
274    }
275
276    /// Compute power spectral density
277    pub fn power_spectrum(&self) -> Result<Tensor<T>>
278    where
279        T: TensorElement + From<f64>,
280    {
281        let fft_result = self.fft()?;
282        fft_result.power_spectrum_from_fft()
283    }
284
285    /// Compute magnitude spectrum
286    pub fn magnitude_spectrum(&self) -> Result<Tensor<T>>
287    where
288        T: TensorElement + From<f64>,
289    {
290        let fft_result = self.fft()?;
291        fft_result.magnitude_spectrum_from_fft()
292    }
293
294    /// Compute phase spectrum
295    pub fn phase_spectrum(&self) -> Result<Tensor<T>>
296    where
297        T: TensorElement + From<f64>,
298    {
299        let fft_result = self.fft()?;
300        fft_result.phase_spectrum_from_fft()
301    }
302
303    /// Slice the last dimension
304    #[allow(dead_code)]
305    fn slice_last_dim(&self, start: usize, size: usize) -> Result<Self> {
306        // This is a simplified implementation
307        // In a full implementation, you'd use proper tensor slicing
308        let shape = self.shape();
309        let dims = shape.dims();
310        let last_dim_size = dims.last().copied().unwrap_or(1);
311
312        if start + size > last_dim_size {
313            return Err(TorshError::IndexOutOfBounds {
314                index: start + size - 1,
315                size: last_dim_size,
316            });
317        }
318
319        // Create new shape
320        let mut new_dims = dims.to_vec();
321        *new_dims
322            .last_mut()
323            .expect("shape should have at least one dimension") = size;
324
325        // Extract the data (simplified - would need proper strided slicing)
326        let input_data = self.to_vec()?;
327        let total_elements = input_data.len();
328        let num_vectors = total_elements / last_dim_size;
329
330        let mut output_data = Vec::with_capacity(num_vectors * size);
331        for i in 0..num_vectors {
332            let base_idx = i * last_dim_size;
333            for j in 0..size {
334                output_data.push(input_data[base_idx + start + j]);
335            }
336        }
337
338        Self::from_data(output_data, new_dims, self.device())
339    }
340
341    /// Reconstruct Hermitian spectrum for IRFFT
342    fn reconstruct_hermitian_spectrum(&self, output_size: usize) -> Result<Tensor<Complex64>> {
343        // This is a simplified implementation
344        // Real implementation would properly handle Hermitian symmetry
345        let shape = self.shape();
346        let input_size = shape.dims().last().copied().unwrap_or(1);
347
348        if output_size < (input_size - 1) * 2 {
349            return Err(TorshError::InvalidArgument(
350                "Output size too small for IRFFT".to_string(),
351            ));
352        }
353
354        // For now, just pad with zeros (proper implementation would use conjugate symmetry)
355        let mut new_dims = shape.dims().to_vec();
356        *new_dims
357            .last_mut()
358            .expect("shape should have at least one dimension") = output_size;
359
360        let input_data = self.to_vec()?;
361        let mut output_data = Vec::with_capacity(input_data.len() * output_size / input_size);
362
363        // Simple zero-padding (would need proper Hermitian reconstruction)
364        for &value in &input_data {
365            // Convert T to Complex64 - assuming T can be converted to f64
366            let f64_value: f64 = value.into();
367            output_data.push(Complex64::new(f64_value, 0.0));
368        }
369
370        // Pad with zeros to reach output size
371        while output_data.len() < output_data.capacity() {
372            output_data.push(Complex64::new(0.0, 0.0));
373        }
374
375        Tensor::from_complex_data(output_data, new_dims, self.device())
376    }
377}
378
379/// General tensor operations that don't require Into<f64>
380impl<T: TensorElement> Tensor<T> {
381    /// Helper method to transpose a dimension to the last position
382    fn transpose_to_last_dim(&self, dim: usize) -> Result<Self> {
383        let ndim = self.shape().dims().len();
384        if dim == ndim - 1 {
385            return Ok(self.clone());
386        }
387        self.transpose(dim as i32, (ndim - 1) as i32)
388    }
389
390    /// Helper method to transpose back from last dimension
391    fn transpose_from_last_dim(&self, original_dim: usize) -> Result<Self> {
392        let ndim = self.shape().dims().len();
393        if original_dim == ndim - 1 {
394            return Ok(self.clone());
395        }
396        self.transpose(original_dim as i32, (ndim - 1) as i32)
397    }
398}
399
400/// Operations specific to complex tensors
401impl Tensor<Complex64> {
402    /// Create tensor from complex data
403    pub fn from_complex_data(
404        data: Vec<Complex64>,
405        shape: Vec<usize>,
406        device: torsh_core::device::DeviceType,
407    ) -> Result<Self> {
408        Tensor::from_data(data, shape, device)
409    }
410
411    /// Convert complex tensor to real by taking the real part
412    pub fn to_real<T: TensorElement + From<f64>>(&self) -> Result<Tensor<T>> {
413        let complex_data = self.to_vec()?;
414        let real_data: Vec<T> = complex_data.iter().map(|c| T::from(c.re)).collect();
415
416        Tensor::from_data(real_data, self.shape().dims().to_vec(), self.device())
417    }
418
419    /// Compute power spectrum from FFT result
420    pub fn power_spectrum_from_fft<T: TensorElement + From<f64>>(&self) -> Result<Tensor<T>> {
421        let complex_data = self.to_vec()?;
422        let power_data: Vec<T> = complex_data
423            .iter()
424            .map(|c| T::from(c.norm().powi(2)))
425            .collect();
426
427        Tensor::from_data(power_data, self.shape().dims().to_vec(), self.device())
428    }
429
430    /// Compute magnitude spectrum from FFT result
431    pub fn magnitude_spectrum_from_fft<T: TensorElement + From<f64>>(&self) -> Result<Tensor<T>> {
432        let complex_data = self.to_vec()?;
433        let magnitude_data: Vec<T> = complex_data.iter().map(|c| T::from(c.norm())).collect();
434
435        Tensor::from_data(magnitude_data, self.shape().dims().to_vec(), self.device())
436    }
437
438    /// Compute phase spectrum from FFT result
439    pub fn phase_spectrum_from_fft<T: TensorElement + From<f64>>(&self) -> Result<Tensor<T>> {
440        let complex_data = self.to_vec()?;
441        let phase_data: Vec<T> = complex_data.iter().map(|c| T::from(c.arg())).collect();
442
443        Tensor::from_data(phase_data, self.shape().dims().to_vec(), self.device())
444    }
445
446    /// Compute FFT for complex data
447    pub fn fft_complex(&self) -> Result<Tensor<Complex64>> {
448        let shape = self.shape();
449        let last_dim_size = shape.dims().last().copied().unwrap_or(1);
450
451        let plan = FFTPlan::new(last_dim_size, true)?;
452
453        let mut complex_data = self.to_vec()?;
454        let num_ffts = complex_data.len() / last_dim_size;
455
456        // Perform FFT on each vector
457        for i in 0..num_ffts {
458            let start = i * last_dim_size;
459            let end = start + last_dim_size;
460            plan.execute(&mut complex_data[start..end])?;
461        }
462
463        // Create output tensor with same shape
464        Tensor::from_complex_data(complex_data, shape.dims().to_vec(), self.device())
465    }
466
467    /// Compute inverse FFT for complex data
468    pub fn ifft_complex(&self) -> Result<Tensor<Complex64>> {
469        let shape = self.shape();
470        let last_dim_size = shape.dims().last().copied().unwrap_or(1);
471
472        let plan = FFTPlan::new(last_dim_size, false)?;
473
474        let mut complex_data = self.to_vec()?;
475        let num_ffts = complex_data.len() / last_dim_size;
476
477        // Perform inverse FFT on each vector
478        for i in 0..num_ffts {
479            let start = i * last_dim_size;
480            let end = start + last_dim_size;
481            plan.execute(&mut complex_data[start..end])?;
482        }
483
484        Tensor::from_complex_data(complex_data, shape.dims().to_vec(), self.device())
485    }
486
487    /// Compute 2D inverse FFT for complex data
488    pub fn ifft2_complex(&self) -> Result<Tensor<Complex64>> {
489        let shape = self.shape();
490        let dims = shape.dims();
491
492        if dims.len() < 2 {
493            return Err(TorshError::InvalidArgument(
494                "2D IFFT requires at least 2 dimensions".to_string(),
495            ));
496        }
497
498        // First, IFFT along the second-to-last dimension
499        let temp = self.ifft_along_dim(dims.len() - 2)?;
500
501        // Then, IFFT along the last dimension
502        temp.ifft_complex()
503    }
504
505    /// Compute inverse FFT along a specific dimension
506    pub fn ifft_along_dim(&self, dim: usize) -> Result<Tensor<Complex64>> {
507        let shape = self.shape();
508        let dims = shape.dims();
509
510        if dim >= dims.len() {
511            return Err(TorshError::InvalidArgument(format!(
512                "Dimension {} out of bounds for tensor with {} dimensions",
513                dim,
514                dims.len()
515            )));
516        }
517
518        // If it's the last dimension, use the optimized path
519        if dim == dims.len() - 1 {
520            return self.ifft_complex();
521        }
522
523        // For other dimensions, we need to transpose, IFFT, then transpose back
524        let transposed = self.transpose_to_last_dim_complex(dim)?;
525        let ifft_result = transposed.ifft_complex()?;
526        ifft_result.transpose_from_last_dim_complex(dim)
527    }
528
529    /// Simple transpose helper for FFT operations (complex version)
530    fn transpose_to_last_dim_complex(&self, _dim: usize) -> Result<Tensor<Complex64>> {
531        // Simple implementation - just return self for now
532        // In a full implementation, this would properly transpose the tensor
533        Ok(self.clone())
534    }
535
536    /// Simple transpose helper for FFT operations (reverse, complex version)
537    fn transpose_from_last_dim_complex(&self, _dim: usize) -> Result<Tensor<Complex64>> {
538        // Simple implementation - just return self for now
539        // In a full implementation, this would properly transpose the tensor back
540        Ok(self.clone())
541    }
542
543    /// 2D FFT for complex tensors
544    pub fn fft2_complex(&self) -> Result<Tensor<Complex64>> {
545        let shape = self.shape();
546        let dims = shape.dims().to_vec();
547
548        if dims.len() < 2 {
549            return Err(TorshError::InvalidArgument(
550                "2D FFT requires at least 2 dimensions".to_string(),
551            ));
552        }
553
554        // First, FFT along the last dimension
555        let temp = self.fft_complex()?;
556
557        // Then, FFT along the second-to-last dimension
558        temp.fft_along_dim_complex(dims.len() - 2)
559    }
560
561    /// Compute FFT along a specific dimension for complex tensors
562    pub fn fft_along_dim(&self, dim: usize) -> Result<Tensor<Complex64>> {
563        self.fft_along_dim_complex(dim)
564    }
565
566    /// Internal implementation of FFT along dimension for complex tensors
567    pub fn fft_along_dim_complex(&self, dim: usize) -> Result<Tensor<Complex64>> {
568        let shape = self.shape();
569        let dims = shape.dims();
570
571        if dim >= dims.len() {
572            return Err(TorshError::InvalidArgument(format!(
573                "Dimension {} out of bounds for tensor with {} dimensions",
574                dim,
575                dims.len()
576            )));
577        }
578
579        // If it's the last dimension, use the optimized path
580        if dim == dims.len() - 1 {
581            return self.fft_complex();
582        }
583
584        // For other dimensions, we need to transpose, FFT, then transpose back
585        let transposed = self.transpose_to_last_dim_complex(dim)?;
586        let fft_result = transposed.fft_complex()?;
587        fft_result.transpose_from_last_dim_complex(dim)
588    }
589
590    /// Slice along the last dimension for complex tensors
591    pub fn slice_last_dim_complex(&self, start: usize, size: usize) -> Result<Tensor<Complex64>> {
592        let shape = self.shape();
593        let dims = shape.dims().to_vec();
594
595        if dims.is_empty() {
596            return Err(TorshError::InvalidArgument(
597                "Cannot slice empty tensor".to_string(),
598            ));
599        }
600
601        let last_dim = dims.len() - 1;
602        let last_dim_size = dims[last_dim];
603        let end = start + size;
604
605        if start >= last_dim_size || end > last_dim_size {
606            return Err(TorshError::InvalidArgument(format!(
607                "Invalid slice range {start}..{end} for dimension of size {last_dim_size}"
608            )));
609        }
610
611        let data = self.to_vec()?;
612        let num_elements_per_slice = dims[..last_dim].iter().product::<usize>();
613        let mut result_data = Vec::with_capacity(num_elements_per_slice * size);
614
615        for i in 0..num_elements_per_slice {
616            let slice_start = i * last_dim_size + start;
617            let slice_end = slice_start + size;
618            result_data.extend_from_slice(&data[slice_start..slice_end]);
619        }
620
621        let mut new_dims = dims;
622        new_dims[last_dim] = size;
623
624        Tensor::from_complex_data(result_data, new_dims, self.device())
625    }
626}
627
628/// Windowing functions for signal processing
629pub mod windows {
630    use super::*;
631
632    /// Hann window
633    pub fn hann<T: TensorElement + From<f64>>(size: usize) -> Result<Tensor<T>> {
634        let data: Vec<T> = (0..size)
635            .map(|i| {
636                let factor = 0.5 * (1.0 - (2.0 * PI * i as f64 / (size - 1) as f64).cos());
637                T::from(factor)
638            })
639            .collect();
640
641        Tensor::from_data(data, vec![size], torsh_core::device::DeviceType::Cpu)
642    }
643
644    /// Hamming window
645    pub fn hamming<T: TensorElement + From<f64>>(size: usize) -> Result<Tensor<T>> {
646        let data: Vec<T> = (0..size)
647            .map(|i| {
648                let factor = 0.54 - 0.46 * (2.0 * PI * i as f64 / (size - 1) as f64).cos();
649                T::from(factor)
650            })
651            .collect();
652
653        Tensor::from_data(data, vec![size], torsh_core::device::DeviceType::Cpu)
654    }
655
656    /// Blackman window
657    pub fn blackman<T: TensorElement + From<f64>>(size: usize) -> Result<Tensor<T>> {
658        let data: Vec<T> = (0..size)
659            .map(|i| {
660                let n = i as f64;
661                let n_max = (size - 1) as f64;
662                let factor =
663                    0.42 - 0.5 * (2.0 * PI * n / n_max).cos() + 0.08 * (4.0 * PI * n / n_max).cos();
664                T::from(factor)
665            })
666            .collect();
667
668        Tensor::from_data(data, vec![size], torsh_core::device::DeviceType::Cpu)
669    }
670
671    /// Kaiser window
672    pub fn kaiser<T: TensorElement + From<f64>>(size: usize, beta: f64) -> Result<Tensor<T>> {
673        // Simplified Kaiser window (proper implementation would use modified Bessel function)
674        let data: Vec<T> = (0..size)
675            .map(|i| {
676                let n = i as f64;
677                let n_max = (size - 1) as f64;
678                let factor = (beta * (1.0 - ((2.0 * n / n_max) - 1.0).powi(2)).sqrt()).exp();
679                T::from(factor)
680            })
681            .collect();
682
683        Tensor::from_data(data, vec![size], torsh_core::device::DeviceType::Cpu)
684    }
685}
686
687#[cfg(test)]
688mod tests {
689    use super::*;
690    use crate::Tensor;
691
692    #[test]
693    fn test_fft_plan_creation() {
694        let plan = FFTPlan::new(8, true).expect("FFT plan creation should succeed");
695        assert_eq!(plan.size, 8);
696        assert_eq!(plan.twiddles.len(), 4);
697        assert_eq!(plan.bit_reversed_indices.len(), 8);
698        assert!(plan.is_forward);
699    }
700
701    #[test]
702    fn test_complex_arithmetic() {
703        let a = Complex64::new(1.0, 2.0);
704        let b = Complex64::new(3.0, 4.0);
705
706        let sum = a + b;
707        assert_eq!(sum.re, 4.0);
708        assert_eq!(sum.im, 6.0);
709
710        let product = a * b;
711        assert_eq!(product.re, -5.0); // (1*3 - 2*4)
712        assert_eq!(product.im, 10.0); // (1*4 + 2*3)
713
714        assert_eq!(a.norm(), (5.0_f64).sqrt());
715    }
716
717    #[test]
718    fn test_fft_basic() {
719        // Test with a simple signal
720        let data = vec![1.0, 0.0, 0.0, 0.0];
721        let tensor = Tensor::from_data(data, vec![4], torsh_core::device::DeviceType::Cpu)
722            .expect("tensor creation should succeed");
723
724        // Test that FFT works correctly
725        let result = tensor.fft();
726        assert!(result.is_ok(), "FFT should work with valid input");
727
728        let fft_result = result.expect("FFT operation should succeed");
729        assert_eq!(fft_result.shape().dims(), &[4]);
730
731        // The FFT of [1, 0, 0, 0] should be [1, 1, 1, 1]
732        let output_data = fft_result
733            .to_vec()
734            .expect("to_vec conversion should succeed");
735        assert_eq!(output_data.len(), 4);
736        // Check that the DC component is 1
737        assert!((output_data[0].re - 1.0).abs() < 1e-6);
738        assert!(output_data[0].im.abs() < 1e-6);
739    }
740
741    #[test]
742    fn test_windowing_functions() {
743        let hann_window = windows::hann::<f64>(8).expect("FFT operation should succeed");
744        assert_eq!(hann_window.shape().dims(), &[8]);
745
746        let hamming_window = windows::hamming::<f64>(8).expect("FFT operation should succeed");
747        assert_eq!(hamming_window.shape().dims(), &[8]);
748
749        let blackman_window = windows::blackman::<f64>(8).expect("FFT operation should succeed");
750        assert_eq!(blackman_window.shape().dims(), &[8]);
751    }
752}