1use scirs2_core::ndarray::{Array1, ArrayBase, Data};
8use std::f64::consts::PI;
9
10use crate::error::{FFTError, FFTResult};
11use crate::window::Window;
12
13#[allow(dead_code)]
40pub fn mdct<S>(
41 x: &ArrayBase<S, scirs2_core::ndarray::Ix1>,
42 n: usize,
43 window: Option<Window>,
44) -> FFTResult<Array1<f64>>
45where
46 S: Data<Elem = f64>,
47{
48 if !n.is_multiple_of(2) {
49 return Err(FFTError::ValueError("MDCT size must be even".to_string()));
50 }
51
52 if x.len() != n {
53 return Err(FFTError::ValueError(format!(
54 "Input length {} does not match MDCT size {}",
55 x.len(),
56 n
57 )));
58 }
59
60 let half_n = n / 2;
61 let mut result = Array1::zeros(half_n);
62
63 let windowed = if let Some(win) = window {
65 let win_coeffs = crate::window::get_window(win, n, true)?;
66 x.to_owned() * &win_coeffs
67 } else {
68 x.to_owned()
69 };
70
71 for k in 0..half_n {
73 let mut sum = 0.0;
74 for n_idx in 0..n {
75 let angle = PI / n as f64 * (n_idx as f64 + 0.5 + half_n as f64) * (k as f64 + 0.5);
76 sum += windowed[n_idx] * angle.cos();
77 }
78 result[k] = sum * (2.0 / n as f64).sqrt();
79 }
80
81 Ok(result)
82}
83
84#[allow(dead_code)]
111pub fn imdct<S>(
112 x: &ArrayBase<S, scirs2_core::ndarray::Ix1>,
113 window: Option<Window>,
114) -> FFTResult<Array1<f64>>
115where
116 S: Data<Elem = f64>,
117{
118 let half_n = x.len();
119 let n = half_n * 2;
120 let mut result = Array1::zeros(n);
121
122 for n_idx in 0..n {
124 let mut sum = 0.0;
125 for k in 0..half_n {
126 let angle = PI / n as f64 * (n_idx as f64 + 0.5 + half_n as f64) * (k as f64 + 0.5);
127 sum += x[k] * angle.cos();
128 }
129 result[n_idx] = sum * (2.0 / n as f64).sqrt();
130 }
131
132 if let Some(win) = window {
134 let win_coeffs = crate::window::get_window(win, n, true)?;
135 result *= &win_coeffs;
136 }
137
138 Ok(result)
139}
140
141#[allow(dead_code)]
156pub fn mdst<S>(
157 x: &ArrayBase<S, scirs2_core::ndarray::Ix1>,
158 n: usize,
159 window: Option<Window>,
160) -> FFTResult<Array1<f64>>
161where
162 S: Data<Elem = f64>,
163{
164 if !n.is_multiple_of(2) {
165 return Err(FFTError::ValueError("MDST size must be even".to_string()));
166 }
167
168 if x.len() != n {
169 return Err(FFTError::ValueError(format!(
170 "Input length {} does not match MDST size {}",
171 x.len(),
172 n
173 )));
174 }
175
176 let half_n = n / 2;
177 let mut result = Array1::zeros(half_n);
178
179 let windowed = if let Some(win) = window {
181 let win_coeffs = crate::window::get_window(win, n, true)?;
182 x.to_owned() * &win_coeffs
183 } else {
184 x.to_owned()
185 };
186
187 for k in 0..half_n {
189 let mut sum = 0.0;
190 for n_idx in 0..n {
191 let angle = PI / n as f64 * (n_idx as f64 + 0.5 + half_n as f64) * (k as f64 + 0.5);
192 sum += windowed[n_idx] * angle.sin();
193 }
194 result[k] = sum * (2.0 / n as f64).sqrt();
195 }
196
197 Ok(result)
198}
199
200#[allow(dead_code)]
213pub fn imdst<S>(
214 x: &ArrayBase<S, scirs2_core::ndarray::Ix1>,
215 window: Option<Window>,
216) -> FFTResult<Array1<f64>>
217where
218 S: Data<Elem = f64>,
219{
220 let half_n = x.len();
221 let n = half_n * 2;
222 let mut result = Array1::zeros(n);
223
224 for n_idx in 0..n {
226 let mut sum = 0.0;
227 for k in 0..half_n {
228 let angle = PI / n as f64 * (n_idx as f64 + 0.5 + half_n as f64) * (k as f64 + 0.5);
229 sum += x[k] * angle.sin();
230 }
231 result[n_idx] = sum * (2.0 / n as f64).sqrt();
232 }
233
234 if let Some(win) = window {
236 let win_coeffs = crate::window::get_window(win, n, true)?;
237 result *= &win_coeffs;
238 }
239
240 Ok(result)
241}
242
243#[allow(dead_code)]
258pub fn mdct_overlap_add(
259 blocks: &[Array1<f64>],
260 window: Option<Window>,
261 hop_size: usize,
262) -> FFTResult<Array1<f64>> {
263 if blocks.is_empty() {
264 return Err(FFTError::ValueError("No blocks provided".to_string()));
265 }
266
267 let block_size = blocks[0].len() * 2;
268 let n_blocks = blocks.len();
269 let output_len = (n_blocks - 1) * hop_size + block_size;
270 let mut output = Array1::zeros(output_len);
271
272 for (i, block) in blocks.iter().enumerate() {
273 let reconstructed = imdct(block, window.clone())?;
274 let start_idx = i * hop_size;
275
276 for j in 0..block_size {
278 if start_idx + j < output_len {
279 output[start_idx + j] += reconstructed[j];
280 }
281 }
282 }
283
284 Ok(output)
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use crate::window::Window;
291 use scirs2_core::ndarray::array;
292
293 #[test]
294 fn test_mdct_basic() {
295 let signal = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
296 let mdct_result = mdct(&signal, 8, None).unwrap();
297
298 assert_eq!(mdct_result.len(), 4);
300 }
301
302 #[test]
303 fn test_mdct_perfect_reconstruction() {
304 let signal = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
305 let window = Some(Window::Hann);
306
307 let mdct_coeffs = mdct(&signal, 8, window.clone()).unwrap();
309
310 let reconstructed = imdct(&mdct_coeffs, window).unwrap();
312
313 assert_eq!(reconstructed.len(), 8);
316 }
317
318 #[test]
319 fn test_mdst_basic() {
320 let signal = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
321 let mdst_result = mdst(&signal, 8, None).unwrap();
322
323 assert_eq!(mdst_result.len(), 4);
325 }
326
327 #[test]
328 fn test_overlap_add() {
329 let block1 = array![1.0, 2.0, 3.0, 4.0];
331 let block2 = array![2.0, 3.0, 4.0, 5.0];
332 let blocks = vec![block1, block2];
333
334 let result = mdct_overlap_add(&blocks, Some(Window::Hann), 4).unwrap();
335
336 assert_eq!(result.len(), 12); }
339
340 #[test]
341 fn test_mdct_invalid_size() {
342 let signal = array![1.0, 2.0, 3.0];
343 let result = mdct(&signal, 3, None);
344 assert!(result.is_err());
345 }
346}