Skip to main content

scirs2_transform/signal_transforms/
cqt.rs

1//! Constant-Q Transform (CQT) and Chromagram Implementation
2//!
3//! Provides musically-motivated time-frequency analysis with logarithmic frequency spacing.
4
5use crate::error::{Result, TransformError};
6use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use scirs2_core::numeric::Complex;
8use scirs2_fft::fft;
9use std::f64::consts::PI;
10
11/// CQT configuration
12#[derive(Debug, Clone)]
13pub struct CQTConfig {
14    /// Sampling rate in Hz
15    pub sample_rate: f64,
16    /// Hop size in samples
17    pub hop_size: usize,
18    /// Minimum frequency in Hz
19    pub fmin: f64,
20    /// Number of bins per octave
21    pub bins_per_octave: usize,
22    /// Number of octaves
23    pub n_octaves: usize,
24    /// Filter quality factor
25    pub q_factor: f64,
26    /// Window type
27    pub window: WindowFunction,
28}
29
30impl Default for CQTConfig {
31    fn default() -> Self {
32        CQTConfig {
33            sample_rate: 22050.0,
34            hop_size: 512,
35            fmin: 32.7, // C1
36            bins_per_octave: 12,
37            n_octaves: 7,
38            q_factor: 1.0,
39            window: WindowFunction::Hann,
40        }
41    }
42}
43
44/// Window functions for CQT
45#[derive(Debug, Clone, Copy, PartialEq)]
46pub enum WindowFunction {
47    /// Hann window
48    Hann,
49    /// Hamming window
50    Hamming,
51    /// Blackman window
52    Blackman,
53}
54
55impl WindowFunction {
56    /// Generate window of given length
57    fn generate(&self, n: usize) -> Array1<f64> {
58        match self {
59            WindowFunction::Hann => Array1::from_vec(
60                (0..n)
61                    .map(|i| 0.5 * (1.0 - (2.0 * PI * i as f64 / (n - 1) as f64).cos()))
62                    .collect(),
63            ),
64            WindowFunction::Hamming => Array1::from_vec(
65                (0..n)
66                    .map(|i| 0.54 - 0.46 * (2.0 * PI * i as f64 / (n - 1) as f64).cos())
67                    .collect(),
68            ),
69            WindowFunction::Blackman => Array1::from_vec(
70                (0..n)
71                    .map(|i| {
72                        let angle = 2.0 * PI * i as f64 / (n - 1) as f64;
73                        0.42 - 0.5 * angle.cos() + 0.08 * (2.0 * angle).cos()
74                    })
75                    .collect(),
76            ),
77        }
78    }
79}
80
81/// Constant-Q Transform
82#[derive(Debug, Clone)]
83pub struct CQT {
84    config: CQTConfig,
85    kernel: Vec<Array1<Complex<f64>>>,
86    frequencies: Vec<f64>,
87}
88
89impl CQT {
90    /// Create a new CQT instance
91    pub fn new(config: CQTConfig) -> Result<Self> {
92        let n_bins = config.bins_per_octave * config.n_octaves;
93        let mut kernel = Vec::with_capacity(n_bins);
94        let mut frequencies = Vec::with_capacity(n_bins);
95
96        // Compute frequency for each bin
97        for k in 0..n_bins {
98            let freq = config.fmin * 2.0_f64.powf(k as f64 / config.bins_per_octave as f64);
99            frequencies.push(freq);
100
101            // Compute kernel for this frequency
102            let bin_kernel = Self::compute_kernel(
103                freq,
104                config.sample_rate,
105                config.q_factor,
106                config.bins_per_octave,
107                &config.window,
108            )?;
109            kernel.push(bin_kernel);
110        }
111
112        Ok(CQT {
113            config,
114            kernel,
115            frequencies,
116        })
117    }
118
119    /// Create with default configuration
120    pub fn default() -> Result<Self> {
121        Self::new(CQTConfig::default())
122    }
123
124    /// Compute CQT kernel for a specific frequency
125    fn compute_kernel(
126        freq: f64,
127        sample_rate: f64,
128        q_factor: f64,
129        bins_per_octave: usize,
130        window: &WindowFunction,
131    ) -> Result<Array1<Complex<f64>>> {
132        // Calculate Q value
133        let q = q_factor / (2.0_f64.powf(1.0 / bins_per_octave as f64) - 1.0);
134
135        // Calculate filter length
136        let filter_len = ((q * sample_rate / freq).ceil() as usize).max(1);
137
138        // Generate window
139        let window_vec = window.generate(filter_len);
140
141        // Create complex exponential
142        let mut kernel = Array1::from_elem(filter_len, Complex::new(0.0, 0.0));
143
144        for n in 0..filter_len {
145            let phase = 2.0 * PI * freq * n as f64 / sample_rate;
146            let win_val = window_vec[n];
147            kernel[n] = Complex::new(win_val * phase.cos(), -win_val * phase.sin());
148        }
149
150        // Normalize
151        let norm: f64 = kernel.iter().map(|c| c.norm_sqr()).sum::<f64>().sqrt();
152        if norm > 1e-10 {
153            for val in kernel.iter_mut() {
154                *val = *val / norm;
155            }
156        }
157
158        Ok(kernel)
159    }
160
161    /// Compute the CQT of a signal
162    pub fn transform(&self, signal: &ArrayView1<f64>) -> Result<Array2<Complex<f64>>> {
163        let signal_len = signal.len();
164        if signal_len == 0 {
165            return Err(TransformError::InvalidInput("Empty signal".to_string()));
166        }
167
168        let n_bins = self.kernel.len();
169        let n_frames = (signal_len / self.config.hop_size).max(1);
170
171        let mut cqt = Array2::from_elem((n_bins, n_frames), Complex::new(0.0, 0.0));
172
173        // Process each frame
174        for frame_idx in 0..n_frames {
175            let frame_start = frame_idx * self.config.hop_size;
176
177            // Apply each frequency kernel
178            for (bin_idx, kernel) in self.kernel.iter().enumerate() {
179                let mut response = Complex::new(0.0, 0.0);
180
181                for (k, &kernel_val) in kernel.iter().enumerate() {
182                    let signal_idx = frame_start + k;
183                    if signal_idx < signal_len {
184                        response = response + kernel_val * signal[signal_idx];
185                    }
186                }
187
188                cqt[[bin_idx, frame_idx]] = response;
189            }
190        }
191
192        Ok(cqt)
193    }
194
195    /// Compute magnitude of CQT
196    pub fn magnitude(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
197        let cqt = self.transform(signal)?;
198        let (n_bins, n_frames) = cqt.dim();
199
200        let mut magnitude = Array2::zeros((n_bins, n_frames));
201        for i in 0..n_bins {
202            for j in 0..n_frames {
203                magnitude[[i, j]] = cqt[[i, j]].norm();
204            }
205        }
206
207        Ok(magnitude)
208    }
209
210    /// Compute power (magnitude squared) of CQT
211    pub fn power(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
212        let cqt = self.transform(signal)?;
213        let (n_bins, n_frames) = cqt.dim();
214
215        let mut power = Array2::zeros((n_bins, n_frames));
216        for i in 0..n_bins {
217            for j in 0..n_frames {
218                power[[i, j]] = cqt[[i, j]].norm_sqr();
219            }
220        }
221
222        Ok(power)
223    }
224
225    /// Get the frequencies corresponding to each bin
226    pub fn frequencies(&self) -> &[f64] {
227        &self.frequencies
228    }
229
230    /// Get the configuration
231    pub fn config(&self) -> &CQTConfig {
232        &self.config
233    }
234
235    /// Get time bins in seconds
236    pub fn time_bins(&self, signal_len: usize) -> Vec<f64> {
237        let n_frames = (signal_len / self.config.hop_size).max(1);
238        (0..n_frames)
239            .map(|i| (i * self.config.hop_size) as f64 / self.config.sample_rate)
240            .collect()
241    }
242}
243
244/// Chromagram (pitch class profile)
245#[derive(Debug, Clone)]
246pub struct Chromagram {
247    cqt: CQT,
248    n_chroma: usize,
249}
250
251impl Chromagram {
252    /// Create a new chromagram with CQT configuration
253    pub fn new(config: CQTConfig) -> Result<Self> {
254        // Ensure bins_per_octave is a multiple of 12 for clean folding
255        let adjusted_config = CQTConfig {
256            bins_per_octave: 12 * ((config.bins_per_octave + 11) / 12),
257            ..config
258        };
259
260        let cqt = CQT::new(adjusted_config)?;
261
262        Ok(Chromagram { cqt, n_chroma: 12 })
263    }
264
265    /// Create with default configuration
266    pub fn default() -> Result<Self> {
267        Self::new(CQTConfig::default())
268    }
269
270    /// Compute chromagram from signal
271    pub fn compute(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
272        // Get CQT magnitude
273        let cqt_mag = self.cqt.magnitude(signal)?;
274        let (n_bins, n_frames) = cqt_mag.dim();
275
276        // Fold into 12 chroma bins
277        let mut chroma = Array2::zeros((self.n_chroma, n_frames));
278
279        for i in 0..n_bins {
280            let chroma_bin = i % self.n_chroma;
281            for j in 0..n_frames {
282                chroma[[chroma_bin, j]] += cqt_mag[[i, j]];
283            }
284        }
285
286        // Normalize each frame
287        for j in 0..n_frames {
288            let mut sum = 0.0;
289            for i in 0..self.n_chroma {
290                sum += chroma[[i, j]];
291            }
292            if sum > 1e-10 {
293                for i in 0..self.n_chroma {
294                    chroma[[i, j]] /= sum;
295                }
296            }
297        }
298
299        Ok(chroma)
300    }
301
302    /// Compute energy-normalized chromagram
303    pub fn compute_normalized(&self, signal: &ArrayView1<f64>) -> Result<Array2<f64>> {
304        let cqt_power = self.cqt.power(signal)?;
305        let (n_bins, n_frames) = cqt_power.dim();
306
307        // Fold into 12 chroma bins
308        let mut chroma = Array2::zeros((self.n_chroma, n_frames));
309
310        for i in 0..n_bins {
311            let chroma_bin = i % self.n_chroma;
312            for j in 0..n_frames {
313                chroma[[chroma_bin, j]] += cqt_power[[i, j]];
314            }
315        }
316
317        // L2 normalize each frame
318        for j in 0..n_frames {
319            let mut norm: f64 = 0.0;
320            for i in 0..self.n_chroma {
321                norm += chroma[[i, j]] * chroma[[i, j]];
322            }
323            norm = norm.sqrt();
324
325            if norm > 1e-10 {
326                for i in 0..self.n_chroma {
327                    chroma[[i, j]] /= norm;
328                }
329            }
330        }
331
332        Ok(chroma)
333    }
334
335    /// Get chroma labels (note names)
336    pub fn chroma_labels() -> Vec<&'static str> {
337        vec![
338            "C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B",
339        ]
340    }
341
342    /// Get the underlying CQT
343    pub fn cqt(&self) -> &CQT {
344        &self.cqt
345    }
346
347    /// Get time bins in seconds
348    pub fn time_bins(&self, signal_len: usize) -> Vec<f64> {
349        self.cqt.time_bins(signal_len)
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use approx::assert_abs_diff_eq;
357
358    #[test]
359    fn test_cqt_creation() -> Result<()> {
360        let cqt = CQT::default()?;
361
362        assert!(cqt.frequencies().len() > 0);
363        assert_eq!(cqt.frequencies().len(), cqt.kernel.len());
364
365        // Check that frequencies are logarithmically spaced
366        let freqs = cqt.frequencies();
367        for i in 1..freqs.len() {
368            let ratio = freqs[i] / freqs[i - 1];
369            assert!(ratio > 1.0);
370        }
371
372        Ok(())
373    }
374
375    #[test]
376    fn test_cqt_transform() -> Result<()> {
377        let signal = Array1::from_vec((0..22050).map(|i| (i as f64 * 0.01).sin()).collect());
378        let cqt = CQT::default()?;
379
380        let result = cqt.transform(&signal.view())?;
381
382        assert!(result.dim().0 > 0);
383        assert!(result.dim().1 > 0);
384
385        Ok(())
386    }
387
388    #[test]
389    fn test_cqt_magnitude() -> Result<()> {
390        let signal = Array1::from_vec(
391            (0..22050)
392                .map(|i| {
393                    // Simple tone at A4 (440 Hz)
394                    (2.0 * PI * 440.0 * i as f64 / 22050.0).sin()
395                })
396                .collect(),
397        );
398
399        let config = CQTConfig {
400            sample_rate: 22050.0,
401            fmin: 55.0, // A1
402            bins_per_octave: 12,
403            n_octaves: 6,
404            ..Default::default()
405        };
406
407        let cqt = CQT::new(config)?;
408        let mag = cqt.magnitude(&signal.view())?;
409
410        assert!(mag.dim().0 > 0);
411        assert!(mag.dim().1 > 0);
412        assert!(mag.iter().all(|&x| x >= 0.0));
413
414        Ok(())
415    }
416
417    #[test]
418    fn test_chromagram_creation() -> Result<()> {
419        let chroma = Chromagram::default()?;
420
421        assert_eq!(chroma.n_chroma, 12);
422
423        Ok(())
424    }
425
426    #[test]
427    fn test_chromagram_compute() -> Result<()> {
428        let signal = Array1::from_vec((0..22050).map(|i| (i as f64 * 0.01).sin()).collect());
429        let chroma = Chromagram::default()?;
430
431        let result = chroma.compute(&signal.view())?;
432
433        assert_eq!(result.dim().0, 12);
434        assert!(result.dim().1 > 0);
435
436        // Check normalization (each column should sum to ~1)
437        for j in 0..result.dim().1 {
438            let mut sum = 0.0;
439            for i in 0..12 {
440                sum += result[[i, j]];
441            }
442            if sum > 1e-10 {
443                assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-6);
444            }
445        }
446
447        Ok(())
448    }
449
450    #[test]
451    fn test_chromagram_normalized() -> Result<()> {
452        let signal = Array1::from_vec((0..22050).map(|i| (i as f64 * 0.01).sin()).collect());
453        let chroma = Chromagram::default()?;
454
455        let result = chroma.compute_normalized(&signal.view())?;
456
457        assert_eq!(result.dim().0, 12);
458        assert!(result.dim().1 > 0);
459
460        // Check L2 normalization
461        for j in 0..result.dim().1 {
462            let mut norm = 0.0;
463            for i in 0..12 {
464                norm += result[[i, j]] * result[[i, j]];
465            }
466            if norm > 1e-10 {
467                assert_abs_diff_eq!(norm.sqrt(), 1.0, epsilon = 1e-6);
468            }
469        }
470
471        Ok(())
472    }
473
474    #[test]
475    fn test_chroma_labels() {
476        let labels = Chromagram::chroma_labels();
477        assert_eq!(labels.len(), 12);
478        assert_eq!(labels[0], "C");
479        assert_eq!(labels[11], "B");
480    }
481
482    #[test]
483    fn test_window_functions() {
484        let hann = WindowFunction::Hann.generate(64);
485        assert_eq!(hann.len(), 64);
486        assert_abs_diff_eq!(hann[0], 0.0, epsilon = 1e-10);
487        assert_abs_diff_eq!(hann[63], 0.0, epsilon = 1e-10);
488
489        let hamming = WindowFunction::Hamming.generate(64);
490        assert_eq!(hamming.len(), 64);
491        assert!(hamming[0] > 0.0);
492
493        let blackman = WindowFunction::Blackman.generate(64);
494        assert_eq!(blackman.len(), 64);
495        assert_abs_diff_eq!(blackman[0], 0.0, epsilon = 1e-2);
496    }
497
498    #[test]
499    fn test_cqt_time_bins() -> Result<()> {
500        let cqt = CQT::default()?;
501        let time_bins = cqt.time_bins(22050);
502
503        assert!(time_bins.len() > 0);
504        assert_abs_diff_eq!(time_bins[0], 0.0, epsilon = 1e-10);
505
506        // Check that time bins are uniformly spaced
507        if time_bins.len() > 1 {
508            let dt = time_bins[1] - time_bins[0];
509            for i in 2..time_bins.len() {
510                assert_abs_diff_eq!(time_bins[i] - time_bins[i - 1], dt, epsilon = 1e-6);
511            }
512        }
513
514        Ok(())
515    }
516}