stem_splitter_core/core/
dsp.rs

1use num_complex::Complex32;
2use rustfft::{num_traits::Zero, FftPlanner};
3
4pub fn to_planar_stereo(interleaved: &[f32], channels: u16) -> Vec<[f32; 2]> {
5    if channels == 1 {
6        interleaved.iter().map(|&x| [x, x]).collect()
7    } else {
8        let mut out = Vec::with_capacity(interleaved.len() / 2);
9        let mut i = 0;
10        while i + 1 < interleaved.len() {
11            out.push([interleaved[i], interleaved[i + 1]]);
12            i += 2;
13        }
14        out
15    }
16}
17
18/// Hann window function
19fn hann(n_fft: usize) -> Vec<f32> {
20    if n_fft <= 1 {
21        return vec![1.0];
22    }
23    let denom = (n_fft - 1) as f32;
24    (0..n_fft)
25        .map(|i| 0.5 - 0.5 * (2.0 * std::f32::consts::PI * (i as f32) / denom).cos())
26        .collect()
27}
28
29/// Compute complex-as-channels spectrogram for stereo with center padding.
30/// Returns (buffer, F=2048, Frames=336) for T=343_980, n_fft=4096, hop=1024.
31/// Layout is [1, 4, F, Frames] flattened => channels order: L.re, L.im, R.re, R.im.
32pub fn stft_cac_stereo_centered(
33    left: &[f32],
34    right: &[f32],
35    n_fft: usize,
36    hop: usize,
37) -> (Vec<f32>, usize, usize) {
38    assert_eq!(left.len(), right.len());
39    let t = left.len();
40    // Demucs export expects center=True: pad n_fft/2 both sides
41    let pad = n_fft / 2;
42    let lpad = vec![0.0f32; pad];
43    let rpad = vec![0.0f32; pad];
44
45    let mut l_sig = Vec::with_capacity(pad + t + pad);
46    let mut r_sig = Vec::with_capacity(pad + t + pad);
47    l_sig.extend_from_slice(&lpad);
48    l_sig.extend_from_slice(left);
49    l_sig.extend_from_slice(&lpad);
50    r_sig.extend_from_slice(&rpad);
51    r_sig.extend_from_slice(right);
52    r_sig.extend_from_slice(&rpad);
53
54    // Frames = 1 + floor(T / hop) when center=True and T divisible-ish
55    let frames = 1 + (t / hop);
56    let window = hann(n_fft);
57
58    // FFT
59    let mut planner = FftPlanner::new();
60    let fft = planner.plan_fft_forward(n_fft);
61
62    // We keep only F = n_fft/2 bins (drop Nyquist so F=2048 for 4096)
63    let f_bins = n_fft / 2;
64
65    // Layout target: [1, 4, F, Frames]
66    let mut out = vec![0.0f32; 4 * f_bins * frames];
67
68    // Scratch buffers
69    let mut buf_l = vec![Complex32::zero(); n_fft];
70    let mut buf_r = vec![Complex32::zero(); n_fft];
71
72    for fr in 0..frames {
73        let start = fr * hop;
74        // slice from padded signals
75        let li = &l_sig[start..start + n_fft];
76        let ri = &r_sig[start..start + n_fft];
77
78        // window + pack into complex
79        for i in 0..n_fft {
80            let w = window[i];
81            buf_l[i].re = li[i] * w;
82            buf_l[i].im = 0.0;
83            buf_r[i].re = ri[i] * w;
84            buf_r[i].im = 0.0;
85        }
86
87        fft.process(&mut buf_l);
88        fft.process(&mut buf_r);
89
90        // write channels [L.re, L.im, R.re, R.im] over [F,Frames]
91        for fi in 0..f_bins {
92            let base_fr = fi * frames + fr; // [F,Frames] index
93
94            // L.re
95            out[0 * f_bins * frames + base_fr] = buf_l[fi].re;
96            // L.im
97            out[1 * f_bins * frames + base_fr] = buf_l[fi].im;
98            // R.re
99            out[2 * f_bins * frames + base_fr] = buf_r[fi].re;
100            // R.im
101            out[3 * f_bins * frames + base_fr] = buf_r[fi].im;
102        }
103    }
104
105    (out, f_bins, frames)
106}
107
108/// Inverse STFT for complex-as-channels stereo spectrogram
109/// Input: complex-as-channels [L.re, L.im, R.re, R.im] with shape [4, F, Frames]
110/// Returns: (left, right) stereo waveform of length target_length
111pub fn istft_cac_stereo(
112    spec_cac: &[f32],
113    f_bins: usize,
114    frames: usize,
115    n_fft: usize,
116    hop: usize,
117    target_length: usize,
118) -> (Vec<f32>, Vec<f32>) {
119    let window = hann(n_fft);
120    
121    // Prepare IFFT
122    let mut planner = FftPlanner::new();
123    let ifft = planner.plan_fft_inverse(n_fft);
124    
125    // Padded length (matching forward STFT)
126    let pad = n_fft / 2;
127    let padded_length = target_length + 2 * pad;
128    
129    // Output buffers (padded)
130    let mut left_out = vec![0.0f32; padded_length];
131    let mut right_out = vec![0.0f32; padded_length];
132    let mut window_sum = vec![0.0f32; padded_length];
133    
134    // Scratch buffers for IFFT
135    let mut buf_l = vec![Complex32::zero(); n_fft];
136    let mut buf_r = vec![Complex32::zero(); n_fft];
137    
138    for fr in 0..frames {
139        // Clear buffers
140        buf_l.fill(Complex32::zero());
141        buf_r.fill(Complex32::zero());
142        
143        // Reconstruct full spectrum from half spectrum
144        // Fill positive frequencies [0..f_bins]
145        for fi in 0..f_bins {
146            let base_fr = fi * frames + fr;
147            buf_l[fi] = Complex32::new(
148                spec_cac[0 * f_bins * frames + base_fr],  // L.re
149                spec_cac[1 * f_bins * frames + base_fr],  // L.im
150            );
151            buf_r[fi] = Complex32::new(
152                spec_cac[2 * f_bins * frames + base_fr],  // R.re
153                spec_cac[3 * f_bins * frames + base_fr],  // R.im
154            );
155        }
156        
157        // Fill negative frequencies (complex conjugate mirror)
158        // Skip DC (fi=0) and only mirror [1..f_bins-1]
159        for fi in 1..f_bins {
160            let neg_fi = n_fft - fi;
161            buf_l[neg_fi] = buf_l[fi].conj();
162            buf_r[neg_fi] = buf_r[fi].conj();
163        }
164        
165        // Ensure DC and Nyquist are real
166        buf_l[0].im = 0.0;
167        buf_r[0].im = 0.0;
168        if n_fft % 2 == 0 && f_bins < n_fft {
169            buf_l[n_fft / 2].im = 0.0;
170            buf_r[n_fft / 2].im = 0.0;
171        }
172        
173        // Apply IFFT
174        ifft.process(&mut buf_l);
175        ifft.process(&mut buf_r);
176        
177        // Overlap-add with window (no extra scaling - already in IFFT)
178        let start = fr * hop;
179        for i in 0..n_fft {
180            let pos = start + i;
181            if pos < padded_length {
182                let w = window[i];
183                // IFFT returns normalized values, apply window for overlap-add
184                left_out[pos] += buf_l[i].re * w / (n_fft as f32);
185                right_out[pos] += buf_r[i].re * w / (n_fft as f32);
186                window_sum[pos] += w * w;
187            }
188        }
189    }
190    
191    // Normalize by window sum to account for overlap
192    for i in 0..padded_length {
193        let sum = window_sum[i];
194        if sum > 1e-10 {
195            left_out[i] /= sum;
196            right_out[i] /= sum;
197        }
198    }
199    
200    // Remove padding and ensure we don't go out of bounds
201    let start = pad.min(left_out.len());
202    let end = (pad + target_length).min(left_out.len());
203    
204    let left_final = if end > start {
205        left_out[start..end].to_vec()
206    } else {
207        vec![0.0; target_length]
208    };
209    
210    let right_final = if end > start {
211        right_out[start..end].to_vec()
212    } else {
213        vec![0.0; target_length]
214    };
215    
216    (left_final, right_final)
217}