scirs2_fft/
mdct.rs

1//! Modified Discrete Cosine Transform (MDCT) and Modified Discrete Sine Transform (MDST)
2//!
3//! The MDCT and MDST are lapped transforms based on the DCT/DST that are widely used
4//! in audio coding (MP3, AAC, Vorbis) due to their perfect reconstruction properties
5//! with overlapping windows.
6
7use scirs2_core::ndarray::{Array1, ArrayBase, Data};
8use std::f64::consts::PI;
9
10use crate::error::{FFTError, FFTResult};
11use crate::window::Window;
12
13/// Compute the Modified Discrete Cosine Transform (MDCT)
14///
15/// The MDCT is a lapped transform with 50% overlap between consecutive blocks.
16/// It is critically sampled and allows perfect reconstruction.
17///
18/// # Arguments
19///
20/// * `x` - Input signal
21/// * `n` - Transform size (output will be n/2 coefficients)
22/// * `window` - Window function to apply
23///
24/// # Returns
25///
26/// MDCT coefficients (n/2 values)
27///
28/// # Example
29///
30/// ```
31/// use scirs2_core::ndarray::array;
32/// use scirs2_fft::mdct::mdct;
33/// use scirs2_fft::window::Window;
34///
35/// let signal = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
36/// let mdct_result = mdct(&signal, 8, Some(Window::Hann)).unwrap();
37/// assert_eq!(mdct_result.len(), 4); // Output is half the transform size
38/// ```
39#[allow(dead_code)]
40pub fn mdct<S>(
41    x: &ArrayBase<S, scirs2_core::ndarray::Ix1>,
42    n: usize,
43    window: Option<Window>,
44) -> FFTResult<Array1<f64>>
45where
46    S: Data<Elem = f64>,
47{
48    if !n.is_multiple_of(2) {
49        return Err(FFTError::ValueError("MDCT size must be even".to_string()));
50    }
51
52    if x.len() != n {
53        return Err(FFTError::ValueError(format!(
54            "Input length {} does not match MDCT size {}",
55            x.len(),
56            n
57        )));
58    }
59
60    let half_n = n / 2;
61    let mut result = Array1::zeros(half_n);
62
63    // Apply window if specified
64    let windowed = if let Some(win) = window {
65        let win_coeffs = crate::window::get_window(win, n, true)?;
66        x.to_owned() * &win_coeffs
67    } else {
68        x.to_owned()
69    };
70
71    // Compute MDCT coefficients
72    for k in 0..half_n {
73        let mut sum = 0.0;
74        for n_idx in 0..n {
75            let angle = PI / n as f64 * (n_idx as f64 + 0.5 + half_n as f64) * (k as f64 + 0.5);
76            sum += windowed[n_idx] * angle.cos();
77        }
78        result[k] = sum * (2.0 / n as f64).sqrt();
79    }
80
81    Ok(result)
82}
83
84/// Compute the Inverse Modified Discrete Cosine Transform (IMDCT)
85///
86/// The IMDCT reconstructs a signal from MDCT coefficients.
87/// To achieve perfect reconstruction, overlapping blocks must be properly combined.
88///
89/// # Arguments
90///
91/// * `x` - MDCT coefficients
92/// * `window` - Window function to apply (should match the forward transform)
93///
94/// # Returns
95///
96/// Reconstructed signal (2 * input length)
97///
98/// # Example
99///
100/// ```
101/// use scirs2_core::ndarray::array;
102/// use scirs2_fft::mdct::{mdct, imdct};
103/// use scirs2_fft::window::Window;
104///
105/// let signal = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
106/// let mdct_coeffs = mdct(&signal, 8, Some(Window::Hann)).unwrap();
107/// let reconstructed = imdct(&mdct_coeffs, Some(Window::Hann)).unwrap();
108/// assert_eq!(reconstructed.len(), 8); // Output is twice the input length
109/// ```
110#[allow(dead_code)]
111pub fn imdct<S>(
112    x: &ArrayBase<S, scirs2_core::ndarray::Ix1>,
113    window: Option<Window>,
114) -> FFTResult<Array1<f64>>
115where
116    S: Data<Elem = f64>,
117{
118    let half_n = x.len();
119    let n = half_n * 2;
120    let mut result = Array1::zeros(n);
121
122    // Compute IMDCT values
123    for n_idx in 0..n {
124        let mut sum = 0.0;
125        for k in 0..half_n {
126            let angle = PI / n as f64 * (n_idx as f64 + 0.5 + half_n as f64) * (k as f64 + 0.5);
127            sum += x[k] * angle.cos();
128        }
129        result[n_idx] = sum * (2.0 / n as f64).sqrt();
130    }
131
132    // Apply window if specified
133    if let Some(win) = window {
134        let win_coeffs = crate::window::get_window(win, n, true)?;
135        result *= &win_coeffs;
136    }
137
138    Ok(result)
139}
140
141/// Modified Discrete Sine Transform (MDST)
142///
143/// The MDST is similar to MDCT but uses sine basis functions.
144/// It is less commonly used than MDCT but provides similar properties.
145///
146/// # Arguments
147///
148/// * `x` - Input signal
149/// * `n` - Transform size (output will be n/2 coefficients)
150/// * `window` - Window function to apply
151///
152/// # Returns
153///
154/// MDST coefficients (n/2 values)
155#[allow(dead_code)]
156pub fn mdst<S>(
157    x: &ArrayBase<S, scirs2_core::ndarray::Ix1>,
158    n: usize,
159    window: Option<Window>,
160) -> FFTResult<Array1<f64>>
161where
162    S: Data<Elem = f64>,
163{
164    if !n.is_multiple_of(2) {
165        return Err(FFTError::ValueError("MDST size must be even".to_string()));
166    }
167
168    if x.len() != n {
169        return Err(FFTError::ValueError(format!(
170            "Input length {} does not match MDST size {}",
171            x.len(),
172            n
173        )));
174    }
175
176    let half_n = n / 2;
177    let mut result = Array1::zeros(half_n);
178
179    // Apply window if specified
180    let windowed = if let Some(win) = window {
181        let win_coeffs = crate::window::get_window(win, n, true)?;
182        x.to_owned() * &win_coeffs
183    } else {
184        x.to_owned()
185    };
186
187    // Compute MDST coefficients
188    for k in 0..half_n {
189        let mut sum = 0.0;
190        for n_idx in 0..n {
191            let angle = PI / n as f64 * (n_idx as f64 + 0.5 + half_n as f64) * (k as f64 + 0.5);
192            sum += windowed[n_idx] * angle.sin();
193        }
194        result[k] = sum * (2.0 / n as f64).sqrt();
195    }
196
197    Ok(result)
198}
199
200/// Inverse Modified Discrete Sine Transform (IMDST)
201///
202/// Reconstructs a signal from MDST coefficients.
203///
204/// # Arguments
205///
206/// * `x` - MDST coefficients
207/// * `window` - Window function to apply (should match the forward transform)
208///
209/// # Returns
210///
211/// Reconstructed signal (2 * input length)
212#[allow(dead_code)]
213pub fn imdst<S>(
214    x: &ArrayBase<S, scirs2_core::ndarray::Ix1>,
215    window: Option<Window>,
216) -> FFTResult<Array1<f64>>
217where
218    S: Data<Elem = f64>,
219{
220    let half_n = x.len();
221    let n = half_n * 2;
222    let mut result = Array1::zeros(n);
223
224    // Compute IMDST values
225    for n_idx in 0..n {
226        let mut sum = 0.0;
227        for k in 0..half_n {
228            let angle = PI / n as f64 * (n_idx as f64 + 0.5 + half_n as f64) * (k as f64 + 0.5);
229            sum += x[k] * angle.sin();
230        }
231        result[n_idx] = sum * (2.0 / n as f64).sqrt();
232    }
233
234    // Apply window if specified
235    if let Some(win) = window {
236        let win_coeffs = crate::window::get_window(win, n, true)?;
237        result *= &win_coeffs;
238    }
239
240    Ok(result)
241}
242
243/// Perform overlap-add reconstruction from MDCT coefficients
244///
245/// This function handles the proper overlapping and adding of consecutive
246/// MDCT blocks for perfect reconstruction.
247///
248/// # Arguments
249///
250/// * `blocks` - Vector of MDCT coefficient blocks
251/// * `window` - Window function used in the forward transform
252/// * `hop_size` - Hop size between consecutive blocks (typically block_size/2)
253///
254/// # Returns
255///
256/// Reconstructed signal
257#[allow(dead_code)]
258pub fn mdct_overlap_add(
259    blocks: &[Array1<f64>],
260    window: Option<Window>,
261    hop_size: usize,
262) -> FFTResult<Array1<f64>> {
263    if blocks.is_empty() {
264        return Err(FFTError::ValueError("No blocks provided".to_string()));
265    }
266
267    let block_size = blocks[0].len() * 2;
268    let n_blocks = blocks.len();
269    let output_len = (n_blocks - 1) * hop_size + block_size;
270    let mut output = Array1::zeros(output_len);
271
272    for (i, block) in blocks.iter().enumerate() {
273        let reconstructed = imdct(block, window.clone())?;
274        let start_idx = i * hop_size;
275
276        // Add overlapping parts
277        for j in 0..block_size {
278            if start_idx + j < output_len {
279                output[start_idx + j] += reconstructed[j];
280            }
281        }
282    }
283
284    Ok(output)
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use crate::window::Window;
291    use scirs2_core::ndarray::array;
292
293    #[test]
294    fn test_mdct_basic() {
295        let signal = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
296        let mdct_result = mdct(&signal, 8, None).unwrap();
297
298        // MDCT should produce n/2 coefficients
299        assert_eq!(mdct_result.len(), 4);
300    }
301
302    #[test]
303    fn test_mdct_perfect_reconstruction() {
304        let signal = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
305        let window = Some(Window::Hann);
306
307        // Perform MDCT
308        let mdct_coeffs = mdct(&signal, 8, window.clone()).unwrap();
309
310        // Perform IMDCT
311        let reconstructed = imdct(&mdct_coeffs, window).unwrap();
312
313        // For proper reconstruction, we need overlapping blocks
314        // This is a simplified test that checks the transform works
315        assert_eq!(reconstructed.len(), 8);
316    }
317
318    #[test]
319    fn test_mdst_basic() {
320        let signal = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
321        let mdst_result = mdst(&signal, 8, None).unwrap();
322
323        // MDST should produce n/2 coefficients
324        assert_eq!(mdst_result.len(), 4);
325    }
326
327    #[test]
328    fn test_overlap_add() {
329        // Create overlapping blocks
330        let block1 = array![1.0, 2.0, 3.0, 4.0];
331        let block2 = array![2.0, 3.0, 4.0, 5.0];
332        let blocks = vec![block1, block2];
333
334        let result = mdct_overlap_add(&blocks, Some(Window::Hann), 4).unwrap();
335
336        // Check output length
337        assert_eq!(result.len(), 12); // (2-1)*4 + 8
338    }
339
340    #[test]
341    fn test_mdct_invalid_size() {
342        let signal = array![1.0, 2.0, 3.0];
343        let result = mdct(&signal, 3, None);
344        assert!(result.is_err());
345    }
346}