use std::{f64::consts::TAU, ops::{AddAssign, Mul, MulAssign}};
use num::{complex::ComplexFloat, traits::{Inv, SaturatingSub}, Complex, Float, NumCast, One, Saturating, Zero};
use slice_ops::{Slice, SliceOps};
pub trait SliceMath<T>: SliceOps<T>
{
fn recip_assign_all(&mut self)
where
T: Inv<Output = T>;
fn conj_assign_all(&mut self)
where
T: ComplexFloat;
fn convolve_direct<Rhs, C>(&self, rhs: &[Rhs]) -> C
where
T: Mul<Rhs, Output: AddAssign + Zero> + Copy,
Rhs: Copy,
C: FromIterator<<T as Mul<Rhs>>::Output>;
fn convolve_fft<Rhs, C>(&self, rhs: &[Rhs]) -> C
where
T: Float + Copy,
Rhs: Float + Copy,
Complex<T>: MulAssign + AddAssign + ComplexFloat<Real = T> + Mul<Complex<Rhs>, Output: ComplexFloat<Real: Float>>,
Complex<Rhs>: MulAssign + AddAssign + ComplexFloat<Real = Rhs>,
<Complex<T> as Mul<Complex<Rhs>>>::Output: ComplexFloat<Real: Float> + Into<Complex<<<Complex<T> as Mul<Complex<Rhs>>>::Output as ComplexFloat>::Real>>,
Complex<<<Complex<T> as Mul<Complex<Rhs>>>::Output as ComplexFloat>::Real>: MulAssign + AddAssign + ComplexFloat<Real = <<Complex<T> as Mul<Complex<Rhs>>>::Output as ComplexFloat>::Real>,
C: FromIterator<<<Complex<T> as Mul<Complex<Rhs>>>::Output as ComplexFloat>::Real>;
fn fft(&mut self)
where
T: ComplexFloat<Real: Float> + MulAssign + AddAssign + From<Complex<T::Real>>;
fn ifft(&mut self)
where
T: ComplexFloat<Real: Float> + MulAssign + AddAssign + From<Complex<T::Real>>;
fn real_fft(&self, y: &mut [Complex<T>])
where
T: Float,
Complex<T>: ComplexFloat<Real = T> + MulAssign + AddAssign;
fn real_ifft(&mut self, x: &[Complex<T>])
where
T: Float,
Complex<T>: ComplexFloat<Real = T> + MulAssign + AddAssign;
}
impl<T> SliceMath<T> for [T]
{
fn recip_assign_all(&mut self)
where
T: Inv<Output = T>
{
let mut i = 0;
while i < self.len()
{
unsafe {
let ptr = self.as_mut_ptr().add(i);
ptr.write(ptr.read().inv());
}
i += 1;
}
}
fn conj_assign_all(&mut self)
where
T: ComplexFloat
{
let mut i = 0;
while i < self.len()
{
unsafe {
let ptr = self.as_mut_ptr().add(i);
ptr.write(ptr.read().conj());
}
i += 1;
}
}
fn convolve_direct<Rhs, C>(&self, rhs: &[Rhs]) -> C
where
T: Mul<Rhs, Output: AddAssign + Zero> + Copy,
Rhs: Copy,
C: FromIterator<<T as Mul<Rhs>>::Output>
{
let y_len = (self.len() + rhs.len()).saturating_sub(1);
(0..y_len).map(|n| {
let mut y = Zero::zero();
for k in (n + 1).saturating_sub(self.len())..rhs.len().min(n + 1)
{
y += self[n - k]*rhs[k];
}
y
}).collect()
}
fn convolve_fft<Rhs, C>(&self, rhs: &[Rhs]) -> C
where
T: Float + Copy,
Rhs: Float + Copy,
Complex<T>: MulAssign + AddAssign + ComplexFloat<Real = T> + Mul<Complex<Rhs>, Output: ComplexFloat<Real: Float>>,
Complex<Rhs>: MulAssign + AddAssign + ComplexFloat<Real = Rhs>,
<Complex<T> as Mul<Complex<Rhs>>>::Output: ComplexFloat<Real: Float> + Into<Complex<<<Complex<T> as Mul<Complex<Rhs>>>::Output as ComplexFloat>::Real>>,
Complex<<<Complex<T> as Mul<Complex<Rhs>>>::Output as ComplexFloat>::Real>: MulAssign + AddAssign + ComplexFloat<Real = <<Complex<T> as Mul<Complex<Rhs>>>::Output as ComplexFloat>::Real>,
C: FromIterator<<<Complex<T> as Mul<Complex<Rhs>>>::Output as ComplexFloat>::Real>
{
let y_len = (self.len() + rhs.len()).saturating_sub(1);
let len = y_len.next_power_of_two();
let mut x: Vec<T> = self.to_vec();
let mut h: Vec<Rhs> = rhs.to_vec();
x.resize(len, T::zero());
h.resize(len, Rhs::zero());
let mut x_f = vec![Complex::zero(); len/2 + 1];
let mut h_f = vec![Complex::zero(); len/2 + 1];
x.real_fft(&mut x_f);
h.real_fft(&mut h_f);
let y_f: Vec<_> = x_f.into_iter()
.zip(h_f.into_iter())
.map(|(x_f, h_f)| (x_f*h_f).into())
.collect();
let mut y = vec![Zero::zero(); len];
y.real_ifft(&y_f);
y.truncate(y_len);
y.into_iter()
.collect()
}
fn fft(&mut self)
where
T: ComplexFloat<Real: Float> + MulAssign + AddAssign + From<Complex<<T>::Real>>
{
let len = self.len();
if len.is_power_of_two()
{
self.bit_reverse_permutation();
for s in 0..len.ilog2()
{
let m = 2usize << s;
let wm = <T as From<_>>::from(Complex::cis(<T::Real as NumCast>::from(-TAU/m as f64).unwrap()));
for k in (0..len).step_by(m)
{
let mut w = T::one();
for j in 0..m/2
{
let t = w*self[k + j + m/2];
let u = self[k + j];
self[k + j] = u + t;
self[k + j + m/2] = u - t;
w *= wm;
}
}
}
}
else
{
let wn = <T as From<_>>::from(Complex::cis(<T::Real as NumCast>::from(-TAU/len as f64).unwrap()));
let mut wnk = T::one();
let mut buf = vec![T::zero(); len];
unsafe {
std::ptr::swap_nonoverlapping(buf.as_mut_ptr(), self.as_mut_ptr(), len);
}
for k in 0..len
{
let mut wnki = T::one();
for i in 0..len
{
self[k] += buf[i]*wnki;
wnki *= wnk;
}
wnk *= wn;
}
}
}
fn ifft(&mut self)
where
T: ComplexFloat<Real: Float> + MulAssign + AddAssign + From<Complex<<T>::Real>>
{
let len = self.len();
if len.is_power_of_two()
{
self.bit_reverse_permutation();
for s in 0..len.ilog2()
{
let m = 2usize << s;
let wm = <T as From<_>>::from(Complex::cis(<T::Real as NumCast>::from(TAU/m as f64).unwrap()));
for k in (0..len).step_by(m)
{
let mut w = T::one();
for j in 0..m/2
{
let t = w*self[k + j + m/2];
let u = self[k + j];
self[k + j] = u + t;
self[k + j + m/2] = u - t;
w *= wm;
}
}
}
}
else
{
let wn = <T as From<_>>::from(Complex::cis(<T::Real as NumCast>::from(TAU/len as f64).unwrap()));
let mut wnk = T::one();
let mut buf = vec![T::zero(); len];
unsafe {
std::ptr::swap_nonoverlapping(buf.as_mut_ptr(), self.as_mut_ptr(), len);
}
for k in 0..len
{
let mut wnki = T::one();
for i in 0..len
{
self[k] += buf[i]*wnki;
wnki *= wnk;
}
wnk *= wn;
}
}
self.mul_assign_all(<T as From<_>>::from(<Complex<_> as From<_>>::from(<T::Real as NumCast>::from(1.0/len as f64).unwrap())));
}
fn real_fft(&self, y: &mut [Complex<T>])
where
T: Float,
Complex<T>: ComplexFloat<Real = T> + MulAssign + AddAssign
{
let len = self.len();
if len <= 1
{
return;
}
assert_eq!(y.len(), len/2 + 1, "Invalid output buffer length.");
let mut x: Vec<Complex<T>> = self.into_iter()
.map(|&x| <Complex<_> as From<_>>::from(x))
.collect();
x.fft();
for (x, y) in x.into_iter()
.zip(y.iter_mut())
{
*y = x;
}
}
fn real_ifft(&mut self, x: &[Complex<T>])
where
T: Float,
Complex<T>: ComplexFloat<Real = T> + MulAssign + AddAssign
{
let len = self.len();
if len <= 1
{
return;
}
assert_eq!(x.len(), len/2 + 1, "Invalid input buffer length.");
let mut x: Vec<Complex<T>> = x.into_iter()
.map(|&x| x)
.chain(x[1..(x.len() + len % 2 - 1)].into_iter()
.rev()
.map(|&x| x.conj())
).collect();
x.ifft();
for (x, y) in x.into_iter()
.zip(self.iter_mut())
{
*y = x.re();
}
}
}
#[test]
fn test()
{
let mut x = [1.0, 1.0, 0.0, 0.0];
let mut y = [Complex::zero(); 3];
x.real_fft(&mut y);
x.real_ifft(&y);
println!("{:?}", x)
}