scirs2_transform/signal_transforms/
cqt.rs1use crate::error::{Result, TransformError};
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::Complex;
8use scirs2_fft::fft;
9use std::f64::consts::PI;
10
11#[derive(Debug, Clone)]
13pub struct CQTConfig {
14 pub sample_rate: f64,
16 pub hop_size: usize,
18 pub fmin: f64,
20 pub bins_per_octave: usize,
22 pub n_octaves: usize,
24 pub q_factor: f64,
26 pub window: WindowFunction,
28}
29
30impl Default for CQTConfig {
31 fn default() -> Self {
32 CQTConfig {
33 sample_rate: 22050.0,
34 hop_size: 512,
35 fmin: 32.7, bins_per_octave: 12,
37 n_octaves: 7,
38 q_factor: 1.0,
39 window: WindowFunction::Hann,
40 }
41 }
42}
43
44#[derive(Debug, Clone, Copy, PartialEq)]
46pub enum WindowFunction {
47 Hann,
49 Hamming,
51 Blackman,
53}
54
55impl WindowFunction {
56 fn generate(&self, n: usize) -> Array1<f64> {
58 match self {
59 WindowFunction::Hann => Array1::from_vec(
60 (0..n)
61 .map(|i| 0.5 * (1.0 - (2.0 * PI * i as f64 / (n - 1) as f64).cos()))
62 .collect(),
63 ),
64 WindowFunction::Hamming => Array1::from_vec(
65 (0..n)
66 .map(|i| 0.54 - 0.46 * (2.0 * PI * i as f64 / (n - 1) as f64).cos())
67 .collect(),
68 ),
69 WindowFunction::Blackman => Array1::from_vec(
70 (0..n)
71 .map(|i| {
72 let angle = 2.0 * PI * i as f64 / (n - 1) as f64;
73 0.42 - 0.5 * angle.cos() + 0.08 * (2.0 * angle).cos()
74 })
75 .collect(),
76 ),
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct CQT {
84 config: CQTConfig,
85 kernel: Vec<Array1<Complex<f64>>>,
86 frequencies: Vec<f64>,
87}
88
89impl CQT {
90 pub fn new(config: CQTConfig) -> Result<Self> {
92 let n_bins = config.bins_per_octave * config.n_octaves;
93 let mut kernel = Vec::with_capacity(n_bins);
94 let mut frequencies = Vec::with_capacity(n_bins);
95
96 for k in 0..n_bins {
98 let freq = config.fmin * 2.0_f64.powf(k as f64 / config.bins_per_octave as f64);
99 frequencies.push(freq);
100
101 let bin_kernel = Self::compute_kernel(
103 freq,
104 config.sample_rate,
105 config.q_factor,
106 config.bins_per_octave,
107 &config.window,
108 )?;
109 kernel.push(bin_kernel);
110 }
111
112 Ok(CQT {
113 config,
114 kernel,
115 frequencies,
116 })
117 }
118
119 pub fn default() -> Result<Self> {
121 Self::new(CQTConfig::default())
122 }
123
124 fn compute_kernel(
126 freq: f64,
127 sample_rate: f64,
128 q_factor: f64,
129 bins_per_octave: usize,
130 window: &WindowFunction,
131 ) -> Result<Array1<Complex<f64>>> {
132 let q = q_factor / (2.0_f64.powf(1.0 / bins_per_octave as f64) - 1.0);
134
135 let filter_len = ((q * sample_rate / freq).ceil() as usize).max(1);
137
138 let window_vec = window.generate(filter_len);
140
141 let mut kernel = Array1::from_elem(filter_len, Complex::new(0.0, 0.0));
143
144 for n in 0..filter_len {
145 let phase = 2.0 * PI * freq * n as f64 / sample_rate;
146 let win_val = window_vec[n];
147 kernel[n] = Complex::new(win_val * phase.cos(), -win_val * phase.sin());
148 }
149
150 let norm: f64 = kernel.iter().map(|c| c.norm_sqr()).sum::<f64>().sqrt();
152 if norm > 1e-10 {
153 for val in kernel.iter_mut() {
154 *val = *val / norm;
155 }
156 }
157
158 Ok(kernel)
159 }
160
161 pub fn transform(&self, signal: &ArrayView1<f64>) -> Result<Array2<Complex<f64>>> {
163 let signal_len = signal.len();
164 if signal_len == 0 {
165 return Err(TransformError::InvalidInput("Empty signal".to_string()));
166 }
167
168 let n_bins = self.kernel.len();
169 let n_frames = (signal_len / self.config.hop_size).max(1);
170
171 let mut cqt = Array2::from_elem((n_bins, n_frames), Complex::new(0.0, 0.0));
172
173 for frame_idx in 0..n_frames {
175 let frame_start = frame_idx * self.config.hop_size;
176
177 for (bin_idx, kernel) in self.kernel.iter().enumerate() {
179 let mut response = Complex::new(0.0, 0.0);
180
181 for (k, &kernel_val) in kernel.iter().enumerate() {
182 let signal_idx = frame_start + k;
183 if signal_idx < signal_len {
184 response = response + kernel_val * signal[signal_idx];
185 }
186 }
187
188 cqt[[bin_idx, frame_idx]] = response;
189 }
190 }
191
192 Ok(cqt)
193 }
194
195 pub fn magnitude(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
197 let cqt = self.transform(signal)?;
198 let (n_bins, n_frames) = cqt.dim();
199
200 let mut magnitude = Array2::zeros((n_bins, n_frames));
201 for i in 0..n_bins {
202 for j in 0..n_frames {
203 magnitude[[i, j]] = cqt[[i, j]].norm();
204 }
205 }
206
207 Ok(magnitude)
208 }
209
210 pub fn power(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
212 let cqt = self.transform(signal)?;
213 let (n_bins, n_frames) = cqt.dim();
214
215 let mut power = Array2::zeros((n_bins, n_frames));
216 for i in 0..n_bins {
217 for j in 0..n_frames {
218 power[[i, j]] = cqt[[i, j]].norm_sqr();
219 }
220 }
221
222 Ok(power)
223 }
224
225 pub fn frequencies(&self) -> &[f64] {
227 &self.frequencies
228 }
229
230 pub fn config(&self) -> &CQTConfig {
232 &self.config
233 }
234
235 pub fn time_bins(&self, signal_len: usize) -> Vec<f64> {
237 let n_frames = (signal_len / self.config.hop_size).max(1);
238 (0..n_frames)
239 .map(|i| (i * self.config.hop_size) as f64 / self.config.sample_rate)
240 .collect()
241 }
242}
243
244#[derive(Debug, Clone)]
246pub struct Chromagram {
247 cqt: CQT,
248 n_chroma: usize,
249}
250
251impl Chromagram {
252 pub fn new(config: CQTConfig) -> Result<Self> {
254 let adjusted_config = CQTConfig {
256 bins_per_octave: 12 * ((config.bins_per_octave + 11) / 12),
257 ..config
258 };
259
260 let cqt = CQT::new(adjusted_config)?;
261
262 Ok(Chromagram { cqt, n_chroma: 12 })
263 }
264
265 pub fn default() -> Result<Self> {
267 Self::new(CQTConfig::default())
268 }
269
270 pub fn compute(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
272 let cqt_mag = self.cqt.magnitude(signal)?;
274 let (n_bins, n_frames) = cqt_mag.dim();
275
276 let mut chroma = Array2::zeros((self.n_chroma, n_frames));
278
279 for i in 0..n_bins {
280 let chroma_bin = i % self.n_chroma;
281 for j in 0..n_frames {
282 chroma[[chroma_bin, j]] += cqt_mag[[i, j]];
283 }
284 }
285
286 for j in 0..n_frames {
288 let mut sum = 0.0;
289 for i in 0..self.n_chroma {
290 sum += chroma[[i, j]];
291 }
292 if sum > 1e-10 {
293 for i in 0..self.n_chroma {
294 chroma[[i, j]] /= sum;
295 }
296 }
297 }
298
299 Ok(chroma)
300 }
301
302 pub fn compute_normalized(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
304 let cqt_power = self.cqt.power(signal)?;
305 let (n_bins, n_frames) = cqt_power.dim();
306
307 let mut chroma = Array2::zeros((self.n_chroma, n_frames));
309
310 for i in 0..n_bins {
311 let chroma_bin = i % self.n_chroma;
312 for j in 0..n_frames {
313 chroma[[chroma_bin, j]] += cqt_power[[i, j]];
314 }
315 }
316
317 for j in 0..n_frames {
319 let mut norm: f64 = 0.0;
320 for i in 0..self.n_chroma {
321 norm += chroma[[i, j]] * chroma[[i, j]];
322 }
323 norm = norm.sqrt();
324
325 if norm > 1e-10 {
326 for i in 0..self.n_chroma {
327 chroma[[i, j]] /= norm;
328 }
329 }
330 }
331
332 Ok(chroma)
333 }
334
335 pub fn chroma_labels() -> Vec<&'static str> {
337 vec![
338 "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B",
339 ]
340 }
341
342 pub fn cqt(&self) -> &CQT {
344 &self.cqt
345 }
346
347 pub fn time_bins(&self, signal_len: usize) -> Vec<f64> {
349 self.cqt.time_bins(signal_len)
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use approx::assert_abs_diff_eq;
357
358 #[test]
359 fn test_cqt_creation() -> Result<()> {
360 let cqt = CQT::default()?;
361
362 assert!(cqt.frequencies().len() > 0);
363 assert_eq!(cqt.frequencies().len(), cqt.kernel.len());
364
365 let freqs = cqt.frequencies();
367 for i in 1..freqs.len() {
368 let ratio = freqs[i] / freqs[i - 1];
369 assert!(ratio > 1.0);
370 }
371
372 Ok(())
373 }
374
375 #[test]
376 fn test_cqt_transform() -> Result<()> {
377 let signal = Array1::from_vec((0..22050).map(|i| (i as f64 * 0.01).sin()).collect());
378 let cqt = CQT::default()?;
379
380 let result = cqt.transform(&signal.view())?;
381
382 assert!(result.dim().0 > 0);
383 assert!(result.dim().1 > 0);
384
385 Ok(())
386 }
387
388 #[test]
389 fn test_cqt_magnitude() -> Result<()> {
390 let signal = Array1::from_vec(
391 (0..22050)
392 .map(|i| {
393 (2.0 * PI * 440.0 * i as f64 / 22050.0).sin()
395 })
396 .collect(),
397 );
398
399 let config = CQTConfig {
400 sample_rate: 22050.0,
401 fmin: 55.0, bins_per_octave: 12,
403 n_octaves: 6,
404 ..Default::default()
405 };
406
407 let cqt = CQT::new(config)?;
408 let mag = cqt.magnitude(&signal.view())?;
409
410 assert!(mag.dim().0 > 0);
411 assert!(mag.dim().1 > 0);
412 assert!(mag.iter().all(|&x| x >= 0.0));
413
414 Ok(())
415 }
416
417 #[test]
418 fn test_chromagram_creation() -> Result<()> {
419 let chroma = Chromagram::default()?;
420
421 assert_eq!(chroma.n_chroma, 12);
422
423 Ok(())
424 }
425
426 #[test]
427 fn test_chromagram_compute() -> Result<()> {
428 let signal = Array1::from_vec((0..22050).map(|i| (i as f64 * 0.01).sin()).collect());
429 let chroma = Chromagram::default()?;
430
431 let result = chroma.compute(&signal.view())?;
432
433 assert_eq!(result.dim().0, 12);
434 assert!(result.dim().1 > 0);
435
436 for j in 0..result.dim().1 {
438 let mut sum = 0.0;
439 for i in 0..12 {
440 sum += result[[i, j]];
441 }
442 if sum > 1e-10 {
443 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
444 }
445 }
446
447 Ok(())
448 }
449
450 #[test]
451 fn test_chromagram_normalized() -> Result<()> {
452 let signal = Array1::from_vec((0..22050).map(|i| (i as f64 * 0.01).sin()).collect());
453 let chroma = Chromagram::default()?;
454
455 let result = chroma.compute_normalized(&signal.view())?;
456
457 assert_eq!(result.dim().0, 12);
458 assert!(result.dim().1 > 0);
459
460 for j in 0..result.dim().1 {
462 let mut norm = 0.0;
463 for i in 0..12 {
464 norm += result[[i, j]] * result[[i, j]];
465 }
466 if norm > 1e-10 {
467 assert_abs_diff_eq!(norm.sqrt(), 1.0, epsilon = 1e-6);
468 }
469 }
470
471 Ok(())
472 }
473
474 #[test]
475 fn test_chroma_labels() {
476 let labels = Chromagram::chroma_labels();
477 assert_eq!(labels.len(), 12);
478 assert_eq!(labels[0], "C");
479 assert_eq!(labels[11], "B");
480 }
481
482 #[test]
483 fn test_window_functions() {
484 let hann = WindowFunction::Hann.generate(64);
485 assert_eq!(hann.len(), 64);
486 assert_abs_diff_eq!(hann[0], 0.0, epsilon = 1e-10);
487 assert_abs_diff_eq!(hann[63], 0.0, epsilon = 1e-10);
488
489 let hamming = WindowFunction::Hamming.generate(64);
490 assert_eq!(hamming.len(), 64);
491 assert!(hamming[0] > 0.0);
492
493 let blackman = WindowFunction::Blackman.generate(64);
494 assert_eq!(blackman.len(), 64);
495 assert_abs_diff_eq!(blackman[0], 0.0, epsilon = 1e-2);
496 }
497
498 #[test]
499 fn test_cqt_time_bins() -> Result<()> {
500 let cqt = CQT::default()?;
501 let time_bins = cqt.time_bins(22050);
502
503 assert!(time_bins.len() > 0);
504 assert_abs_diff_eq!(time_bins[0], 0.0, epsilon = 1e-10);
505
506 if time_bins.len() > 1 {
508 let dt = time_bins[1] - time_bins[0];
509 for i in 2..time_bins.len() {
510 assert_abs_diff_eq!(time_bins[i] - time_bins[i - 1], dt, epsilon = 1e-6);
511 }
512 }
513
514 Ok(())
515 }
516}