1use ndarray::{concatenate, prelude::*};
2use ndarray_rand::{rand_distr::Uniform, RandomExt};
3use ndarray_slice::Slice1Ext;
4use num_complex::{Complex, ComplexFloat};
5use rustfft::FftPlanner;
6use std::cell::RefCell;
7
8use crate::errors::VmdError;
9use crate::utils::array::{fftshift1d, ifftshift1d, Flip};
10
11thread_local! {
12 static FFT_PLANNER: RefCell<FftPlanner<f64>> = RefCell::new(FftPlanner::new());
13}
14
15#[allow(non_snake_case, clippy::type_complexity)]
16pub fn vmd(
44 input: &[f64],
45 alpha: f64,
46 tau: f64,
47 K: usize,
48 DC: i32,
49 init: i32,
50 tol: f64,
51) -> Result<(Array2<f64>, Array2<Complex<f64>>, Array2<f64>), VmdError> {
52 let fs = 1.0 / input.len() as f64;
56
57 let T = input.len();
58 let midpoint = (input.len() as f64 / 2.0).ceil() as usize;
59
60 let mut f_mirr = {
61 let input = ArrayView1::from_shape(T, input)?;
62 let first_half = input.slice(s![..midpoint]);
63 let second_half = input.slice(s![midpoint..]);
64 concatenate(Axis(0), &[first_half.flip(), input, second_half.flip()])?
65 .map(|&f| Complex::new(f, 0.))
67 };
68
69 let T = f_mirr.len() as f64;
70 let t = Array::range(1., T + 1., 1.) / T;
71 let t_len = t.len();
72 let freqs = t - 0.5 - (1. / T);
73 const N_ITER: usize = 500;
74
75 let fft_fhat = {
77 FFT_PLANNER.with(|planner| {
78 let fft = planner.borrow_mut().plan_fft_forward(T as usize);
79 fft.process(f_mirr.as_slice_mut().unwrap());
80 f_mirr
81 })
98 };
99
100 let f_hat = fftshift1d(fft_fhat.view());
101 let mut f_hat_plus = f_hat;
102 f_hat_plus
103 .slice_mut(s![..T as usize / 2])
104 .map_inplace(|v| *v = Complex::new(0., 0.));
105
106 let mut omega_plus = Array::from_shape_fn((N_ITER, K), |(_, _)| 0.);
108 match init {
109 1 => {
110 for i in 0..K {
111 omega_plus[[0, i]] = (0.5 / K as f64) * i as f64
112 }
113 }
114 2 => {
116 let rexpr = fs.log(std::f64::consts::E);
118 let random = Array::random([1, K], Uniform::new(0., 1.));
119 let rexpr2 = (0.5_f64.log(std::f64::consts::E) - rexpr) * random;
121 let mut expr = rexpr + rexpr2;
122
123 expr.map_inplace(|f| *f = f.exp());
124 let mut axis_sort = expr.slice_axis_mut(Axis(0), ndarray::Slice::new(0, None, 1));
125 axis_sort
126 .row_mut(0)
127 .sort_unstable_by(|f1, f2| f1.partial_cmp(f2).unwrap());
128 expr.row_mut(0).assign_to(
129 omega_plus
130 .slice_axis_mut(Axis(0), ndarray::Slice::new(0, None, 1))
131 .row_mut(0),
132 );
133 }
134 _ => {
135 omega_plus.slice_mut(s![.., ..]).map_inplace(|f| *f = 0.);
136 }
137 };
138 if DC != 0 {
139 omega_plus[[0, 0]] = 0.;
140 }
141
142 const ROWS: usize = 3;
147 let mut lambda_hat: Array2<Complex<f64>> = Array::zeros((ROWS, freqs.len()));
148
149 let mut u_hat_plus: Array3<Complex<f64>> = Array::zeros((ROWS, freqs.len(), K));
153 let mut udiff = tol + f64::EPSILON;
154 let mut n = 0;
155 let mut sum_uk: Array1<Complex<f64>> = Array::zeros(freqs.len());
156
157 let mut cur: usize = 0; let mut next: usize = 1; let mut prev: usize;
160
161 let alpha: Array1<f64> = Array::ones(K) * alpha;
163
164 while udiff > tol && n < N_ITER - 1 {
166 let T = T as usize;
167 let k = 0;
171 let s1 = u_hat_plus.slice(s![cur, .., K - 1]);
172 let s2 = u_hat_plus.slice(s![cur, .., 0]);
173 sum_uk += &s1;
174 sum_uk -= &s2;
175
176 let lambda_hat_slice = &lambda_hat.slice(s![cur, ..]) / Complex::new(2., 0.);
178 let lexpr = &f_hat_plus - &sum_uk - &lambda_hat_slice;
179 let rexpr = 1. + alpha[k] * (&freqs - omega_plus[[n, k]]).map_mut(|f| f.powi(2));
180 (lexpr / rexpr).move_into(u_hat_plus.slice_mut(s![next, .., k]));
181
182 if DC == 0 {
183 let expr1 = freqs.slice(s![T / 2..T]);
184 let subexpr2 = u_hat_plus.slice(s![next, T / 2..T, k]);
185 let expr2 = subexpr2.map(|f| ComplexFloat::abs(*f).powi(2));
186 let expr1: f64 = expr1.dot(&expr2);
187 let expr2 = expr2.sum();
188 omega_plus[[n + 1, k]] = expr1 / expr2;
189 }
190
191 for k in 1..K {
193 sum_uk += &u_hat_plus.slice(s![next, .., k - 1]);
195 sum_uk -= &u_hat_plus.slice(s![cur, .., k]);
196
197 let lexpr = &f_hat_plus - &sum_uk - &lambda_hat_slice;
200 let rexpr = 1. + alpha[k] * (&freqs - omega_plus[[n, k]]).map(|v| v.powi(2));
201 (lexpr / rexpr).move_into(u_hat_plus.slice_mut(s![next, .., k]));
202
203 let expr1 = freqs.slice(s![T / 2..T]);
205 let subexpr2 = u_hat_plus.slice(s![next, T / 2..T, k]);
206 let expr2 = subexpr2.map(|f| ComplexFloat::abs(*f).powi(2));
207 let expr1: f64 = expr1.dot(&expr2);
208 let expr2 = expr2.sum();
209 omega_plus[[n + 1, k]] = expr1 / expr2;
210 }
211
212 let expr1 = (&u_hat_plus
214 .slice(s![next, .., ..])
215 .sum_axis(ndarray::Axis(1))
216 - &f_hat_plus)
217 * tau;
218 let expr1 = &lambda_hat.slice(s![cur, ..]) + expr1;
219 expr1.move_into(lambda_hat.slice_mut(s![next, ..]));
220
221 n += 1;
223 cur = n % ROWS;
224 next = (n + 1) % ROWS;
225 prev = (n - 1) % ROWS;
226
227 let mut udiff_ = Complex::new(f64::EPSILON, 0.);
228 for i in 0..K {
229 let expr1 = &u_hat_plus.slice(s![cur, .., i]) - &u_hat_plus.slice(s![prev, .., i]);
230 let expr2 = expr1.map(|f| f.conj());
231 let expr = expr1.dot(&expr2) * (1. / T as f64);
232
233 udiff_ += expr;
234 }
235 udiff = ComplexFloat::abs(udiff_);
236 }
237 let n_iter = std::cmp::min(n, N_ITER);
240 let omega = omega_plus.slice(s![..n_iter, ..]);
241
242 let T = T as usize;
248 let mut u_hat = Array::from_elem([T, K], Complex::new(0.0, 0.0));
249 u_hat
250 .slice_mut(s![T / 2..T, ..])
251 .assign(&u_hat_plus.slice(s![(n_iter - 1) % ROWS, T / 2..T, ..]));
252 u_hat_plus
254 .slice(s![(n_iter - 1) % ROWS, T / 2..T, ..])
255 .map(|f| f.conj())
256 .move_into(u_hat.slice_mut(s![1..T/2+1;-1,..]));
257 u_hat
258 .slice(s![-1, ..])
259 .map(|f| f.conj())
260 .move_into(u_hat.slice_mut(s![0, ..]));
261
262 let mut u: Array2<f64> = ndarray::Array::zeros([K, t_len]);
263 FFT_PLANNER.with(|planner| {
264 let ffti = planner
265 .borrow_mut()
266 .plan_fft_inverse(u_hat.slice(s![.., 0]).len());
267 for k in 0..K {
268 let subexpr = u_hat.slice(s![.., k]);
269 let mut ishifted = ifftshift1d(subexpr);
270 ffti.process(ishifted.as_slice_mut().unwrap());
271 let len = ishifted.len() as f64;
274 (ishifted / len)
277 .map(|f| f.re())
278 .move_into(u.slice_mut(s![k, ..]));
279 }
280 });
281
282 let u = u.slice_mut(s![.., T / 4..3 * T / 4]);
284
285 let mut u_hat: Array2<Complex<f64>> = Array::zeros([u.shape()[1], K]);
287 FFT_PLANNER.with(|planner| {
288 for k in 0..K {
289 let mut u_ = u.slice(s![k, ..]).map(|f| Complex::new(*f, 0.));
290 let fft = planner.borrow_mut().plan_fft_forward(u_.len());
291 fft.process(u_.as_slice_mut().unwrap());
292 fftshift1d(u_.view()).move_into(u_hat.slice_mut(s![.., k]));
293 }
294 });
295
296 Ok((u.to_owned(), u_hat, omega.to_owned()))
297}