stem_splitter_core/core/
dsp.rs1use num_complex::Complex32;
2use rustfft::{num_traits::Zero, FftPlanner};
3
4pub fn to_planar_stereo(interleaved: &[f32], channels: u16) -> Vec<[f32; 2]> {
5 if channels == 1 {
6 interleaved.iter().map(|&x| [x, x]).collect()
7 } else {
8 let mut out = Vec::with_capacity(interleaved.len() / 2);
9 let mut i = 0;
10 while i + 1 < interleaved.len() {
11 out.push([interleaved[i], interleaved[i + 1]]);
12 i += 2;
13 }
14 out
15 }
16}
17
18fn hann(n_fft: usize) -> Vec<f32> {
20 if n_fft <= 1 {
21 return vec![1.0];
22 }
23 let denom = (n_fft - 1) as f32;
24 (0..n_fft)
25 .map(|i| 0.5 - 0.5 * (2.0 * std::f32::consts::PI * (i as f32) / denom).cos())
26 .collect()
27}
28
29pub fn stft_cac_stereo_centered(
33 left: &[f32],
34 right: &[f32],
35 n_fft: usize,
36 hop: usize,
37) -> (Vec<f32>, usize, usize) {
38 assert_eq!(left.len(), right.len());
39 let t = left.len();
40 let pad = n_fft / 2;
42 let lpad = vec![0.0f32; pad];
43 let rpad = vec![0.0f32; pad];
44
45 let mut l_sig = Vec::with_capacity(pad + t + pad);
46 let mut r_sig = Vec::with_capacity(pad + t + pad);
47 l_sig.extend_from_slice(&lpad);
48 l_sig.extend_from_slice(left);
49 l_sig.extend_from_slice(&lpad);
50 r_sig.extend_from_slice(&rpad);
51 r_sig.extend_from_slice(right);
52 r_sig.extend_from_slice(&rpad);
53
54 let frames = 1 + (t / hop);
56 let window = hann(n_fft);
57
58 let mut planner = FftPlanner::new();
60 let fft = planner.plan_fft_forward(n_fft);
61
62 let f_bins = n_fft / 2;
64
65 let mut out = vec![0.0f32; 4 * f_bins * frames];
67
68 let mut buf_l = vec![Complex32::zero(); n_fft];
70 let mut buf_r = vec![Complex32::zero(); n_fft];
71
72 for fr in 0..frames {
73 let start = fr * hop;
74 let li = &l_sig[start..start + n_fft];
76 let ri = &r_sig[start..start + n_fft];
77
78 for i in 0..n_fft {
80 let w = window[i];
81 buf_l[i].re = li[i] * w;
82 buf_l[i].im = 0.0;
83 buf_r[i].re = ri[i] * w;
84 buf_r[i].im = 0.0;
85 }
86
87 fft.process(&mut buf_l);
88 fft.process(&mut buf_r);
89
90 for fi in 0..f_bins {
92 let base_fr = fi * frames + fr; out[0 * f_bins * frames + base_fr] = buf_l[fi].re;
96 out[1 * f_bins * frames + base_fr] = buf_l[fi].im;
98 out[2 * f_bins * frames + base_fr] = buf_r[fi].re;
100 out[3 * f_bins * frames + base_fr] = buf_r[fi].im;
102 }
103 }
104
105 (out, f_bins, frames)
106}
107
108pub fn istft_cac_stereo(
112 spec_cac: &[f32],
113 f_bins: usize,
114 frames: usize,
115 n_fft: usize,
116 hop: usize,
117 target_length: usize,
118) -> (Vec<f32>, Vec<f32>) {
119 let window = hann(n_fft);
120
121 let mut planner = FftPlanner::new();
123 let ifft = planner.plan_fft_inverse(n_fft);
124
125 let pad = n_fft / 2;
127 let padded_length = target_length + 2 * pad;
128
129 let mut left_out = vec![0.0f32; padded_length];
131 let mut right_out = vec![0.0f32; padded_length];
132 let mut window_sum = vec![0.0f32; padded_length];
133
134 let mut buf_l = vec![Complex32::zero(); n_fft];
136 let mut buf_r = vec![Complex32::zero(); n_fft];
137
138 for fr in 0..frames {
139 buf_l.fill(Complex32::zero());
141 buf_r.fill(Complex32::zero());
142
143 for fi in 0..f_bins {
146 let base_fr = fi * frames + fr;
147 buf_l[fi] = Complex32::new(
148 spec_cac[0 * f_bins * frames + base_fr], spec_cac[1 * f_bins * frames + base_fr], );
151 buf_r[fi] = Complex32::new(
152 spec_cac[2 * f_bins * frames + base_fr], spec_cac[3 * f_bins * frames + base_fr], );
155 }
156
157 for fi in 1..f_bins {
160 let neg_fi = n_fft - fi;
161 buf_l[neg_fi] = buf_l[fi].conj();
162 buf_r[neg_fi] = buf_r[fi].conj();
163 }
164
165 buf_l[0].im = 0.0;
167 buf_r[0].im = 0.0;
168 if n_fft % 2 == 0 && f_bins < n_fft {
169 buf_l[n_fft / 2].im = 0.0;
170 buf_r[n_fft / 2].im = 0.0;
171 }
172
173 ifft.process(&mut buf_l);
175 ifft.process(&mut buf_r);
176
177 let start = fr * hop;
179 for i in 0..n_fft {
180 let pos = start + i;
181 if pos < padded_length {
182 let w = window[i];
183 left_out[pos] += buf_l[i].re * w / (n_fft as f32);
185 right_out[pos] += buf_r[i].re * w / (n_fft as f32);
186 window_sum[pos] += w * w;
187 }
188 }
189 }
190
191 for i in 0..padded_length {
193 let sum = window_sum[i];
194 if sum > 1e-10 {
195 left_out[i] /= sum;
196 right_out[i] /= sum;
197 }
198 }
199
200 let start = pad.min(left_out.len());
202 let end = (pad + target_length).min(left_out.len());
203
204 let left_final = if end > start {
205 left_out[start..end].to_vec()
206 } else {
207 vec![0.0; target_length]
208 };
209
210 let right_final = if end > start {
211 right_out[start..end].to_vec()
212 } else {
213 vec![0.0; target_length]
214 };
215
216 (left_final, right_final)
217}