spectrograms/spectrogram.rs
1use std::marker::PhantomData;
2use std::num::NonZeroUsize;
3use std::ops::{Deref, DerefMut};
4
5use ndarray::{Array1, Array2};
6use non_empty_slice::{NonEmptySlice, NonEmptyVec, non_empty_vec};
7use num_complex::Complex;
8
9#[cfg(feature = "python")]
10use pyo3::prelude::*;
11
12use crate::cqt::CqtKernel;
13use crate::erb::ErbFilterbank;
14use crate::{
15 CqtParams, ErbParams, R2cPlan, SpectrogramError, SpectrogramResult, WindowType,
16 min_max_single_pass, nzu,
17};
18const EPS: f64 = 1e-12;
19
20//
21// ========================
22// Sparse Matrix for efficient filterbank multiplication
23// ========================
24//
25
26/// Row-wise sparse matrix optimized for matrix-vector multiplication.
27///
28/// This structure stores sparse data as a vector of vectors, where each row maintains
29/// its own list of non-zero values and corresponding column indices. This is more
30/// flexible than traditional CSR format and allows efficient row-by-row construction.
31///
32/// This structure is designed for matrices with very few non-zero values per row,
33/// such as mel filterbanks (triangular filters) and logarithmic frequency mappings
34/// (linear interpolation between 1-2 adjacent bins).
35///
36/// For typical spectrograms:
37/// - `LogHz` interpolation: Only 1-2 non-zeros per row (~99% sparse)
38/// - Mel filterbank: ~10-50 non-zeros per row depending on FFT size (~90-98% sparse)
39///
40/// By storing only non-zero values, we avoid wasting CPU cycles multiplying by zero,
41/// which can provide 10-100x speedup compared to dense matrix multiplication.
42#[derive(Debug, Clone)]
43struct SparseMatrix {
44 /// Number of rows
45 nrows: usize,
46 /// Number of columns
47 ncols: usize,
48 /// Non-zero values for each row (row-major order)
49 values: Vec<Vec<f64>>,
50 /// Column indices for each non-zero value
51 indices: Vec<Vec<usize>>,
52}
53
54impl SparseMatrix {
55 /// Create a new sparse matrix with the given dimensions.
56 fn new(nrows: usize, ncols: usize) -> Self {
57 Self {
58 nrows,
59 ncols,
60 values: vec![Vec::new(); nrows],
61 indices: vec![Vec::new(); nrows],
62 }
63 }
64
65 /// Set a value in the matrix. Only stores if value is non-zero.
66 ///
67 /// # Panics (debug mode only)
68 /// Panics in debug builds if row or col are out of bounds.
69 fn set(&mut self, row: usize, col: usize, value: f64) {
70 debug_assert!(
71 row < self.nrows && col < self.ncols,
72 "SparseMatrix index out of bounds: ({}, {}) for {}x{} matrix",
73 row,
74 col,
75 self.nrows,
76 self.ncols
77 );
78 if row >= self.nrows || col >= self.ncols {
79 return;
80 }
81
82 // Only store non-zero values (with small epsilon for numerical stability)
83 if value.abs() > 1e-10 {
84 self.values[row].push(value);
85 self.indices[row].push(col);
86 }
87 }
88
89 /// Get the number of rows.
90 const fn nrows(&self) -> usize {
91 self.nrows
92 }
93
94 /// Get the number of columns.
95 const fn ncols(&self) -> usize {
96 self.ncols
97 }
98
99 /// Perform sparse matrix-vector multiplication: out = self * input
100 /// This is much faster than dense multiplication when the matrix is sparse.
101 #[inline]
102 fn multiply_vec(&self, input: &[f64], out: &mut [f64]) {
103 debug_assert_eq!(input.len(), self.ncols);
104 debug_assert_eq!(out.len(), self.nrows);
105
106 for (row_idx, (row_values, row_indices)) in
107 self.values.iter().zip(&self.indices).enumerate()
108 {
109 let mut acc = 0.0;
110 for (&value, &col_idx) in row_values.iter().zip(row_indices) {
111 acc += value * input[col_idx];
112 }
113 out[row_idx] = acc;
114 }
115 }
116}
117
118// Linear frequency
119pub type LinearPowerSpectrogram = Spectrogram<LinearHz, Power>;
120pub type LinearMagnitudeSpectrogram = Spectrogram<LinearHz, Magnitude>;
121pub type LinearDbSpectrogram = Spectrogram<LinearHz, Decibels>;
122pub type LinearSpectrogram<AmpScale> = Spectrogram<LinearHz, AmpScale>;
123
124// Log-frequency (e.g. CQT-style)
125pub type LogHzPowerSpectrogram = Spectrogram<LogHz, Power>;
126pub type LogHzMagnitudeSpectrogram = Spectrogram<LogHz, Magnitude>;
127pub type LogHzDbSpectrogram = Spectrogram<LogHz, Decibels>;
128pub type LogHzSpectrogram<AmpScale> = Spectrogram<LogHz, AmpScale>;
129
130// ERB / gammatone
131pub type ErbPowerSpectrogram = Spectrogram<Erb, Power>;
132pub type ErbMagnitudeSpectrogram = Spectrogram<Erb, Magnitude>;
133pub type ErbDbSpectrogram = Spectrogram<Erb, Decibels>;
134pub type GammatonePowerSpectrogram = ErbPowerSpectrogram;
135pub type GammatoneMagnitudeSpectrogram = ErbMagnitudeSpectrogram;
136pub type GammatoneDbSpectrogram = ErbDbSpectrogram;
137pub type ErbSpectrogram<AmpScale> = Spectrogram<Erb, AmpScale>;
138pub type GammatoneSpectrogram<AmpScale> = ErbSpectrogram<AmpScale>;
139
140// Mel
141pub type MelMagnitudeSpectrogram = Spectrogram<Mel, Magnitude>;
142pub type MelPowerSpectrogram = Spectrogram<Mel, Power>;
143pub type MelDbSpectrogram = Spectrogram<Mel, Decibels>;
144pub type LogMelSpectrogram = MelDbSpectrogram;
145pub type MelSpectrogram<AmpScale> = Spectrogram<Mel, AmpScale>;
146
147// CQT
148pub type CqtPowerSpectrogram = Spectrogram<Cqt, Power>;
149pub type CqtMagnitudeSpectrogram = Spectrogram<Cqt, Magnitude>;
150pub type CqtDbSpectrogram = Spectrogram<Cqt, Decibels>;
151pub type CqtSpectrogram<AmpScale> = Spectrogram<Cqt, AmpScale>;
152
153use crate::fft_backend::r2c_output_size;
154
155/// A spectrogram plan is the compiled, reusable execution object.
156///
157/// It owns:
158/// - FFT plan (reusable)
159/// - window samples
160/// - mapping (identity / mel filterbank / etc.)
161/// - amplitude scaling config
162/// - workspace buffers to avoid allocations in hot loops
163///
164/// It computes one specific spectrogram type: `Spectrogram<FreqScale, AmpScale>`.
165///
166/// # Type Parameters
167///
168/// - `FreqScale`: Frequency scale type (e.g. `LinearHz`, `LogHz`, `Mel`, etc.)
169/// - `AmpScale`: Amplitude scaling type (e.g. `Power`, `Magnitude`, `Decibels`, etc.)
170pub struct SpectrogramPlan<FreqScale, AmpScale>
171where
172 AmpScale: AmpScaleSpec + 'static,
173 FreqScale: Copy + Clone + 'static,
174{
175 params: SpectrogramParams,
176
177 stft: StftPlan,
178 mapping: FrequencyMapping<FreqScale>,
179 scaling: AmplitudeScaling<AmpScale>,
180
181 freq_axis: FrequencyAxis<FreqScale>,
182 workspace: Workspace,
183
184 _amp: PhantomData<AmpScale>,
185}
186
187impl<FreqScale, AmpScale> SpectrogramPlan<FreqScale, AmpScale>
188where
189 AmpScale: AmpScaleSpec + 'static,
190 FreqScale: Copy + Clone + 'static,
191{
192 /// Get the spectrogram parameters used to create this plan.
193 ///
194 /// # Returns
195 ///
196 /// A reference to the `SpectrogramParams` used in this plan.
197 #[inline]
198 #[must_use]
199 pub const fn params(&self) -> &SpectrogramParams {
200 &self.params
201 }
202
203 /// Get the frequency axis for this spectrogram plan.
204 ///
205 /// # Returns
206 ///
207 /// A reference to the `FrequencyAxis<FreqScale>` used in this plan.
208 #[inline]
209 #[must_use]
210 pub const fn freq_axis(&self) -> &FrequencyAxis<FreqScale> {
211 &self.freq_axis
212 }
213
214 /// Compute a spectrogram for a mono signal.
215 ///
216 /// This function performs:
217 /// - framing + windowing
218 /// - FFT per frame
219 /// - magnitude/power
220 /// - frequency mapping (identity/mel/etc.)
221 /// - amplitude scaling (linear or dB)
222 ///
223 /// It allocates the output `Array2` once, but does not allocate per-frame.
224 ///
225 /// # Arguments
226 ///
227 /// * `samples` - Audio samples
228 ///
229 /// # Returns
230 ///
231 /// A `Spectrogram<FreqScale, AmpScale>` containing the computed spectrogram.
232 ///
233 /// # Errors
234 ///
235 /// Returns an error if STFT computation or mapping fails.
236 #[inline]
237 pub fn compute(
238 &mut self,
239 samples: &NonEmptySlice<f64>,
240 ) -> SpectrogramResult<Spectrogram<FreqScale, AmpScale>> {
241 let n_frames = self.stft.frame_count(samples.len())?;
242 let n_bins = self.mapping.output_bins();
243
244 // Create output matrix: (n_bins, n_frames)
245 let mut data = Array2::<f64>::zeros((n_bins.get(), n_frames.get()));
246
247 // Ensure workspace is correctly sized
248 self.workspace
249 .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
250
251 // Main loop: fill each frame (column)
252 for frame_idx in 0..n_frames.get() {
253 // CQT needs unwindowed frames because its kernels already contain windowing.
254 // Other mappings use the FFT spectrum, so they need the windowed frame.
255 if self.mapping.kind.needs_unwindowed_frame() {
256 // Fill frame without windowing (for CQT)
257 self.stft
258 .fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
259 } else {
260 // Compute windowed frame spectrum (for FFT-based mappings)
261 self.stft
262 .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
263 }
264
265 // mapping: spectrum(out_len) -> mapped(n_bins)
266 // For CQT, this uses workspace.frame; for others, workspace.spectrum
267 // For ERB, we need the complex FFT output (fft_out)
268 // We need to borrow workspace fields separately to avoid borrow conflicts
269 let Workspace {
270 spectrum,
271 mapped,
272 frame,
273 ..
274 } = &mut self.workspace;
275
276 self.mapping.apply(spectrum, frame, mapped)?;
277
278 // amplitude scaling in-place on mapped vector
279 self.scaling.apply_in_place(mapped)?;
280
281 // write column into output
282 for (row, &val) in mapped.iter().enumerate() {
283 data[[row, frame_idx]] = val;
284 }
285 }
286
287 let times = build_time_axis_seconds(&self.params, n_frames);
288 let axes = Axes::new(self.freq_axis.clone(), times);
289
290 Ok(Spectrogram::new(data, axes, self.params.clone()))
291 }
292
293 /// Compute a single frame of the spectrogram.
294 ///
295 /// This is useful for streaming/online processing where you want to
296 /// process audio frame-by-frame without computing the entire spectrogram.
297 ///
298 /// # Arguments
299 ///
300 /// * `samples` - Audio samples (must contain at least enough samples for the requested frame)
301 /// * `frame_idx` - Frame index to compute
302 ///
303 /// # Returns
304 ///
305 /// A vector of frequency bin values for the requested frame.
306 ///
307 /// # Errors
308 ///
309 /// Returns an error if the frame index is out of bounds or if STFT computation fails.
310 ///
311 /// # Examples
312 ///
313 /// ```
314 /// use spectrograms::*;
315 /// use non_empty_slice::non_empty_vec;
316 ///
317 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
318 /// let samples = non_empty_vec![0.0; nzu!(16000)];
319 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
320 /// let params = SpectrogramParams::new(stft, 16000.0)?;
321 ///
322 /// let planner = SpectrogramPlanner::new();
323 /// let mut plan = planner.linear_plan::<Power>(¶ms, None)?;
324 ///
325 /// // Compute just the first frame
326 /// let frame = plan.compute_frame(&samples, 0)?;
327 /// assert_eq!(frame.len(), nzu!(257)); // n_fft/2 + 1
328 /// # Ok(())
329 /// # }
330 /// ```
331 #[inline]
332 pub fn compute_frame(
333 &mut self,
334 samples: &NonEmptySlice<f64>,
335 frame_idx: usize,
336 ) -> SpectrogramResult<NonEmptyVec<f64>> {
337 let n_bins = self.mapping.output_bins();
338
339 // Ensure workspace is correctly sized
340 self.workspace
341 .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
342
343 // CQT needs unwindowed frames because its kernels already contain windowing.
344 // Other mappings use the FFT spectrum, so they need the windowed frame.
345 if self.mapping.kind.needs_unwindowed_frame() {
346 // Fill frame without windowing (for CQT)
347 self.stft
348 .fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
349 } else {
350 // Compute windowed frame spectrum (for FFT-based mappings)
351 self.stft
352 .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
353 }
354
355 // Apply mapping (using split borrows to avoid borrow conflicts)
356 let Workspace {
357 spectrum,
358 mapped,
359 frame,
360 ..
361 } = &mut self.workspace;
362
363 self.mapping.apply(spectrum, frame, mapped)?;
364
365 // Apply amplitude scaling
366 self.scaling.apply_in_place(mapped)?;
367
368 Ok(mapped.clone())
369 }
370
371 /// Compute spectrogram into a pre-allocated buffer.
372 ///
373 /// This avoids allocating the output matrix, which is useful when
374 /// you want to reuse buffers or have strict memory requirements.
375 ///
376 /// # Arguments
377 ///
378 /// * `samples` - Audio samples
379 /// * `output` - Pre-allocated output matrix (must be correct size: `n_bins` x `n_frames`)
380 ///
381 /// # Returns
382 ///
383 /// An empty result on success.
384 ///
385 /// # Errors
386 ///
387 /// Returns an error if the output buffer dimensions don't match the expected size.
388 ///
389 /// # Examples
390 ///
391 /// ```
392 /// use spectrograms::*;
393 /// use ndarray::Array2;
394 /// use non_empty_slice::non_empty_vec;
395 ///
396 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
397 /// let samples = non_empty_vec![0.0; nzu!(16000)];
398 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
399 /// let params = SpectrogramParams::new(stft, 16000.0)?;
400 ///
401 /// let planner = SpectrogramPlanner::new();
402 /// let mut plan = planner.linear_plan::<Power>(¶ms, None)?;
403 ///
404 /// // Pre-allocate output buffer
405 /// let mut output = Array2::<f64>::zeros((257, 63));
406 /// plan.compute_into(&samples, &mut output)?;
407 /// # Ok(())
408 /// # }
409 /// ```
410 #[inline]
411 pub fn compute_into(
412 &mut self,
413 samples: &NonEmptySlice<f64>,
414 output: &mut Array2<f64>,
415 ) -> SpectrogramResult<()> {
416 let n_frames = self.stft.frame_count(samples.len())?;
417 let n_bins = self.mapping.output_bins();
418
419 // Validate output dimensions
420 if output.nrows() != n_bins.get() {
421 return Err(SpectrogramError::dimension_mismatch(
422 n_bins.get(),
423 output.nrows(),
424 ));
425 }
426 if output.ncols() != n_frames.get() {
427 return Err(SpectrogramError::dimension_mismatch(
428 n_frames.get(),
429 output.ncols(),
430 ));
431 }
432
433 // Ensure workspace is correctly sized
434 self.workspace
435 .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
436
437 // Main loop: fill each frame (column)
438 for frame_idx in 0..n_frames.get() {
439 // CQT needs unwindowed frames because its kernels already contain windowing.
440 // Other mappings use the FFT spectrum, so they need the windowed frame.
441 if self.mapping.kind.needs_unwindowed_frame() {
442 // Fill frame without windowing (for CQT)
443 self.stft
444 .fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
445 } else {
446 // Compute windowed frame spectrum (for FFT-based mappings)
447 self.stft
448 .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
449 }
450
451 // mapping: spectrum(out_len) -> mapped(n_bins)
452 // For CQT, this uses workspace.frame; for others, workspace.spectrum
453 // For ERB, we need the complex FFT output (fft_out)
454 // We need to borrow workspace fields separately to avoid borrow conflicts
455 let Workspace {
456 spectrum,
457 mapped,
458 frame,
459 ..
460 } = &mut self.workspace;
461
462 self.mapping.apply(spectrum, frame, mapped)?;
463
464 // amplitude scaling in-place on mapped vector
465 self.scaling.apply_in_place(mapped)?;
466
467 // write column into output
468 for (row, &val) in mapped.iter().enumerate() {
469 output[[row, frame_idx]] = val;
470 }
471 }
472
473 Ok(())
474 }
475
476 /// Get the expected output dimensions for a given signal length.
477 ///
478 /// # Arguments
479 ///
480 /// * `signal_length` - Length of the input signal in samples.
481 ///
482 /// # Returns
483 ///
484 /// A tuple `(n_bins, n_frames)` representing the output spectrogram shape.
485 ///
486 /// # Errors
487 ///
488 /// Returns an error if the signal length is too short to produce any frames.
489 ///
490 /// # Examples
491 ///
492 /// ```
493 /// use spectrograms::*;
494 ///
495 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
496 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
497 /// let params = SpectrogramParams::new(stft, 16000.0)?;
498 ///
499 /// let planner = SpectrogramPlanner::new();
500 /// let plan = planner.linear_plan::<Power>(¶ms, None)?;
501 ///
502 /// let (n_bins, n_frames) = plan.output_shape(nzu!(16000))?;
503 /// assert_eq!(n_bins, nzu!(257));
504 /// assert_eq!(n_frames, nzu!(63));
505 /// # Ok(())
506 /// # }
507 /// ```
508 #[inline]
509 pub fn output_shape(
510 &self,
511 signal_length: NonZeroUsize,
512 ) -> SpectrogramResult<(NonZeroUsize, NonZeroUsize)> {
513 let n_frames = self.stft.frame_count(signal_length)?;
514 let n_bins = self.mapping.output_bins();
515 Ok((n_bins, n_frames))
516 }
517}
518
519/// STFT (Short-Time Fourier Transform) result containing complex frequency bins.
520///
521/// This is the raw STFT output before any frequency mapping or amplitude scaling.
522///
523/// # Fields
524///
525/// - `data`: Complex STFT matrix with shape (`frequency_bins`, `time_frames`)
526/// - `frequencies`: Frequency axis in Hz
527/// - `sample_rate`: Sample rate in Hz
528/// - `params`: STFT computation parameters
529#[derive(Debug, Clone)]
530#[non_exhaustive]
531pub struct StftResult {
532 /// Complex STFT matrix with shape (`frequency_bins`, `time_frames`)
533 pub data: Array2<Complex<f64>>,
534 /// Frequency axis in Hz
535 pub frequencies: NonEmptyVec<f64>,
536 /// Sample rate in Hz
537 pub sample_rate: f64,
538 pub params: StftParams,
539}
540
541impl StftResult {
542 /// Get the number of frequency bins.
543 ///
544 /// # Returns
545 ///
546 /// Number of frequency bins in the STFT result.
547 #[inline]
548 #[must_use]
549 pub fn n_bins(&self) -> NonZeroUsize {
550 // safety: nrows() > 0 for NonEmptyVec
551 unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
552 }
553
554 /// Get the number of time frames.
555 ///
556 /// # Returns
557 ///
558 /// Number of time frames in the STFT result.
559 #[inline]
560 #[must_use]
561 pub fn n_frames(&self) -> NonZeroUsize {
562 // safety: ncols() > 0 for NonEmptyVec
563 unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
564 }
565
566 /// Get the frequency resolution in Hz
567 ///
568 /// # Returns
569 ///
570 /// Frequency bin width in Hz.
571 #[inline]
572 #[must_use]
573 pub fn frequency_resolution(&self) -> f64 {
574 self.sample_rate / self.params.n_fft().get() as f64
575 }
576
577 /// Get the time resolution in seconds.
578 ///
579 /// # Returns
580 ///
581 /// Time between successive frames in seconds.
582 #[inline]
583 #[must_use]
584 pub fn time_resolution(&self) -> f64 {
585 self.params.hop_size().get() as f64 / self.sample_rate
586 }
587
588 /// Normalizes self.data to remove the complex aspect of it.
589 ///
590 /// # Returns
591 ///
592 /// An Array2\<f64\> containing the norms of each complex number in self.data.
593 #[inline]
594 pub fn norm(&self) -> Array2<f64> {
595 self.as_ref().mapv(Complex::norm)
596 }
597}
598
599impl AsRef<Array2<Complex<f64>>> for StftResult {
600 #[inline]
601 fn as_ref(&self) -> &Array2<Complex<f64>> {
602 &self.data
603 }
604}
605
606impl AsMut<Array2<Complex<f64>>> for StftResult {
607 #[inline]
608 fn as_mut(&mut self) -> &mut Array2<Complex<f64>> {
609 &mut self.data
610 }
611}
612
613impl Deref for StftResult {
614 type Target = Array2<Complex<f64>>;
615
616 #[inline]
617 fn deref(&self) -> &Self::Target {
618 &self.data
619 }
620}
621
622impl DerefMut for StftResult {
623 #[inline]
624 fn deref_mut(&mut self) -> &mut Self::Target {
625 &mut self.data
626 }
627}
628
629/// A planner is an object that can build spectrogram plans.
630///
631/// In your design, this is where:
632/// - FFT plans are created
633/// - mapping matrices are compiled
634/// - axes are computed
635///
636/// This allows you to keep plan building separate from the output types.
637#[derive(Debug, Default)]
638#[non_exhaustive]
639pub struct SpectrogramPlanner;
640
641impl SpectrogramPlanner {
642 /// Create a new spectrogram planner.
643 ///
644 /// # Returns
645 ///
646 /// A new `SpectrogramPlanner` instance.
647 #[inline]
648 #[must_use]
649 pub const fn new() -> Self {
650 Self
651 }
652
653 /// Compute the Short-Time Fourier Transform (STFT) of a signal.
654 ///
655 /// This returns the raw complex STFT matrix before any frequency mapping
656 /// or amplitude scaling. Useful for applications that need the full complex
657 /// spectrum or custom processing.
658 ///
659 /// # Arguments
660 ///
661 /// * `samples` - Audio samples (any type that can be converted to a slice)
662 /// * `params` - STFT computation parameters
663 ///
664 /// # Returns
665 ///
666 /// An `StftResult` containing the complex STFT matrix and metadata.
667 ///
668 /// # Errors
669 ///
670 /// Returns an error if STFT computation fails.
671 ///
672 /// # Examples
673 ///
674 /// ```
675 /// use spectrograms::*;
676 /// use non_empty_slice::non_empty_vec;
677 ///
678 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
679 /// let samples = non_empty_vec![0.0; nzu!(16000)];
680 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
681 /// let params = SpectrogramParams::new(stft, 16000.0)?;
682 ///
683 /// let planner = SpectrogramPlanner::new();
684 /// let stft_result = planner.compute_stft(&samples, ¶ms)?;
685 ///
686 /// println!("STFT: {} bins x {} frames", stft_result.n_bins(), stft_result.n_frames());
687 /// # Ok(())
688 /// # }
689 /// ```
690 ///
691 /// # Performance Note
692 ///
693 /// This method creates a new FFT plan each time. For processing multiple
694 /// signals, create a reusable plan with `StftPlan::new()` instead.
695 ///
696 /// # Examples
697 ///
698 /// ```rust
699 /// use spectrograms::*;
700 /// use non_empty_slice::non_empty_vec;
701 ///
702 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
703 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
704 /// let params = SpectrogramParams::new(stft, 16000.0)?;
705 ///
706 /// // One-shot (convenient)
707 /// let planner = SpectrogramPlanner::new();
708 /// let stft_result = planner.compute_stft(&non_empty_vec![0.0; nzu!(16000)], ¶ms)?;
709 ///
710 /// // Reusable plan (efficient for batch)
711 /// let mut plan = StftPlan::new(¶ms)?;
712 /// for signal in &[non_empty_vec![0.0; nzu!(16000)], non_empty_vec![1.0; nzu!(16000)]] {
713 /// let stft = plan.compute(&signal, ¶ms)?;
714 /// }
715 /// # Ok(())
716 /// # }
717 /// ```
718 #[inline]
719 pub fn compute_stft(
720 &self,
721 samples: &NonEmptySlice<f64>,
722 params: &SpectrogramParams,
723 ) -> SpectrogramResult<StftResult> {
724 let mut plan = StftPlan::new(params)?;
725 plan.compute(samples, params)
726 }
727
728 /// Compute the power spectrum of a single audio frame.
729 ///
730 /// This is useful for real-time processing or analyzing individual frames.
731 ///
732 /// # Arguments
733 ///
734 /// * `samples` - Audio frame (length ≤ n_fft, will be zero-padded if shorter)
735 /// * `n_fft` - FFT size
736 /// * `window` - Window type to apply
737 ///
738 /// # Returns
739 ///
740 /// A vector of power values (|X|²) with length `n_fft/2` + 1.
741 ///
742 /// # Automatic Zero-Padding
743 ///
744 /// If the input signal is shorter than `n_fft`, it will be automatically
745 /// zero-padded to the required length.
746 ///
747 /// # Errors
748 ///
749 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
750 ///
751 /// # Examples
752 ///
753 /// ```
754 /// use spectrograms::*;
755 /// use non_empty_slice::non_empty_vec;
756 ///
757 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
758 /// let frame = non_empty_vec![0.0; nzu!(512)];
759 ///
760 /// let planner = SpectrogramPlanner::new();
761 /// let power = planner.compute_power_spectrum(frame.as_ref(), nzu!(512), WindowType::Hanning)?;
762 ///
763 /// assert_eq!(power.len(), nzu!(257)); // 512/2 + 1
764 /// # Ok(())
765 /// # }
766 /// ```
767 #[inline]
768 pub fn compute_power_spectrum(
769 &self,
770 samples: &NonEmptySlice<f64>,
771 n_fft: NonZeroUsize,
772 window: WindowType,
773 ) -> SpectrogramResult<NonEmptyVec<f64>> {
774 if samples.len() > n_fft {
775 return Err(SpectrogramError::invalid_input(format!(
776 "Input length ({}) exceeds FFT size ({})",
777 samples.len(),
778 n_fft
779 )));
780 }
781
782 let window_samples = make_window(window, n_fft);
783 let out_len = r2c_output_size(n_fft.get());
784
785 // Create FFT plan
786 #[cfg(feature = "realfft")]
787 let mut fft = {
788 let mut planner = crate::RealFftPlanner::new();
789 let plan = planner.get_or_create(n_fft.get());
790 crate::RealFftPlan::new(n_fft.get(), plan)
791 };
792
793 #[cfg(feature = "fftw")]
794 let mut fft = {
795 use std::sync::Arc;
796 let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
797 crate::FftwPlan::new(Arc::new(plan))
798 };
799
800 // Apply window and compute FFT
801 let mut windowed = vec![0.0; n_fft.get()];
802 for i in 0..samples.len().get() {
803 windowed[i] = samples[i] * window_samples[i];
804 }
805 // The rest is already zero-padded
806 let mut fft_out = vec![Complex::new(0.0, 0.0); out_len];
807 fft.process(&windowed, &mut fft_out)?;
808
809 // Convert to power
810 let power: Vec<f64> = fft_out.iter().map(num_complex::Complex::norm_sqr).collect();
811 // safety: power is non-empty since n_fft > 0
812 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
813 }
814
815 /// Compute the magnitude spectrum of a single audio frame.
816 ///
817 /// This is useful for real-time processing or analyzing individual frames.
818 ///
819 /// # Arguments
820 ///
821 /// * `samples` - Audio frame (length ≤ n_fft, will be zero-padded if shorter)
822 /// * `n_fft` - FFT size
823 /// * `window` - Window type to apply
824 ///
825 /// # Returns
826 ///
827 /// A vector of magnitude values (|X|) with length `n_fft/2` + 1.
828 ///
829 /// # Automatic Zero-Padding
830 ///
831 /// If the input signal is shorter than `n_fft`, it will be automatically
832 /// zero-padded to the required length.
833 ///
834 /// # Errors
835 ///
836 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
837 ///
838 /// # Examples
839 ///
840 /// ```
841 /// use spectrograms::*;
842 /// use non_empty_slice::non_empty_vec;
843 ///
844 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
845 /// let frame = non_empty_vec![0.0; nzu!(512)];
846 ///
847 /// let planner = SpectrogramPlanner::new();
848 /// let magnitude = planner.compute_magnitude_spectrum(frame.as_ref(), nzu!(512), WindowType::Hanning)?;
849 ///
850 /// assert_eq!(magnitude.len(), nzu!(257)); // 512/2 + 1
851 /// # Ok(())
852 /// # }
853 /// ```
854 #[inline]
855 pub fn compute_magnitude_spectrum(
856 &self,
857 samples: &NonEmptySlice<f64>,
858 n_fft: NonZeroUsize,
859 window: WindowType,
860 ) -> SpectrogramResult<NonEmptyVec<f64>> {
861 let power = self.compute_power_spectrum(samples, n_fft, window)?;
862 let power = power.iter().map(|&p| p.sqrt()).collect::<Vec<f64>>();
863 // safety: power is non-empty since power_spectrum returned successfully
864 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
865 }
866
867 /// Build a linear-frequency spectrogram plan.
868 ///
869 /// # Type Parameters
870 ///
871 /// `AmpScale` determines whether output is:
872 /// - Magnitude
873 /// - Power
874 /// - Decibels
875 ///
876 /// # Arguments
877 ///
878 /// * `params` - Spectrogram parameters
879 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale
880 /// is `Decibels`)
881 ///
882 /// # Returns
883 ///
884 /// A `SpectrogramPlan` configured for linear-frequency spectrogram computation.
885 ///
886 /// # Errors
887 ///
888 /// Returns an error if the plan cannot be created due to invalid parameters.
889 #[inline]
890 pub fn linear_plan<AmpScale>(
891 &self,
892 params: &SpectrogramParams,
893 db: Option<&LogParams>, // only used when AmpScale = Decibels
894 ) -> SpectrogramResult<SpectrogramPlan<LinearHz, AmpScale>>
895 where
896 AmpScale: AmpScaleSpec + 'static,
897 {
898 let stft = StftPlan::new(params)?;
899 let mapping = FrequencyMapping::<LinearHz>::new(params)?;
900 let scaling = AmplitudeScaling::<AmpScale>::new(db);
901 let freq_axis = build_frequency_axis::<LinearHz>(params, &mapping);
902
903 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
904
905 Ok(SpectrogramPlan {
906 params: params.clone(),
907 stft,
908 mapping,
909 scaling,
910 freq_axis,
911 workspace,
912 _amp: PhantomData,
913 })
914 }
915
916 /// Build a mel-frequency spectrogram plan.
917 ///
918 /// This compiles a mel filterbank matrix and caches it inside the plan.
919 ///
920 /// # Type Parameters
921 ///
922 /// `AmpScale`: determines whether output is:
923 /// - Magnitude
924 /// - Power
925 /// - Decibels
926 ///
927 /// # Arguments
928 ///
929 /// * `params` - Spectrogram parameters
930 /// * `mel` - Mel-specific parameters
931 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
932 ///
933 /// # Returns
934 ///
935 /// A `SpectrogramPlan` configured for mel spectrogram computation.
936 ///
937 /// # Errors
938 ///
939 /// Returns an error if the plan cannot be created due to invalid parameters.
940 #[inline]
941 pub fn mel_plan<AmpScale>(
942 &self,
943 params: &SpectrogramParams,
944 mel: &MelParams,
945 db: Option<&LogParams>, // only used when AmpScale = Decibels
946 ) -> SpectrogramResult<SpectrogramPlan<Mel, AmpScale>>
947 where
948 AmpScale: AmpScaleSpec + 'static,
949 {
950 // cross-validation: mel range must be compatible with sample rate
951 let nyquist = params.nyquist_hz();
952 if mel.f_max() > nyquist {
953 return Err(SpectrogramError::invalid_input(
954 "mel f_max must be <= Nyquist",
955 ));
956 }
957
958 let stft = StftPlan::new(params)?;
959 let mapping = FrequencyMapping::<Mel>::new_mel(params, mel)?;
960 let scaling = AmplitudeScaling::<AmpScale>::new(db);
961 let freq_axis = build_frequency_axis::<Mel>(params, &mapping);
962
963 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
964
965 Ok(SpectrogramPlan {
966 params: params.clone(),
967 stft,
968 mapping,
969 scaling,
970 freq_axis,
971 workspace,
972 _amp: PhantomData,
973 })
974 }
975
976 /// Build an ERB-scale spectrogram plan.
977 ///
978 /// This creates a spectrogram with ERB-spaced frequency bands using gammatone
979 /// filterbank approximation in the frequency domain.
980 ///
981 /// # Type Parameters
982 ///
983 /// `AmpScale`: determines whether output is:
984 /// - Magnitude
985 /// - Power
986 /// - Decibels
987 ///
988 /// # Arguments
989 ///
990 /// * `params` - Spectrogram parameters
991 /// * `erb` - ERB-specific parameters
992 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
993 ///
994 /// # Returns
995 ///
996 /// A `SpectrogramPlan` configured for ERB spectrogram computation.
997 ///
998 /// # Errors
999 ///
1000 /// Returns an error if the plan cannot be created due to invalid parameters.
1001 #[inline]
1002 pub fn erb_plan<AmpScale>(
1003 &self,
1004 params: &SpectrogramParams,
1005 erb: &ErbParams,
1006 db: Option<&LogParams>,
1007 ) -> SpectrogramResult<SpectrogramPlan<Erb, AmpScale>>
1008 where
1009 AmpScale: AmpScaleSpec + 'static,
1010 {
1011 // cross-validation: erb range must be compatible with sample rate
1012 let nyquist = params.nyquist_hz();
1013 if erb.f_max() > nyquist {
1014 return Err(SpectrogramError::invalid_input(format!(
1015 "f_max={} exceeds Nyquist={}",
1016 erb.f_max(),
1017 nyquist
1018 )));
1019 }
1020
1021 let stft = StftPlan::new(params)?;
1022 let mapping = FrequencyMapping::<Erb>::new_erb(params, erb)?;
1023 let scaling = AmplitudeScaling::<AmpScale>::new(db);
1024 let freq_axis = build_frequency_axis::<Erb>(params, &mapping);
1025
1026 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1027
1028 Ok(SpectrogramPlan {
1029 params: params.clone(),
1030 stft,
1031 mapping,
1032 scaling,
1033 freq_axis,
1034 workspace,
1035 _amp: PhantomData,
1036 })
1037 }
1038
1039 /// Build a log-frequency plan.
1040 ///
1041 /// This creates a spectrogram with logarithmically-spaced frequency bins.
1042 ///
1043 /// # Type Parameters
1044 ///
1045 /// `AmpScale`: determines whether output is:
1046 /// - Magnitude
1047 /// - Power
1048 /// - Decibels
1049 ///
1050 /// # Arguments
1051 ///
1052 /// * `params` - Spectrogram parameters
1053 /// * `loghz` - LogHz-specific parameters
1054 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
1055 ///
1056 /// # Returns
1057 ///
1058 /// A `SpectrogramPlan` configured for log-frequency spectrogram computation.
1059 ///
1060 /// # Errors
1061 ///
1062 /// Returns an error if the plan cannot be created due to invalid parameters.
1063 #[inline]
1064 pub fn log_hz_plan<AmpScale>(
1065 &self,
1066 params: &SpectrogramParams,
1067 loghz: &LogHzParams,
1068 db: Option<&LogParams>,
1069 ) -> SpectrogramResult<SpectrogramPlan<LogHz, AmpScale>>
1070 where
1071 AmpScale: AmpScaleSpec + 'static,
1072 {
1073 // cross-validation: loghz range must be compatible with sample rate
1074 let nyquist = params.nyquist_hz();
1075 if loghz.f_max() > nyquist {
1076 return Err(SpectrogramError::invalid_input(format!(
1077 "f_max={} exceeds Nyquist={}",
1078 loghz.f_max(),
1079 nyquist
1080 )));
1081 }
1082
1083 let stft = StftPlan::new(params)?;
1084 let mapping = FrequencyMapping::<LogHz>::new_loghz(params, loghz)?;
1085 let scaling = AmplitudeScaling::<AmpScale>::new(db);
1086 let freq_axis = build_frequency_axis::<LogHz>(params, &mapping);
1087
1088 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1089
1090 Ok(SpectrogramPlan {
1091 params: params.clone(),
1092 stft,
1093 mapping,
1094 scaling,
1095 freq_axis,
1096 workspace,
1097 _amp: PhantomData,
1098 })
1099 }
1100
1101 /// Build a cqt spectrogram plan.
1102 ///
1103 /// # Type Parameters
1104 ///
1105 /// `AmpScale`: determines whether output is:
1106 /// - Magnitude
1107 /// - Power
1108 /// - Decibels
1109 ///
1110 /// # Arguments
1111 ///
1112 /// * `params` - Spectrogram parameters
1113 /// * `cqt` - CQT-specific parameters
1114 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
1115 ///
1116 /// # Returns
1117 ///
1118 /// A `SpectrogramPlan` configured for CQT spectrogram computation.
1119 ///
1120 /// # Errors
1121 ///
1122 /// Returns an error if the plan cannot be created due to invalid parameters.
1123 #[inline]
1124 pub fn cqt_plan<AmpScale>(
1125 &self,
1126 params: &SpectrogramParams,
1127 cqt: &CqtParams,
1128 db: Option<&LogParams>, // only used when AmpScale = Decibels
1129 ) -> SpectrogramResult<SpectrogramPlan<Cqt, AmpScale>>
1130 where
1131 AmpScale: AmpScaleSpec + 'static,
1132 {
1133 let stft = StftPlan::new(params)?;
1134 let mapping = FrequencyMapping::<Cqt>::new(params, cqt)?;
1135 let scaling = AmplitudeScaling::<AmpScale>::new(db);
1136 let freq_axis = build_frequency_axis::<Cqt>(params, &mapping);
1137
1138 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1139
1140 Ok(SpectrogramPlan {
1141 params: params.clone(),
1142 stft,
1143 mapping,
1144 scaling,
1145 freq_axis,
1146 workspace,
1147 _amp: PhantomData,
1148 })
1149 }
1150}
1151
1152/// STFT plan containing reusable FFT plan and buffers.
1153///
1154/// This struct is responsible for performing the Short-Time Fourier Transform (STFT)
1155/// on audio signals based on the provided parameters.
1156///
1157/// It encapsulates the FFT plan, windowing function, and internal buffers to efficiently
1158/// compute the STFT for multiple frames of audio data.
1159///
1160/// # Fields
1161///
1162/// - `n_fft`: Size of the FFT.
1163/// - `hop_size`: Hop size between consecutive frames.
1164/// - `window`: Windowing function samples.
1165/// - `centre`: Whether to centre the frames with padding.
1166/// - `out_len`: Length of the FFT output.
1167/// - `fft`: Boxed FFT plan for real-to-complex transformation.
1168/// - `fft_out`: Internal buffer for FFT output.
1169/// - `frame`: Internal buffer for windowed audio frames.
1170pub struct StftPlan {
1171 n_fft: NonZeroUsize,
1172 hop_size: NonZeroUsize,
1173 window: NonEmptyVec<f64>,
1174 centre: bool,
1175
1176 out_len: NonZeroUsize,
1177
1178 // FFT plan (reused for all frames)
1179 fft: Box<dyn R2cPlan>,
1180
1181 // internal scratch
1182 fft_out: NonEmptyVec<Complex<f64>>,
1183 frame: NonEmptyVec<f64>,
1184}
1185
1186impl StftPlan {
1187 /// Create a new STFT plan from parameters.
1188 ///
1189 /// # Arguments
1190 ///
1191 /// * `params` - Spectrogram parameters containing STFT config
1192 ///
1193 /// # Returns
1194 ///
1195 /// A new `StftPlan` instance.
1196 ///
1197 /// # Errors
1198 ///
1199 /// Returns an error if the FFT plan cannot be created.
1200 #[inline]
1201 pub fn new(params: &SpectrogramParams) -> SpectrogramResult<Self> {
1202 let stft = params.stft();
1203 let n_fft = stft.n_fft();
1204 let hop_size = stft.hop_size();
1205 let centre = stft.centre();
1206
1207 let window = make_window(stft.window(), n_fft);
1208
1209 let out_len = r2c_output_size(n_fft.get());
1210 let out_len = NonZeroUsize::new(out_len)
1211 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1212
1213 #[cfg(feature = "realfft")]
1214 let fft = {
1215 let mut planner = crate::RealFftPlanner::new();
1216 let plan = planner.get_or_create(n_fft.get());
1217 let plan = crate::RealFftPlan::new(n_fft.get(), plan);
1218 Box::new(plan)
1219 };
1220
1221 #[cfg(feature = "fftw")]
1222 let fft = {
1223 use std::sync::Arc;
1224 let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
1225 Box::new(crate::FftwPlan::new(Arc::new(plan)))
1226 };
1227
1228 Ok(Self {
1229 n_fft,
1230 hop_size,
1231 window,
1232 centre,
1233 out_len,
1234 fft,
1235 fft_out: non_empty_vec![Complex::new(0.0, 0.0); out_len],
1236 frame: non_empty_vec![0.0; n_fft],
1237 })
1238 }
1239
1240 fn frame_count(&self, n_samples: NonZeroUsize) -> SpectrogramResult<NonZeroUsize> {
1241 // Framing policy:
1242 // - centre = true: implicit padding of n_fft/2 on both sides
1243 // - centre = false: no padding
1244 //
1245 // Define the number of frames such that each frame has a valid centre sample position.
1246 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1247 let padded_len = n_samples.get() + 2 * pad;
1248
1249 if padded_len < self.n_fft.get() {
1250 // still produce 1 frame (all padding / partial)
1251 return Ok(nzu!(1));
1252 }
1253
1254 let remaining = padded_len - self.n_fft.get();
1255 let n_frames = remaining / self.hop_size().get() + 1;
1256 let n_frames = NonZeroUsize::new(n_frames).ok_or_else(|| {
1257 SpectrogramError::invalid_input("computed number of frames must be non-zero")
1258 })?;
1259 Ok(n_frames)
1260 }
1261
1262 /// Compute one frame FFT using internal buffers only.
1263 fn compute_frame_fft_simple(
1264 &mut self,
1265 samples: &NonEmptySlice<f64>,
1266 frame_idx: usize,
1267 ) -> SpectrogramResult<()> {
1268 let out = self.frame.as_mut_slice();
1269 debug_assert_eq!(out.len(), self.n_fft.get());
1270
1271 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1272 let start = frame_idx
1273 .checked_mul(self.hop_size.get())
1274 .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1275
1276 // Fill windowed frame
1277 for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1278 let v_idx = start + i;
1279 let s_idx = v_idx as isize - pad as isize;
1280
1281 let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1282 0.0
1283 } else {
1284 samples[s_idx as usize]
1285 };
1286 *sample = sample_val * self.window[i];
1287 }
1288
1289 // Compute FFT
1290 let fft_out = self.fft_out.as_mut_slice();
1291 self.fft.process(out, fft_out)?;
1292
1293 Ok(())
1294 }
1295
1296 /// Compute one frame spectrum into workspace:
1297 /// - fills windowed frame
1298 /// - runs FFT
1299 /// - converts to magnitude/power based on `AmpScale` later
1300 fn compute_frame_spectrum(
1301 &mut self,
1302 samples: &NonEmptySlice<f64>,
1303 frame_idx: usize,
1304 workspace: &mut Workspace,
1305 ) -> SpectrogramResult<()> {
1306 let out = workspace.frame.as_mut_slice();
1307
1308 // self.fill_frame(samples, frame_idx, frame)?;
1309 debug_assert_eq!(out.len(), self.n_fft.get());
1310
1311 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1312 let start = frame_idx
1313 .checked_mul(self.hop_size().get())
1314 .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1315
1316 // The "virtual" signal is samples with pad zeros on both sides.
1317 // Virtual index 0..padded_len
1318 // Map virtual index to original samples by subtracting pad.
1319 for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1320 let v_idx = start + i;
1321 let s_idx = v_idx as isize - pad as isize;
1322
1323 let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1324 0.0
1325 } else {
1326 samples[s_idx as usize]
1327 };
1328
1329 *sample = sample_val * self.window[i];
1330 }
1331 let fft_out = workspace.fft_out.as_mut_slice();
1332 // FFT
1333 self.fft.process(out, fft_out)?;
1334
1335 // Convert complex spectrum to linear magnitude OR power here? No:
1336 // Keep "spectrum" as power by default? That would entangle semantics.
1337 //
1338 // Instead, we store magnitude^2 (power) as the canonical intermediate,
1339 // and let AmpScale decide later whether output is magnitude or power.
1340 //
1341 // This is consistent and avoids recomputing norms multiple times.
1342 for (i, c) in workspace.fft_out.iter().enumerate() {
1343 workspace.spectrum[i] = c.norm_sqr();
1344 }
1345
1346 Ok(())
1347 }
1348
1349 /// Fill a time-domain frame WITHOUT applying the window.
1350 ///
1351 /// This is used for CQT mapping, where the CQT kernels already contain
1352 /// windowing applied during kernel generation. Applying the STFT window
1353 /// would result in double-windowing.
1354 ///
1355 /// # Arguments
1356 ///
1357 /// * `samples` - Input audio samples
1358 /// * `frame_idx` - Frame index
1359 /// * `workspace` - Workspace containing the frame buffer to fill
1360 ///
1361 /// # Errors
1362 ///
1363 /// Returns an error if frame index is out of bounds.
1364 fn fill_frame_unwindowed(
1365 &self,
1366 samples: &NonEmptySlice<f64>,
1367 frame_idx: usize,
1368 workspace: &mut Workspace,
1369 ) -> SpectrogramResult<()> {
1370 let out = workspace.frame.as_mut_slice();
1371 debug_assert_eq!(out.len(), self.n_fft.get());
1372
1373 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1374 let start = frame_idx
1375 .checked_mul(self.hop_size().get())
1376 .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1377
1378 // Fill frame WITHOUT windowing
1379 for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1380 let v_idx = start + i;
1381 let s_idx = v_idx as isize - pad as isize;
1382
1383 let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1384 0.0
1385 } else {
1386 samples[s_idx as usize]
1387 };
1388
1389 // No window multiplication - just copy the sample
1390 *sample = sample_val;
1391 }
1392
1393 Ok(())
1394 }
1395
1396 /// Compute the full STFT for a signal, returning an `StftResult`.
1397 ///
1398 /// This is a convenience method that handles frame iteration and
1399 /// builds the complete STFT matrix.
1400 ///
1401 ///
1402 /// # Arguments
1403 ///
1404 /// * `samples` - Input audio samples
1405 /// * `params` - STFT computation parameters
1406 ///
1407 /// # Returns
1408 ///
1409 /// An `StftResult` containing the complex STFT matrix and metadata.
1410 ///
1411 /// # Errors
1412 ///
1413 /// Returns an error if computation fails.
1414 ///
1415 /// # Examples
1416 ///
1417 /// ```rust
1418 /// use spectrograms::*;
1419 /// use non_empty_slice::non_empty_vec;
1420 ///
1421 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1422 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1423 /// let params = SpectrogramParams::new(stft, 16000.0)?;
1424 /// let mut plan = StftPlan::new(¶ms)?;
1425 ///
1426 /// let samples = non_empty_vec![0.0; nzu!(16000)];
1427 /// let stft_result = plan.compute(&samples, ¶ms)?;
1428 ///
1429 /// println!("STFT: {} bins x {} frames", stft_result.n_bins(), stft_result.n_frames());
1430 /// # Ok(())
1431 /// # }
1432 /// ```
1433 #[inline]
1434 pub fn compute(
1435 &mut self,
1436 samples: &NonEmptySlice<f64>,
1437 params: &SpectrogramParams,
1438 ) -> SpectrogramResult<StftResult> {
1439 let n_frames = self.frame_count(samples.len())?;
1440 let n_bins = self.out_len;
1441
1442 // Allocate output matrix (frequency_bins x time_frames)
1443 let mut data = Array2::<Complex<f64>>::zeros((n_bins.get(), n_frames.get()));
1444
1445 // Compute each frame
1446 for frame_idx in 0..n_frames.get() {
1447 self.compute_frame_fft_simple(samples, frame_idx)?;
1448
1449 // Copy from internal buffer to output
1450 for (bin_idx, &value) in self.fft_out.iter().enumerate() {
1451 data[[bin_idx, frame_idx]] = value;
1452 }
1453 }
1454
1455 // Build frequency axis
1456 let frequencies: Vec<f64> = (0..n_bins.get())
1457 .map(|k| k as f64 * params.sample_rate_hz() / params.stft().n_fft().get() as f64)
1458 .collect();
1459 // SAFETY: n_bins > 0
1460 let frequencies = unsafe { NonEmptyVec::new_unchecked(frequencies) };
1461
1462 Ok(StftResult {
1463 data,
1464 frequencies,
1465 sample_rate: params.sample_rate_hz(),
1466 params: params.stft().clone(),
1467 })
1468 }
1469
1470 /// Compute a single frame of STFT, returning the complex spectrum.
1471 ///
1472 /// This is useful for streaming/online processing where you want to
1473 /// process audio frame-by-frame.
1474 ///
1475 /// # Arguments
1476 ///
1477 /// * `samples` - Input audio samples
1478 /// * `frame_idx` - Index of the frame to compute
1479 ///
1480 /// # Returns
1481 ///
1482 /// A `NonEmptyVec` containing the complex spectrum for the specified frame.
1483 ///
1484 /// # Errors
1485 ///
1486 /// Returns an error if computation fails.
1487 ///
1488 /// # Examples
1489 ///
1490 /// ```rust
1491 /// use spectrograms::*;
1492 /// use non_empty_slice::non_empty_vec;
1493 ///
1494 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1495 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1496 /// let params = SpectrogramParams::new(stft, 16000.0)?;
1497 /// let mut plan = StftPlan::new(¶ms)?;
1498 ///
1499 /// let samples = non_empty_vec![0.0; nzu!(16000)];
1500 /// let (_, n_frames) = plan.output_shape(samples.len())?;
1501 ///
1502 /// for frame_idx in 0..n_frames.get() {
1503 /// let spectrum = plan.compute_frame_simple(&samples, frame_idx)?;
1504 /// // Process spectrum...
1505 /// }
1506 /// # Ok(())
1507 /// # }
1508 /// ```
1509 #[inline]
1510 pub fn compute_frame_simple(
1511 &mut self,
1512 samples: &NonEmptySlice<f64>,
1513 frame_idx: usize,
1514 ) -> SpectrogramResult<NonEmptyVec<Complex<f64>>> {
1515 self.compute_frame_fft_simple(samples, frame_idx)?;
1516 Ok(self.fft_out.clone())
1517 }
1518
1519 /// Compute STFT into a pre-allocated buffer.
1520 ///
1521 /// This avoids allocating the output matrix, useful for reusing buffers.
1522 ///
1523 /// # Arguments
1524 ///
1525 /// * `samples` - Input audio samples
1526 /// * `output` - Pre-allocated output buffer (shape: `n_bins` x `n_frames`)
1527 ///
1528 /// # Returns
1529 ///
1530 /// `Ok(())` on success, or an error if dimensions mismatch.
1531 ///
1532 /// # Errors
1533 ///
1534 /// Returns `DimensionMismatch` error if the output buffer has incorrect shape.
1535 ///
1536 /// # Examples
1537 ///
1538 /// ```rust
1539 /// use spectrograms::*;
1540 /// use ndarray::Array2;
1541 /// use num_complex::Complex;
1542 /// use non_empty_slice::non_empty_vec;
1543 ///
1544 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1545 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1546 /// let params = SpectrogramParams::new(stft, 16000.0)?;
1547 /// let mut plan = StftPlan::new(¶ms)?;
1548 ///
1549 /// let samples = non_empty_vec![0.0; nzu!(16000)];
1550 /// let (n_bins, n_frames) = plan.output_shape(samples.len())?;
1551 /// let mut output = Array2::<Complex<f64>>::zeros((n_bins.get(), n_frames.get()));
1552 ///
1553 /// plan.compute_into(&samples, &mut output)?;
1554 /// # Ok(())
1555 /// # }
1556 /// ```
1557 #[inline]
1558 pub fn compute_into(
1559 &mut self,
1560 samples: &NonEmptySlice<f64>,
1561 output: &mut Array2<Complex<f64>>,
1562 ) -> SpectrogramResult<()> {
1563 let n_frames = self.frame_count(samples.len())?;
1564 let n_bins = self.out_len;
1565
1566 // Validate output dimensions
1567 if output.nrows() != n_bins.get() {
1568 return Err(SpectrogramError::dimension_mismatch(
1569 n_bins.get(),
1570 output.nrows(),
1571 ));
1572 }
1573 if output.ncols() != n_frames.get() {
1574 return Err(SpectrogramError::dimension_mismatch(
1575 n_frames.get(),
1576 output.ncols(),
1577 ));
1578 }
1579
1580 // Compute into pre-allocated buffer
1581 for frame_idx in 0..n_frames.get() {
1582 self.compute_frame_fft_simple(samples, frame_idx)?;
1583
1584 for (bin_idx, &value) in self.fft_out.iter().enumerate() {
1585 output[[bin_idx, frame_idx]] = value;
1586 }
1587 }
1588
1589 Ok(())
1590 }
1591
1592 /// Get the expected output dimensions for a given signal length.
1593 ///
1594 /// # Arguments
1595 ///
1596 /// * `signal_length` - Length of the input signal in samples
1597 ///
1598 /// # Returns
1599 ///
1600 /// A tuple `(n_frequency_bins, n_time_frames)`.
1601 ///
1602 /// # Errors
1603 ///
1604 /// Returns an error if the computed number of frames is invalid.
1605 #[inline]
1606 pub fn output_shape(
1607 &self,
1608 signal_length: NonZeroUsize,
1609 ) -> SpectrogramResult<(NonZeroUsize, NonZeroUsize)> {
1610 let n_frames = self.frame_count(signal_length)?;
1611 Ok((self.out_len, n_frames))
1612 }
1613
1614 /// Get the number of frequency bins in the output.
1615 ///
1616 /// # Returns
1617 ///
1618 /// The number of frequency bins.
1619 #[inline]
1620 #[must_use]
1621 pub const fn n_bins(&self) -> NonZeroUsize {
1622 self.out_len
1623 }
1624
1625 /// Get the FFT size.
1626 ///
1627 /// # Returns
1628 ///
1629 /// The FFT size.
1630 #[inline]
1631 #[must_use]
1632 pub const fn n_fft(&self) -> NonZeroUsize {
1633 self.n_fft
1634 }
1635
1636 /// Get the hop size.
1637 ///
1638 /// # Returns
1639 ///
1640 /// The hop size.
1641 #[inline]
1642 #[must_use]
1643 pub const fn hop_size(&self) -> NonZeroUsize {
1644 self.hop_size
1645 }
1646}
1647
1648#[derive(Debug, Clone)]
1649enum MappingKind {
1650 Identity {
1651 out_len: NonZeroUsize,
1652 },
1653 Mel {
1654 matrix: SparseMatrix,
1655 }, // shape: (n_mels, out_len)
1656 LogHz {
1657 matrix: SparseMatrix,
1658 frequencies: NonEmptyVec<f64>,
1659 }, // shape: (n_bins, out_len)
1660 Erb {
1661 filterbank: ErbFilterbank,
1662 },
1663 Cqt {
1664 kernel: CqtKernel,
1665 },
1666}
1667
1668impl MappingKind {
1669 /// Check if this mapping requires unwindowed time-domain frames.
1670 ///
1671 /// CQT kernels already contain windowing applied during kernel generation,
1672 /// so they should receive unwindowed frames to avoid double-windowing.
1673 /// All other mappings work on FFT spectra and don't use the time-domain frame.
1674 const fn needs_unwindowed_frame(&self) -> bool {
1675 matches!(self, Self::Cqt { .. })
1676 }
1677}
1678
1679/// Typed mapping wrapper.
1680#[derive(Debug, Clone)]
1681struct FrequencyMapping<FreqScale> {
1682 kind: MappingKind,
1683 _marker: PhantomData<FreqScale>,
1684}
1685
1686impl FrequencyMapping<LinearHz> {
1687 fn new(params: &SpectrogramParams) -> SpectrogramResult<Self> {
1688 let out_len = r2c_output_size(params.stft().n_fft().get());
1689 let out_len = NonZeroUsize::new(out_len)
1690 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1691 Ok(Self {
1692 kind: MappingKind::Identity { out_len },
1693 _marker: PhantomData,
1694 })
1695 }
1696}
1697
1698impl FrequencyMapping<Mel> {
1699 fn new_mel(params: &SpectrogramParams, mel: &MelParams) -> SpectrogramResult<Self> {
1700 let n_fft = params.stft().n_fft();
1701 let out_len = r2c_output_size(n_fft.get());
1702 let out_len = NonZeroUsize::new(out_len)
1703 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1704
1705 // Validate: mel bins must be <= something sensible
1706 if mel.n_mels() > nzu!(10_000) {
1707 return Err(SpectrogramError::invalid_input(
1708 "n_mels is unreasonably large",
1709 ));
1710 }
1711
1712 let matrix = build_mel_filterbank_matrix(
1713 params.sample_rate_hz(),
1714 n_fft,
1715 mel.n_mels(),
1716 mel.f_min(),
1717 mel.f_max(),
1718 mel.norm(),
1719 )?;
1720
1721 // matrix must be (n_mels, out_len)
1722 if matrix.nrows() != mel.n_mels().get() || matrix.ncols() != out_len.get() {
1723 return Err(SpectrogramError::invalid_input(
1724 "mel filterbank matrix shape mismatch",
1725 ));
1726 }
1727
1728 Ok(Self {
1729 kind: MappingKind::Mel { matrix },
1730 _marker: PhantomData,
1731 })
1732 }
1733}
1734
1735impl FrequencyMapping<LogHz> {
1736 fn new_loghz(params: &SpectrogramParams, loghz: &LogHzParams) -> SpectrogramResult<Self> {
1737 let n_fft = params.stft().n_fft();
1738 let out_len = r2c_output_size(n_fft.get());
1739 let out_len = NonZeroUsize::new(out_len)
1740 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1741 // Validate: n_bins must be <= something sensible
1742 if loghz.n_bins() > nzu!(10_000) {
1743 return Err(SpectrogramError::invalid_input(
1744 "n_bins is unreasonably large",
1745 ));
1746 }
1747
1748 let (matrix, frequencies) = build_loghz_matrix(
1749 params.sample_rate_hz(),
1750 n_fft,
1751 loghz.n_bins(),
1752 loghz.f_min(),
1753 loghz.f_max(),
1754 )?;
1755
1756 // matrix must be (n_bins, out_len)
1757 if matrix.nrows() != loghz.n_bins().get() || matrix.ncols() != out_len.get() {
1758 return Err(SpectrogramError::invalid_input(
1759 "loghz matrix shape mismatch",
1760 ));
1761 }
1762
1763 Ok(Self {
1764 kind: MappingKind::LogHz {
1765 matrix,
1766 frequencies,
1767 },
1768 _marker: PhantomData,
1769 })
1770 }
1771}
1772
1773impl FrequencyMapping<Erb> {
1774 fn new_erb(params: &SpectrogramParams, erb: &crate::erb::ErbParams) -> SpectrogramResult<Self> {
1775 let n_fft = params.stft().n_fft();
1776 let sample_rate = params.sample_rate_hz();
1777
1778 // Validate: n_filters must be <= something sensible
1779 if erb.n_filters() > nzu!(10_000) {
1780 return Err(SpectrogramError::invalid_input(
1781 "n_filters is unreasonably large",
1782 ));
1783 }
1784
1785 // Generate ERB filterbank with pre-computed frequency responses
1786 let filterbank = crate::erb::ErbFilterbank::generate(erb, sample_rate, n_fft)?;
1787
1788 Ok(Self {
1789 kind: MappingKind::Erb { filterbank },
1790 _marker: PhantomData,
1791 })
1792 }
1793}
1794
1795impl FrequencyMapping<Cqt> {
1796 fn new(params: &SpectrogramParams, cqt: &CqtParams) -> SpectrogramResult<Self> {
1797 let sample_rate = params.sample_rate_hz();
1798 let n_fft = params.stft().n_fft();
1799
1800 // Validate that frequency range is reasonable
1801 let f_max = cqt.bin_frequency(cqt.num_bins().get().saturating_sub(1));
1802 if f_max >= sample_rate / 2.0 {
1803 return Err(SpectrogramError::invalid_input(
1804 "CQT maximum frequency must be below Nyquist frequency",
1805 ));
1806 }
1807
1808 // Generate CQT kernel using n_fft as the signal length for kernel generation
1809 let kernel = CqtKernel::generate(cqt, sample_rate, n_fft);
1810
1811 Ok(Self {
1812 kind: MappingKind::Cqt { kernel },
1813 _marker: PhantomData,
1814 })
1815 }
1816}
1817
1818impl<FreqScale> FrequencyMapping<FreqScale> {
1819 const fn output_bins(&self) -> NonZeroUsize {
1820 // safety: all variants ensure output bins > 0 OR rely on a matrix that is guaranteed to have rows > 0
1821 match &self.kind {
1822 MappingKind::Identity { out_len } => *out_len,
1823 // safety: matrix.nrows() > 0
1824 MappingKind::LogHz { matrix, .. } | MappingKind::Mel { matrix } => unsafe {
1825 NonZeroUsize::new_unchecked(matrix.nrows())
1826 },
1827 MappingKind::Erb { filterbank, .. } => filterbank.num_filters(),
1828 MappingKind::Cqt { kernel, .. } => kernel.num_bins(),
1829 }
1830 }
1831
1832 fn apply(
1833 &self,
1834 spectrum: &NonEmptySlice<f64>,
1835 frame: &NonEmptySlice<f64>,
1836 out: &mut NonEmptySlice<f64>,
1837 ) -> SpectrogramResult<()> {
1838 match &self.kind {
1839 MappingKind::Identity { out_len } => {
1840 if spectrum.len() != *out_len {
1841 return Err(SpectrogramError::dimension_mismatch(
1842 (*out_len).get(),
1843 spectrum.len().get(),
1844 ));
1845 }
1846 if out.len() != *out_len {
1847 return Err(SpectrogramError::dimension_mismatch(
1848 (*out_len).get(),
1849 out.len().get(),
1850 ));
1851 }
1852 out.copy_from_slice(spectrum);
1853 Ok(())
1854 }
1855 MappingKind::LogHz { matrix, .. } | MappingKind::Mel { matrix } => {
1856 let out_bins = matrix.nrows();
1857 let in_bins = matrix.ncols();
1858
1859 if spectrum.len().get() != in_bins {
1860 return Err(SpectrogramError::dimension_mismatch(
1861 in_bins,
1862 spectrum.len().get(),
1863 ));
1864 }
1865 if out.len().get() != out_bins {
1866 return Err(SpectrogramError::dimension_mismatch(
1867 out_bins,
1868 out.len().get(),
1869 ));
1870 }
1871
1872 // Sparse matrix-vector multiplication: out = matrix * spectrum
1873 matrix.multiply_vec(spectrum, out);
1874 Ok(())
1875 }
1876 MappingKind::Erb { filterbank } => {
1877 // Apply ERB filterbank using pre-computed frequency responses
1878 // The filterbank already has |H(f)|^2 pre-computed, so we just
1879 // apply it to the power spectrum
1880 let erb_out = filterbank.apply_to_power_spectrum(spectrum)?;
1881
1882 if out.len().get() != erb_out.len().get() {
1883 return Err(SpectrogramError::dimension_mismatch(
1884 erb_out.len().get(),
1885 out.len().get(),
1886 ));
1887 }
1888
1889 out.copy_from_slice(&erb_out);
1890 Ok(())
1891 }
1892 MappingKind::Cqt { kernel } => {
1893 // CQT works on time-domain unwindowed frame (not FFT spectrum).
1894 // The CQT kernels contain windowing applied during kernel generation,
1895 // so the input frame should be unwindowed to avoid double-windowing.
1896 let cqt_complex = kernel.apply(frame)?;
1897
1898 if out.len().get() != cqt_complex.len().get() {
1899 return Err(SpectrogramError::dimension_mismatch(
1900 cqt_complex.len().get(),
1901 out.len().get(),
1902 ));
1903 }
1904
1905 // Convert complex coefficients to power (|z|^2)
1906 // This matches the convention where intermediate values are in power domain
1907 for (i, c) in cqt_complex.iter().enumerate() {
1908 out[i] = c.norm_sqr();
1909 }
1910
1911 Ok(())
1912 }
1913 }
1914 }
1915
1916 fn frequencies_hz(&self, params: &SpectrogramParams) -> NonEmptyVec<f64> {
1917 match &self.kind {
1918 MappingKind::Identity { out_len } => {
1919 // Standard R2C bins: k * sr / n_fft
1920 let n_fft = params.stft().n_fft().get() as f64;
1921 let sr = params.sample_rate_hz();
1922 let df = sr / n_fft;
1923
1924 let mut f = Vec::with_capacity((*out_len).get());
1925 for k in 0..(*out_len).get() {
1926 f.push(k as f64 * df);
1927 }
1928 // safety: out_len > 0
1929 unsafe { NonEmptyVec::new_unchecked(f) }
1930 }
1931 MappingKind::Mel { matrix } => {
1932 // For mel, the axis is defined by the mel band centre frequencies.
1933 // We compute and store them consistently with how we built the filterbank.
1934 let n_mels = matrix.nrows();
1935 // safety: n_mels > 0
1936 let n_mels = unsafe { NonZeroUsize::new_unchecked(n_mels) };
1937 mel_band_centres_hz(n_mels, params.sample_rate_hz(), params.nyquist_hz())
1938 }
1939 MappingKind::LogHz { frequencies, .. } => {
1940 // Frequencies are stored when the mapping is created
1941 frequencies.clone()
1942 }
1943 MappingKind::Erb { filterbank, .. } => {
1944 // ERB center frequencies
1945 filterbank.center_frequencies().to_non_empty_vec()
1946 }
1947 MappingKind::Cqt { kernel, .. } => {
1948 // CQT center frequencies from the kernel
1949 kernel.frequencies().to_non_empty_vec()
1950 }
1951 }
1952 }
1953}
1954
1955//
1956// ========================
1957// Amplitude scaling
1958// ========================
1959//
1960
1961/// Marker trait so we can specialise behaviour by `AmpScale`.
1962pub trait AmpScaleSpec: Sized + Send + Sync {
1963 /// Apply conversion from power-domain value to the desired amplitude scale.
1964 ///
1965 /// # Arguments
1966 ///
1967 /// - `power`: input power-domain value.
1968 ///
1969 /// # Returns
1970 ///
1971 /// Converted amplitude value.
1972 fn apply_from_power(power: f64) -> f64;
1973
1974 /// Apply dB conversion in-place on a power-domain vector.
1975 ///
1976 /// This is a no-op for Power and Magnitude scales.
1977 ///
1978 /// # Arguments
1979 ///
1980 /// - `x`: power-domain values to convert to dB in-place.
1981 /// - `floor_db`: dB floor value to apply.
1982 ///
1983 /// # Returns
1984 ///
1985 /// `SpectrogramResult<()>`: Ok on success, error on invalid input.
1986 ///
1987 /// # Errors
1988 ///
1989 /// - If `floor_db` is not finite.
1990 fn apply_db_in_place(x: &mut [f64], floor_db: f64) -> SpectrogramResult<()>;
1991}
1992
1993impl AmpScaleSpec for Power {
1994 #[inline]
1995 fn apply_from_power(power: f64) -> f64 {
1996 power
1997 }
1998
1999 #[inline]
2000 fn apply_db_in_place(_x: &mut [f64], _floor_db: f64) -> SpectrogramResult<()> {
2001 Ok(())
2002 }
2003}
2004
2005impl AmpScaleSpec for Magnitude {
2006 #[inline]
2007 fn apply_from_power(power: f64) -> f64 {
2008 power.sqrt()
2009 }
2010
2011 #[inline]
2012 fn apply_db_in_place(_x: &mut [f64], _floor_db: f64) -> SpectrogramResult<()> {
2013 Ok(())
2014 }
2015}
2016
2017impl AmpScaleSpec for Decibels {
2018 #[inline]
2019 fn apply_from_power(power: f64) -> f64 {
2020 // dB conversion is applied in batch, not here.
2021 power
2022 }
2023
2024 #[inline]
2025 fn apply_db_in_place(x: &mut [f64], floor_db: f64) -> SpectrogramResult<()> {
2026 // Convert power -> dB: 10*log10(max(power, eps))
2027 // where eps is derived from floor_db to ensure consistency
2028 if !floor_db.is_finite() {
2029 return Err(SpectrogramError::invalid_input("floor_db must be finite"));
2030 }
2031
2032 // Convert floor_db to linear scale to get epsilon
2033 // e.g., floor_db = -80 dB -> eps = 10^(-80/10) = 1e-8
2034 let eps = 10.0_f64.powf(floor_db / 10.0);
2035
2036 for v in x.iter_mut() {
2037 // Clamp power to epsilon before log to avoid log(0) and ensure floor
2038 *v = 10.0 * v.max(eps).log10();
2039 }
2040 Ok(())
2041 }
2042}
2043
2044/// Amplitude scaling configuration.
2045///
2046/// This handles conversion from power-domain intermediate to the desired amplitude scale (Power, Magnitude, Decibels).
2047#[derive(Debug, Clone)]
2048struct AmplitudeScaling<AmpScale> {
2049 db_floor: Option<f64>,
2050 _marker: PhantomData<AmpScale>,
2051}
2052
2053impl<AmpScale> AmplitudeScaling<AmpScale>
2054where
2055 AmpScale: AmpScaleSpec + 'static,
2056{
2057 fn new(db: Option<&LogParams>) -> Self {
2058 let db_floor = db.map(LogParams::floor_db);
2059 Self {
2060 db_floor,
2061 _marker: PhantomData,
2062 }
2063 }
2064
2065 /// Apply amplitude scaling in-place on a mapped spectrum vector.
2066 ///
2067 /// The input vector is assumed to be in the *power* domain (|X|^2),
2068 /// because the STFT stage produces power as the canonical intermediate.
2069 ///
2070 /// - Power: leaves values unchanged.
2071 /// - Magnitude: sqrt(power).
2072 /// - Decibels: converts power -> dB and floors at `db_floor`.
2073 pub fn apply_in_place(&self, x: &mut [f64]) -> SpectrogramResult<()> {
2074 // Convert from canonical power-domain intermediate into the requested linear domain.
2075 for v in x.iter_mut() {
2076 *v = AmpScale::apply_from_power(*v);
2077 }
2078
2079 // Apply dB conversion if configured (no-op for Power/Magnitude via trait impls).
2080 if let Some(floor_db) = self.db_floor {
2081 AmpScale::apply_db_in_place(x, floor_db)?;
2082 }
2083
2084 Ok(())
2085 }
2086}
2087
2088#[derive(Debug, Clone)]
2089struct Workspace {
2090 spectrum: NonEmptyVec<f64>, // out_len (power spectrum)
2091 mapped: NonEmptyVec<f64>, // n_bins (after mapping)
2092 frame: NonEmptyVec<f64>, // n_fft (windowed frame for FFT)
2093 fft_out: NonEmptyVec<Complex<f64>>, // out_len (FFT output)
2094}
2095
2096impl Workspace {
2097 fn new(n_fft: NonZeroUsize, out_len: NonZeroUsize, n_bins: NonZeroUsize) -> Self {
2098 Self {
2099 spectrum: non_empty_vec![0.0; out_len],
2100 mapped: non_empty_vec![0.0; n_bins],
2101 frame: non_empty_vec![0.0; n_fft],
2102 fft_out: non_empty_vec![Complex::new(0.0, 0.0); out_len],
2103 }
2104 }
2105
2106 fn ensure_sizes(&mut self, n_fft: NonZeroUsize, out_len: NonZeroUsize, n_bins: NonZeroUsize) {
2107 if self.spectrum.len() != out_len {
2108 self.spectrum.resize(out_len, 0.0);
2109 }
2110 if self.mapped.len() != n_bins {
2111 self.mapped.resize(n_bins, 0.0);
2112 }
2113 if self.frame.len() != n_fft {
2114 self.frame.resize(n_fft, 0.0);
2115 }
2116 if self.fft_out.len() != out_len {
2117 self.fft_out.resize(out_len, Complex::new(0.0, 0.0));
2118 }
2119 }
2120}
2121
2122fn build_frequency_axis<FreqScale>(
2123 params: &SpectrogramParams,
2124 mapping: &FrequencyMapping<FreqScale>,
2125) -> FrequencyAxis<FreqScale>
2126where
2127 FreqScale: Copy + Clone + 'static,
2128{
2129 let frequencies = mapping.frequencies_hz(params);
2130 FrequencyAxis::new(frequencies)
2131}
2132
2133fn build_time_axis_seconds(params: &SpectrogramParams, n_frames: NonZeroUsize) -> NonEmptyVec<f64> {
2134 let dt = params.frame_period_seconds();
2135 let mut times = Vec::with_capacity(n_frames.get());
2136
2137 for i in 0..n_frames.get() {
2138 times.push(i as f64 * dt);
2139 }
2140
2141 // safety: times is guaranteed non-empty since n_frames > 0
2142
2143 unsafe { NonEmptyVec::new_unchecked(times) }
2144}
2145
2146/// Generate window function samples.
2147///
2148/// Supports various window types including Rectangular, Hanning, Hamming, Blackman, Kaiser, and Gaussian.
2149///
2150/// # Arguments
2151///
2152/// * `window` - The type of window function to generate.
2153/// * `n_fft` - The size of the FFT, which determines the length of the window.
2154///
2155/// # Returns
2156///
2157/// A `NonEmptyVec<f64>` containing the window function samples.
2158///
2159/// # Panics
2160///
2161/// Panics if a custom window is provided with a size that does not match `n_fft`.
2162#[inline]
2163#[must_use]
2164pub fn make_window(window: WindowType, n_fft: NonZeroUsize) -> NonEmptyVec<f64> {
2165 let n_fft = n_fft.get();
2166 let mut w = vec![0.0; n_fft];
2167
2168 match window {
2169 WindowType::Rectangular => {
2170 w.fill(1.0);
2171 }
2172 WindowType::Hanning => {
2173 // Hann: 0.5 - 0.5*cos(2Ï€n/(N-1))
2174 let n1 = (n_fft - 1) as f64;
2175 for (n, v) in w.iter_mut().enumerate() {
2176 *v = 0.5f64.mul_add(-(2.0 * std::f64::consts::PI * (n as f64) / n1).cos(), 0.5);
2177 }
2178 }
2179 WindowType::Hamming => {
2180 // Hamming: 0.54 - 0.46*cos(2Ï€n/(N-1))
2181 let n1 = (n_fft - 1) as f64;
2182 for (n, v) in w.iter_mut().enumerate() {
2183 *v = 0.46f64.mul_add(-(2.0 * std::f64::consts::PI * (n as f64) / n1).cos(), 0.54);
2184 }
2185 }
2186 WindowType::Blackman => {
2187 // Blackman: 0.42 - 0.5*cos(2Ï€n/(N-1)) + 0.08*cos(4Ï€n/(N-1))
2188 let n1 = (n_fft - 1) as f64;
2189 for (n, v) in w.iter_mut().enumerate() {
2190 let a = 2.0 * std::f64::consts::PI * (n as f64) / n1;
2191 *v = 0.08f64.mul_add((2.0 * a).cos(), 0.5f64.mul_add(-a.cos(), 0.42));
2192 }
2193 }
2194 WindowType::Kaiser { beta } => {
2195 if n_fft == 1 {
2196 w[0] = 1.0;
2197 } else {
2198 let denom = modified_bessel_i0(beta);
2199 let n_max = (n_fft - 1) as f64 / 2.0;
2200
2201 for (i, value) in w.iter_mut().enumerate() {
2202 let n = i as f64 - n_max;
2203 let ratio = if n_max == 0.0 {
2204 0.0
2205 } else {
2206 let normalized = n / n_max;
2207 (1.0 - normalized * normalized).max(0.0)
2208 };
2209 let arg = beta * ratio.sqrt();
2210 *value = if denom == 0.0 {
2211 0.0
2212 } else {
2213 modified_bessel_i0(arg) / denom
2214 };
2215 }
2216 }
2217 }
2218 WindowType::Gaussian { std } => (0..n_fft).for_each(|i| {
2219 let n = i as f64;
2220 let center: f64 = (n_fft - 1) as f64 / 2.0;
2221 let exponent: f64 = -0.5 * ((n - center) / std).powi(2);
2222 w[i] = exponent.exp();
2223 }),
2224 WindowType::Custom { coefficients, size } => {
2225 assert!(
2226 size.get() == n_fft,
2227 "Custom window size mismatch: expected {}, got {}. \
2228 Custom windows must be pre-computed with the exact FFT size.",
2229 n_fft,
2230 size.get()
2231 );
2232 w.copy_from_slice(&coefficients);
2233 }
2234 }
2235
2236 // safety: window is guaranteed non-empty since n_fft > 0
2237 unsafe { NonEmptyVec::new_unchecked(w) }
2238}
2239
2240fn modified_bessel_i0(x: f64) -> f64 {
2241 let ax = x.abs();
2242 if ax <= 3.75 {
2243 let t = x / 3.75;
2244 let t2 = t * t;
2245 1.0 + t2
2246 * (3.515_622_9
2247 + t2 * (3.089_942_4
2248 + t2 * (1.206_749_2
2249 + t2 * (0.265_973_2 + t2 * (0.036_076_8 + t2 * 0.004_581_3)))))
2250 } else {
2251 let t = 3.75 / ax;
2252 let poly = 0.398_942_28
2253 + t * (0.013_285_92
2254 + t * (0.002_253_19
2255 + t * (-0.001_575_65
2256 + t * (0.009_162_81
2257 + t * (-0.020_577_06
2258 + t * (0.026_355_37 + t * (-0.016_476_33 + t * 0.003_923_77)))))));
2259
2260 (ax.exp() / (ax.sqrt() * (2.0 * std::f64::consts::PI).sqrt())) * poly
2261 }
2262}
2263
2264/// Convert Hz to mel scale using Slaney formula (librosa default, htk=False).
2265///
2266/// Uses a hybrid scale:
2267/// - Linear below 1000 Hz: mel = hz / (200/3)
2268/// - Logarithmic above 1000 Hz: mel = 15 + log(hz/1000) / log_step
2269///
2270/// This matches librosa's default behavior.
2271fn hz_to_mel(hz: f64) -> f64 {
2272 const F_MIN: f64 = 0.0;
2273 const F_SP: f64 = 200.0 / 3.0; // ~66.667
2274 const MIN_LOG_HZ: f64 = 1000.0;
2275 const MIN_LOG_MEL: f64 = (MIN_LOG_HZ - F_MIN) / F_SP; // = 15.0
2276 const LOGSTEP: f64 = 0.068_751_777_420_949_23; // ln(6.4) / 27
2277 if hz >= MIN_LOG_HZ {
2278 // Logarithmic region
2279 MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / LOGSTEP
2280 } else {
2281 // Linear region
2282 (hz - F_MIN) / F_SP
2283 }
2284}
2285
2286/// Convert mel to Hz using Slaney formula (librosa default, htk=False).
2287///
2288/// Inverse of hz_to_mel.
2289fn mel_to_hz(mel: f64) -> f64 {
2290 const F_MIN: f64 = 0.0;
2291 const F_SP: f64 = 200.0 / 3.0; // ~66.667
2292 const MIN_LOG_HZ: f64 = 1000.0;
2293 const MIN_LOG_MEL: f64 = (MIN_LOG_HZ - F_MIN) / F_SP; // = 15.0
2294 const LOGSTEP: f64 = 0.068_751_777_420_949_23; // ln(6.4) / 27
2295
2296 if mel >= MIN_LOG_MEL {
2297 // Logarithmic region
2298 MIN_LOG_HZ * (LOGSTEP * (mel - MIN_LOG_MEL)).exp()
2299 } else {
2300 // Linear region
2301 F_SP.mul_add(mel, F_MIN)
2302 }
2303}
2304
2305fn build_mel_filterbank_matrix(
2306 sample_rate_hz: f64,
2307 n_fft: NonZeroUsize,
2308 n_mels: NonZeroUsize,
2309 f_min: f64,
2310 f_max: f64,
2311 norm: MelNorm,
2312) -> SpectrogramResult<SparseMatrix> {
2313 if sample_rate_hz <= 0.0 || !sample_rate_hz.is_finite() {
2314 return Err(SpectrogramError::invalid_input(
2315 "sample_rate_hz must be finite and > 0",
2316 ));
2317 }
2318 if f_min < 0.0 || f_min.is_infinite() {
2319 return Err(SpectrogramError::invalid_input("f_min must be >= 0"));
2320 }
2321 if f_max <= f_min {
2322 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
2323 }
2324 if f_max > sample_rate_hz * 0.5 {
2325 return Err(SpectrogramError::invalid_input("f_max must be <= Nyquist"));
2326 }
2327 let n_mels = n_mels.get();
2328 let n_fft = n_fft.get();
2329 let out_len = r2c_output_size(n_fft);
2330
2331 // FFT bin frequencies
2332 let df = sample_rate_hz / n_fft as f64;
2333
2334 // Mel points: n_mels + 2 (for triangular edges)
2335 let mel_min = hz_to_mel(f_min);
2336 let mel_max = hz_to_mel(f_max);
2337
2338 let n_points = n_mels + 2;
2339 let step = (mel_max - mel_min) / (n_points - 1) as f64;
2340
2341 let mut mel_points = Vec::with_capacity(n_points);
2342 for i in 0..n_points {
2343 mel_points.push((i as f64).mul_add(step, mel_min));
2344 }
2345
2346 let mut hz_points = Vec::with_capacity(n_points);
2347 for m in &mel_points {
2348 hz_points.push(mel_to_hz(*m));
2349 }
2350
2351 // Build filterbank as sparse matrix (librosa-style, in frequency space)
2352 // This builds triangular filters based on actual frequencies, not bin indices
2353 let mut fb = SparseMatrix::new(n_mels, out_len);
2354
2355 for m in 0..n_mels {
2356 let freq_left = hz_points[m];
2357 let freq_center = hz_points[m + 1];
2358 let freq_right = hz_points[m + 2];
2359
2360 let fdiff_left = freq_center - freq_left;
2361 let fdiff_right = freq_right - freq_center;
2362
2363 if fdiff_left == 0.0 || fdiff_right == 0.0 {
2364 // Degenerate triangle, skip
2365 continue;
2366 }
2367
2368 // For each FFT bin, compute the triangular weight based on its frequency
2369 for k in 0..out_len {
2370 let bin_freq = k as f64 * df;
2371
2372 // Lower slope: rises from freq_left to freq_center
2373 let lower = (bin_freq - freq_left) / fdiff_left;
2374
2375 // Upper slope: falls from freq_center to freq_right
2376 let upper = (freq_right - bin_freq) / fdiff_right;
2377
2378 // Triangle is the minimum of the two slopes, clipped to [0, 1]
2379 let weight = lower.min(upper).clamp(0.0, 1.0);
2380
2381 if weight > 0.0 {
2382 fb.set(m, k, weight);
2383 }
2384 }
2385 }
2386
2387 // Apply normalization
2388 match norm {
2389 MelNorm::None => {
2390 // No normalization needed
2391 }
2392 MelNorm::Slaney => {
2393 // Slaney-style area normalization: 2 / (hz_max - hz_min) for each triangle
2394 // NOTE: Uses Hz bandwidth, not mel bandwidth (to match librosa's implementation)
2395 for m in 0..n_mels {
2396 let mel_left = mel_points[m];
2397 let mel_right = mel_points[m + 2];
2398 let hz_left = mel_to_hz(mel_left);
2399 let hz_right = mel_to_hz(mel_right);
2400 let enorm = 2.0 / (hz_right - hz_left);
2401
2402 // Normalize all values in this row
2403 for val in &mut fb.values[m] {
2404 *val *= enorm;
2405 }
2406 }
2407 }
2408 MelNorm::L1 => {
2409 // L1 normalization: sum of weights = 1.0
2410 for m in 0..n_mels {
2411 let sum: f64 = fb.values[m].iter().sum();
2412 if sum > 0.0 {
2413 let normalizer = 1.0 / sum;
2414 for val in &mut fb.values[m] {
2415 *val *= normalizer;
2416 }
2417 }
2418 }
2419 }
2420 MelNorm::L2 => {
2421 // L2 normalization: L2 norm = 1.0
2422 for m in 0..n_mels {
2423 let norm_val: f64 = fb.values[m].iter().map(|&v| v * v).sum::<f64>().sqrt();
2424 if norm_val > 0.0 {
2425 let normalizer = 1.0 / norm_val;
2426 for val in &mut fb.values[m] {
2427 *val *= normalizer;
2428 }
2429 }
2430 }
2431 }
2432 }
2433
2434 Ok(fb)
2435}
2436
2437/// Build a logarithmic frequency interpolation matrix.
2438///
2439/// Maps linearly-spaced FFT bins to logarithmically-spaced frequency bins
2440/// using linear interpolation.
2441fn build_loghz_matrix(
2442 sample_rate_hz: f64,
2443 n_fft: NonZeroUsize,
2444 n_bins: NonZeroUsize,
2445 f_min: f64,
2446 f_max: f64,
2447) -> SpectrogramResult<(SparseMatrix, NonEmptyVec<f64>)> {
2448 if sample_rate_hz <= 0.0 || !sample_rate_hz.is_finite() {
2449 return Err(SpectrogramError::invalid_input(
2450 "sample_rate_hz must be finite and > 0",
2451 ));
2452 }
2453 if f_min <= 0.0 || f_min.is_infinite() {
2454 return Err(SpectrogramError::invalid_input(
2455 "f_min must be finite and > 0",
2456 ));
2457 }
2458 if f_max <= f_min {
2459 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
2460 }
2461 if f_max > sample_rate_hz * 0.5 {
2462 return Err(SpectrogramError::invalid_input("f_max must be <= Nyquist"));
2463 }
2464
2465 let n_bins = n_bins.get();
2466 let n_fft = n_fft.get();
2467
2468 let out_len = r2c_output_size(n_fft);
2469 let df = sample_rate_hz / n_fft as f64;
2470
2471 // Generate logarithmically-spaced frequencies
2472 let log_f_min = f_min.ln();
2473 let log_f_max = f_max.ln();
2474 let log_step = (log_f_max - log_f_min) / (n_bins - 1) as f64;
2475
2476 let mut log_frequencies = Vec::with_capacity(n_bins);
2477 for i in 0..n_bins {
2478 let log_f = (i as f64).mul_add(log_step, log_f_min);
2479 log_frequencies.push(log_f.exp());
2480 }
2481 // safety: n_bins > 0
2482 let log_frequencies = unsafe { NonEmptyVec::new_unchecked(log_frequencies) };
2483
2484 // Build interpolation matrix as sparse matrix
2485 let mut matrix = SparseMatrix::new(n_bins, out_len);
2486
2487 for (bin_idx, &target_freq) in log_frequencies.iter().enumerate() {
2488 // Find the two FFT bins that bracket this frequency
2489 let exact_bin = target_freq / df;
2490 let lower_bin = exact_bin.floor() as usize;
2491 let upper_bin = (exact_bin.ceil() as usize).min(out_len - 1);
2492
2493 if lower_bin >= out_len {
2494 continue;
2495 }
2496
2497 if lower_bin == upper_bin {
2498 // Exact match
2499 matrix.set(bin_idx, lower_bin, 1.0);
2500 } else {
2501 // Linear interpolation
2502 let frac = exact_bin - lower_bin as f64;
2503 matrix.set(bin_idx, lower_bin, 1.0 - frac);
2504 if upper_bin < out_len {
2505 matrix.set(bin_idx, upper_bin, frac);
2506 }
2507 }
2508 }
2509
2510 Ok((matrix, log_frequencies))
2511}
2512
2513fn mel_band_centres_hz(
2514 n_mels: NonZeroUsize,
2515 sample_rate_hz: f64,
2516 nyquist_hz: f64,
2517) -> NonEmptyVec<f64> {
2518 let f_min = 0.0;
2519 let f_max = nyquist_hz.min(sample_rate_hz * 0.5);
2520
2521 let mel_min = hz_to_mel(f_min);
2522 let mel_max = hz_to_mel(f_max);
2523 let n_mels = n_mels.get();
2524 let step = (mel_max - mel_min) / (n_mels + 1) as f64;
2525
2526 let mut centres = Vec::with_capacity(n_mels);
2527 for i in 0..n_mels {
2528 let mel = (i as f64 + 1.0).mul_add(step, mel_min);
2529 centres.push(mel_to_hz(mel));
2530 }
2531 // safety: centres is guaranteed non-empty since n_mels > 0
2532 unsafe { NonEmptyVec::new_unchecked(centres) }
2533}
2534
2535/// Spectrogram structure holding the computed spectrogram data and metadata.
2536///
2537/// # Type Parameters
2538///
2539/// * `FreqScale`: The frequency scale type (e.g., `LinearHz`, `Mel`, `LogHz`, etc.).
2540/// * `AmpScale`: The amplitude scale type (e.g., `Power`, `Magnitude`, `Decibels`).
2541///
2542/// # Fields
2543///
2544/// * `data`: A 2D array containing the spectrogram data.
2545/// * `axes`: The axes of the spectrogram (frequency and time).
2546/// * `params`: The parameters used to compute the spectrogram.
2547/// * `_amp`: A phantom data marker for the amplitude scale type.
2548#[derive(Debug, Clone)]
2549#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
2550pub struct Spectrogram<FreqScale, AmpScale>
2551where
2552 AmpScale: AmpScaleSpec + 'static,
2553 FreqScale: Copy + Clone + 'static,
2554{
2555 data: Array2<f64>,
2556 axes: Axes<FreqScale>,
2557 params: SpectrogramParams,
2558 #[cfg_attr(feature = "serde", serde(skip))]
2559 _amp: PhantomData<AmpScale>,
2560}
2561
2562impl<FreqScale, AmpScale> Spectrogram<FreqScale, AmpScale>
2563where
2564 AmpScale: AmpScaleSpec + 'static,
2565 FreqScale: Copy + Clone + 'static,
2566{
2567 /// Get the X-axis label.
2568 ///
2569 /// # Returns
2570 ///
2571 /// A static string slice representing the X-axis label.
2572 #[inline]
2573 #[must_use]
2574 pub const fn x_axis_label() -> &'static str {
2575 "Time (s)"
2576 }
2577
2578 /// Get the Y-axis label based on the frequency scale type.
2579 ///
2580 /// # Returns
2581 ///
2582 /// A static string slice representing the Y-axis label.
2583 #[inline]
2584 #[must_use]
2585 pub fn y_axis_label() -> &'static str {
2586 match std::any::TypeId::of::<FreqScale>() {
2587 id if id == std::any::TypeId::of::<LinearHz>() => "Frequency (Hz)",
2588 id if id == std::any::TypeId::of::<Mel>() => "Frequency (Mel)",
2589 id if id == std::any::TypeId::of::<LogHz>() => "Frequency (Log Hz)",
2590 id if id == std::any::TypeId::of::<Erb>() => "Frequency (ERB)",
2591 id if id == std::any::TypeId::of::<Cqt>() => "Frequency (CQT Bins)",
2592 _ => "Frequency",
2593 }
2594 }
2595
2596 /// Internal constructor. Only callable inside the crate.
2597 ///
2598 /// All inputs must already be validated and consistent.
2599 pub(crate) fn new(data: Array2<f64>, axes: Axes<FreqScale>, params: SpectrogramParams) -> Self {
2600 debug_assert_eq!(data.nrows(), axes.frequencies().len().get());
2601 debug_assert_eq!(data.ncols(), axes.times().len().get());
2602
2603 Self {
2604 data,
2605 axes,
2606 params,
2607 _amp: PhantomData,
2608 }
2609 }
2610
2611 /// Set spectrogram data matrix
2612 ///
2613 /// # Arguments
2614 ///
2615 /// * `data` - The new spectrogram data matrix.
2616 #[inline]
2617 pub fn set_data(&mut self, data: Array2<f64>) {
2618 self.data = data;
2619 }
2620
2621 /// Spectrogram data matrix
2622 ///
2623 /// # Returns
2624 ///
2625 /// A reference to the spectrogram data matrix.
2626 #[inline]
2627 #[must_use]
2628 pub const fn data(&self) -> &Array2<f64> {
2629 &self.data
2630 }
2631
2632 /// Consume the spectrogram and extract the data matrix.
2633 ///
2634 /// This method moves the data out of the spectrogram, consuming it.
2635 /// Useful for transferring ownership to Python without copying.
2636 ///
2637 /// # Returns
2638 ///
2639 /// The owned spectrogram data matrix.
2640 #[inline]
2641 #[must_use]
2642 pub fn into_data(self) -> Array2<f64> {
2643 self.data
2644 }
2645
2646 /// Axes of the spectrogram
2647 ///
2648 /// # Returns
2649 ///
2650 /// A reference to the axes of the spectrogram.
2651 #[inline]
2652 #[must_use]
2653 pub const fn axes(&self) -> &Axes<FreqScale> {
2654 &self.axes
2655 }
2656
2657 /// Frequency axis in Hz
2658 ///
2659 /// # Returns
2660 ///
2661 /// A reference to the frequency axis in Hz.
2662 #[inline]
2663 #[must_use]
2664 pub fn frequencies(&self) -> &NonEmptySlice<f64> {
2665 self.axes.frequencies()
2666 }
2667
2668 /// Frequency range in Hz (min, max)
2669 ///
2670 /// # Returns
2671 ///
2672 /// A tuple containing the minimum and maximum frequencies in Hz.
2673 #[inline]
2674 #[must_use]
2675 pub const fn frequency_range(&self) -> (f64, f64) {
2676 self.axes.frequency_range()
2677 }
2678
2679 /// Time axis in seconds
2680 ///
2681 /// # Returns
2682 ///
2683 /// A reference to the time axis in seconds.
2684 #[inline]
2685 #[must_use]
2686 pub fn times(&self) -> &NonEmptySlice<f64> {
2687 self.axes.times()
2688 }
2689
2690 /// Spectrogram computation parameters
2691 ///
2692 /// # Returns
2693 ///
2694 /// A reference to the spectrogram computation parameters.
2695 #[inline]
2696 #[must_use]
2697 pub const fn params(&self) -> &SpectrogramParams {
2698 &self.params
2699 }
2700
2701 /// Duration of the spectrogram in seconds
2702 ///
2703 /// # Returns
2704 ///
2705 /// The duration of the spectrogram in seconds.
2706 #[inline]
2707 #[must_use]
2708 pub fn duration(&self) -> f64 {
2709 self.axes.duration()
2710 }
2711
2712 /// If this is a dB spectrogram, return the (min, max) dB values. otherwise do the maths to compute dB range.
2713 ///
2714 /// # Returns
2715 ///
2716 /// The (min, max) dB values of the spectrogram, or `None` if the amplitude scale is unknown.
2717 #[inline]
2718 #[must_use]
2719 pub fn db_range(&self) -> Option<(f64, f64)> {
2720 let type_self = std::any::TypeId::of::<AmpScale>();
2721
2722 if type_self == std::any::TypeId::of::<Decibels>() {
2723 let (min, max) = min_max_single_pass(self.data.as_slice()?);
2724 Some((min, max))
2725 } else if type_self == std::any::TypeId::of::<Power>() {
2726 // Not a dB spectrogram; compute dB range from power values
2727 let mut min_db = f64::INFINITY;
2728 let mut max_db = f64::NEG_INFINITY;
2729 for &v in &self.data {
2730 let db = 10.0 * (v + EPS).log10();
2731 if db < min_db {
2732 min_db = db;
2733 }
2734 if db > max_db {
2735 max_db = db;
2736 }
2737 }
2738 Some((min_db, max_db))
2739 } else if type_self == std::any::TypeId::of::<Magnitude>() {
2740 // Not a dB spectrogram; compute dB range from magnitude values
2741 let mut min_db = f64::INFINITY;
2742 let mut max_db = f64::NEG_INFINITY;
2743
2744 for &v in &self.data {
2745 let power = v * v;
2746 let db = 10.0 * (power + EPS).log10();
2747 if db < min_db {
2748 min_db = db;
2749 }
2750 if db > max_db {
2751 max_db = db;
2752 }
2753 }
2754
2755 Some((min_db, max_db))
2756 } else {
2757 // Unknown AmpScale type; return dummy values
2758 None
2759 }
2760 }
2761
2762 /// Number of frequency bins
2763 ///
2764 /// # Returns
2765 ///
2766 /// The number of frequency bins in the spectrogram.
2767 #[inline]
2768 #[must_use]
2769 pub fn n_bins(&self) -> NonZeroUsize {
2770 // safety: data.nrows() > 0 is guaranteed by construction
2771 unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
2772 }
2773
2774 /// Number of time frames in the spectrogram
2775 ///
2776 /// # Returns
2777 ///
2778 /// The number of time frames (columns) in the spectrogram.
2779 #[inline]
2780 #[must_use]
2781 pub fn n_frames(&self) -> NonZeroUsize {
2782 // safety: data.ncols() > 0 is guaranteed by construction
2783 unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
2784 }
2785}
2786
2787impl<FreqScale, AmpScale> AsRef<Array2<f64>> for Spectrogram<FreqScale, AmpScale>
2788where
2789 FreqScale: Copy + Clone + 'static,
2790 AmpScale: AmpScaleSpec + 'static,
2791{
2792 #[inline]
2793 fn as_ref(&self) -> &Array2<f64> {
2794 &self.data
2795 }
2796}
2797
2798impl<FreqScale, AmpScale> Deref for Spectrogram<FreqScale, AmpScale>
2799where
2800 FreqScale: Copy + Clone + 'static,
2801 AmpScale: AmpScaleSpec + 'static,
2802{
2803 type Target = Array2<f64>;
2804
2805 #[inline]
2806 fn deref(&self) -> &Self::Target {
2807 &self.data
2808 }
2809}
2810
2811impl<AmpScale> Spectrogram<LinearHz, AmpScale>
2812where
2813 AmpScale: AmpScaleSpec + 'static,
2814{
2815 /// Compute a linear-frequency spectrogram from audio samples.
2816 ///
2817 /// This is a convenience method that creates a planner internally and computes
2818 /// the spectrogram in one call. For processing multiple signals with the same
2819 /// parameters, use [`SpectrogramPlanner::linear_plan`] to create a reusable plan.
2820 ///
2821 /// # Arguments
2822 ///
2823 /// * `samples` - Audio samples (any type that can be converted to a slice)
2824 /// * `params` - Spectrogram computation parameters
2825 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2826 ///
2827 /// # Returns
2828 ///
2829 /// A linear-frequency spectrogram with the specified amplitude scale.
2830 ///
2831 /// # Errors
2832 ///
2833 /// Returns an error if:
2834 /// - The samples slice is empty
2835 /// - Parameters are invalid
2836 /// - FFT computation fails
2837 ///
2838 /// # Examples
2839 ///
2840 /// ```
2841 /// use spectrograms::*;
2842 /// use non_empty_slice::non_empty_vec;
2843 ///
2844 /// # fn example() -> SpectrogramResult<()> {
2845 /// // Create a simple test signal
2846 /// let sample_rate = 16000.0;
2847 /// let samples_vec: Vec<f64> = (0..16000).map(|i| {
2848 /// (2.0 * std::f64::consts::PI * 440.0 * i as f64 / sample_rate).sin()
2849 /// }).collect();
2850 /// let samples = non_empty_slice::NonEmptyVec::new(samples_vec).unwrap();
2851 ///
2852 /// // Set up parameters
2853 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2854 /// let params = SpectrogramParams::new(stft, sample_rate)?;
2855 ///
2856 /// // Compute power spectrogram
2857 /// let spec = LinearPowerSpectrogram::compute(&samples, ¶ms, None)?;
2858 ///
2859 /// println!("Computed spectrogram: {} bins x {} frames", spec.n_bins(), spec.n_frames());
2860 /// # Ok(())
2861 /// # }
2862 /// ```
2863 #[inline]
2864 pub fn compute(
2865 samples: &NonEmptySlice<f64>,
2866 params: &SpectrogramParams,
2867 db: Option<&LogParams>,
2868 ) -> SpectrogramResult<Self> {
2869 let planner = SpectrogramPlanner::new();
2870 let mut plan = planner.linear_plan(params, db)?;
2871 plan.compute(samples)
2872 }
2873}
2874
2875impl<AmpScale> Spectrogram<Mel, AmpScale>
2876where
2877 AmpScale: AmpScaleSpec + 'static,
2878{
2879 /// Compute a mel-frequency spectrogram from audio samples.
2880 ///
2881 /// This is a convenience method that creates a planner internally and computes
2882 /// the spectrogram in one call. For processing multiple signals with the same
2883 /// parameters, use [`SpectrogramPlanner::mel_plan`] to create a reusable plan.
2884 ///
2885 /// # Arguments
2886 ///
2887 /// * `samples` - Audio samples (any type that can be converted to a slice)
2888 /// * `params` - Spectrogram computation parameters
2889 /// * `mel` - Mel filterbank parameters
2890 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2891 ///
2892 /// # Returns
2893 ///
2894 /// A mel-frequency spectrogram with the specified amplitude scale.
2895 ///
2896 /// # Errors
2897 ///
2898 /// Returns an error if:
2899 /// - The samples slice is empty
2900 /// - Parameters are invalid
2901 /// - Mel `f_max` exceeds Nyquist frequency
2902 /// - FFT computation fails
2903 ///
2904 /// # Examples
2905 ///
2906 /// ```
2907 /// use spectrograms::*;
2908 /// use non_empty_slice::non_empty_vec;
2909 ///
2910 /// # fn example() -> SpectrogramResult<()> {
2911 /// // Create a simple test signal
2912 /// let sample_rate = 16000.0;
2913 /// let samples_vec: Vec<f64> = (0..16000).map(|i| {
2914 /// (2.0 * std::f64::consts::PI * 440.0 * i as f64 / sample_rate).sin()
2915 /// }).collect();
2916 /// let samples = non_empty_slice::NonEmptyVec::new(samples_vec).unwrap();
2917 ///
2918 /// // Set up parameters
2919 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2920 /// let params = SpectrogramParams::new(stft, sample_rate)?;
2921 /// let mel = MelParams::new(nzu!(80), 0.0, 8000.0)?;
2922 ///
2923 /// // Compute mel spectrogram in dB scale
2924 /// let db = LogParams::new(-80.0)?;
2925 /// let spec = MelDbSpectrogram::compute(&samples, ¶ms, &mel, Some(&db))?;
2926 ///
2927 /// println!("Computed mel spectrogram: {} mels x {} frames", spec.n_bins(), spec.n_frames());
2928 /// # Ok(())
2929 /// # }
2930 /// ```
2931 #[inline]
2932 pub fn compute(
2933 samples: &NonEmptySlice<f64>,
2934 params: &SpectrogramParams,
2935 mel: &MelParams,
2936 db: Option<&LogParams>,
2937 ) -> SpectrogramResult<Self> {
2938 let planner = SpectrogramPlanner::new();
2939 let mut plan = planner.mel_plan(params, mel, db)?;
2940 plan.compute(samples)
2941 }
2942}
2943
2944impl<AmpScale> Spectrogram<Erb, AmpScale>
2945where
2946 AmpScale: AmpScaleSpec + 'static,
2947{
2948 /// Compute an ERB-frequency spectrogram from audio samples.
2949 ///
2950 /// This is a convenience method that creates a planner internally and computes
2951 /// the spectrogram in one call. For processing multiple signals with the same
2952 /// parameters, use [`SpectrogramPlanner::erb_plan`] to create a reusable plan.
2953 ///
2954 /// # Arguments
2955 ///
2956 /// * `samples` - Audio samples (any type that can be converted to a slice)
2957 /// * `params` - Spectrogram computation parameters
2958 /// * `erb` - ERB frequency scale parameters
2959 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2960 ///
2961 /// # Returns
2962 ///
2963 /// An ERB-scale spectrogram with the specified amplitude scale.
2964 ///
2965 /// # Errors
2966 ///
2967 /// Returns an error if:
2968 /// - The samples slice is empty
2969 /// - Parameters are invalid
2970 /// - FFT computation fails
2971 ///
2972 /// # Examples
2973 ///
2974 /// ```
2975 /// use spectrograms::*;
2976 /// use non_empty_slice::non_empty_vec;
2977 ///
2978 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2979 /// let samples = non_empty_vec![0.0; nzu!(16000)];
2980 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2981 /// let params = SpectrogramParams::new(stft, 16000.0)?;
2982 /// let erb = ErbParams::speech_standard();
2983 ///
2984 /// let spec = ErbPowerSpectrogram::compute(&samples, ¶ms, &erb, None)?;
2985 /// assert_eq!(spec.n_bins(), nzu!(40));
2986 /// # Ok(())
2987 /// # }
2988 /// ```
2989 #[inline]
2990 pub fn compute(
2991 samples: &NonEmptySlice<f64>,
2992 params: &SpectrogramParams,
2993 erb: &ErbParams,
2994 db: Option<&LogParams>,
2995 ) -> SpectrogramResult<Self> {
2996 let planner = SpectrogramPlanner::new();
2997 let mut plan = planner.erb_plan(params, erb, db)?;
2998 plan.compute(samples)
2999 }
3000}
3001
3002impl<AmpScale> Spectrogram<LogHz, AmpScale>
3003where
3004 AmpScale: AmpScaleSpec + 'static,
3005{
3006 /// Compute a logarithmic-frequency spectrogram from audio samples.
3007 ///
3008 /// This is a convenience method that creates a planner internally and computes
3009 /// the spectrogram in one call. For processing multiple signals with the same
3010 /// parameters, use [`SpectrogramPlanner::log_hz_plan`] to create a reusable plan.
3011 ///
3012 /// # Arguments
3013 ///
3014 /// * `samples` - Audio samples (any type that can be converted to a slice)
3015 /// * `params` - Spectrogram computation parameters
3016 /// * `loghz` - Logarithmic frequency scale parameters
3017 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
3018 ///
3019 /// # Returns
3020 ///
3021 /// A logarithmic-frequency spectrogram with the specified amplitude scale.
3022 ///
3023 /// # Errors
3024 ///
3025 /// Returns an error if:
3026 /// - The samples slice is empty
3027 /// - Parameters are invalid
3028 /// - FFT computation fails
3029 ///
3030 /// # Examples
3031 ///
3032 /// ```
3033 /// use spectrograms::*;
3034 /// use non_empty_slice::non_empty_vec;
3035 ///
3036 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
3037 /// let samples = non_empty_vec![0.0; nzu!(16000)];
3038 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
3039 /// let params = SpectrogramParams::new(stft, 16000.0)?;
3040 /// let loghz = LogHzParams::new(nzu!(128), 20.0, 8000.0)?;
3041 ///
3042 /// let spec = LogHzPowerSpectrogram::compute(&samples, ¶ms, &loghz, None)?;
3043 /// assert_eq!(spec.n_bins(), nzu!(128));
3044 /// # Ok(())
3045 /// # }
3046 /// ```
3047 #[inline]
3048 pub fn compute(
3049 samples: &NonEmptySlice<f64>,
3050 params: &SpectrogramParams,
3051 loghz: &LogHzParams,
3052 db: Option<&LogParams>,
3053 ) -> SpectrogramResult<Self> {
3054 let planner = SpectrogramPlanner::new();
3055 let mut plan = planner.log_hz_plan(params, loghz, db)?;
3056 plan.compute(samples)
3057 }
3058}
3059
3060impl<AmpScale> Spectrogram<Cqt, AmpScale>
3061where
3062 AmpScale: AmpScaleSpec + 'static,
3063{
3064 /// Compute a constant-Q transform (CQT) spectrogram from audio samples.
3065 ///
3066 /// # Arguments
3067 ///
3068 /// * `samples` - Audio samples (any type that can be converted to a slice)
3069 /// * `params` - Spectrogram computation parameters
3070 /// * `cqt` - CQT parameters
3071 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
3072 ///
3073 /// # Returns
3074 ///
3075 /// A CQT spectrogram with the specified amplitude scale.
3076 ///
3077 /// # Errors
3078 ///
3079 /// Returns an error if:
3080 ///
3081 #[inline]
3082 pub fn compute(
3083 samples: &NonEmptySlice<f64>,
3084 params: &SpectrogramParams,
3085 cqt: &CqtParams,
3086 db: Option<&LogParams>,
3087 ) -> SpectrogramResult<Self> {
3088 let planner = SpectrogramPlanner::new();
3089 let mut plan = planner.cqt_plan(params, cqt, db)?;
3090 plan.compute(samples)
3091 }
3092}
3093
3094// ========================
3095// Display implementations
3096// ========================
3097
3098/// Helper function to get amplitude scale name
3099fn amp_scale_name<AmpScale>() -> &'static str
3100where
3101 AmpScale: AmpScaleSpec + 'static,
3102{
3103 match std::any::TypeId::of::<AmpScale>() {
3104 id if id == std::any::TypeId::of::<Power>() => "Power",
3105 id if id == std::any::TypeId::of::<Magnitude>() => "Magnitude",
3106 id if id == std::any::TypeId::of::<Decibels>() => "Decibels",
3107 _ => "Unknown",
3108 }
3109}
3110
3111/// Helper function to get frequency scale name
3112fn freq_scale_name<FreqScale>() -> &'static str
3113where
3114 FreqScale: Copy + Clone + 'static,
3115{
3116 match std::any::TypeId::of::<FreqScale>() {
3117 id if id == std::any::TypeId::of::<LinearHz>() => "Linear Hz",
3118 id if id == std::any::TypeId::of::<LogHz>() => "Log Hz",
3119 id if id == std::any::TypeId::of::<Mel>() => "Mel",
3120 id if id == std::any::TypeId::of::<Erb>() => "ERB",
3121 id if id == std::any::TypeId::of::<Cqt>() => "CQT",
3122 _ => "Unknown",
3123 }
3124}
3125
3126impl<FreqScale, AmpScale> core::fmt::Display for Spectrogram<FreqScale, AmpScale>
3127where
3128 AmpScale: AmpScaleSpec + 'static,
3129 FreqScale: Copy + Clone + 'static,
3130{
3131 #[inline]
3132 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
3133 let (freq_min, freq_max) = self.frequency_range();
3134 let duration = self.duration();
3135 let (rows, cols) = self.data.dim();
3136
3137 // Alternative formatting (#) provides more detailed output with data
3138 if f.alternate() {
3139 writeln!(f, "Spectrogram {{")?;
3140 writeln!(f, " Frequency Scale: {}", freq_scale_name::<FreqScale>())?;
3141 writeln!(f, " Amplitude Scale: {}", amp_scale_name::<AmpScale>())?;
3142 writeln!(f, " Shape: {rows} frequency bins × {cols} time frames")?;
3143 writeln!(f, " Frequency Range: {freq_min:.2} Hz - {freq_max:.2} Hz")?;
3144 writeln!(f, " Duration: {duration:.3} s")?;
3145 writeln!(f)?;
3146
3147 // Display parameters
3148 writeln!(f, " Parameters:")?;
3149 writeln!(f, " Sample Rate: {} Hz", self.params.sample_rate_hz())?;
3150 writeln!(f, " FFT Size: {}", self.params.stft().n_fft())?;
3151 writeln!(f, " Hop Size: {}", self.params.stft().hop_size())?;
3152 writeln!(f, " Window: {:?}", self.params.stft().window())?;
3153 writeln!(f, " Centered: {}", self.params.stft().centre())?;
3154 writeln!(f)?;
3155
3156 // Display data statistics
3157 let data_slice = self.data.as_slice().unwrap_or(&[]);
3158 if !data_slice.is_empty() {
3159 let (min_val, max_val) = min_max_single_pass(data_slice);
3160 let mean = data_slice.iter().sum::<f64>() / data_slice.len() as f64;
3161 writeln!(f, " Data Statistics:")?;
3162 writeln!(f, " Min: {min_val:.6}")?;
3163 writeln!(f, " Max: {max_val:.6}")?;
3164 writeln!(f, " Mean: {mean:.6}")?;
3165 writeln!(f)?;
3166 }
3167
3168 // Display actual data (truncated if too large)
3169 writeln!(f, " Data Matrix:")?;
3170 let max_rows_to_display = 5;
3171 let max_cols_to_display = 5;
3172
3173 for i in 0..rows.min(max_rows_to_display) {
3174 write!(f, " [")?;
3175 for j in 0..cols.min(max_cols_to_display) {
3176 if j > 0 {
3177 write!(f, ", ")?;
3178 }
3179 write!(f, "{:9.4}", self.data[[i, j]])?;
3180 }
3181 if cols > max_cols_to_display {
3182 write!(f, ", ... ({} more)", cols - max_cols_to_display)?;
3183 }
3184 writeln!(f, "]")?;
3185 }
3186
3187 if rows > max_rows_to_display {
3188 writeln!(f, " ... ({} more rows)", rows - max_rows_to_display)?;
3189 }
3190
3191 write!(f, "}}")?;
3192 } else {
3193 // Default formatting: compact summary
3194 write!(
3195 f,
3196 "Spectrogram<{}, {}>[{}x{}] ({:.2}-{:.2} Hz, {:.3}s)",
3197 freq_scale_name::<FreqScale>(),
3198 amp_scale_name::<AmpScale>(),
3199 rows,
3200 cols,
3201 freq_min,
3202 freq_max,
3203 duration
3204 )?;
3205 }
3206
3207 Ok(())
3208 }
3209}
3210
3211#[derive(Debug, Clone)]
3212#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3213pub struct FrequencyAxis<FreqScale>
3214where
3215 FreqScale: Copy + Clone + 'static,
3216{
3217 frequencies: NonEmptyVec<f64>,
3218 #[cfg_attr(feature = "serde", serde(skip))]
3219 _marker: PhantomData<FreqScale>,
3220}
3221
3222impl<FreqScale> FrequencyAxis<FreqScale>
3223where
3224 FreqScale: Copy + Clone + 'static,
3225{
3226 pub(crate) const fn new(frequencies: NonEmptyVec<f64>) -> Self {
3227 Self {
3228 frequencies,
3229 _marker: PhantomData,
3230 }
3231 }
3232
3233 /// Get the frequency values in Hz.
3234 ///
3235 /// # Returns
3236 ///
3237 /// Returns a non-empty slice of frequencies.
3238 #[inline]
3239 #[must_use]
3240 pub fn frequencies(&self) -> &NonEmptySlice<f64> {
3241 &self.frequencies
3242 }
3243
3244 /// Get the frequency range (min, max) in Hz.
3245 ///
3246 /// # Returns
3247 ///
3248 /// Returns a tuple containing the minimum and maximum frequency.
3249 #[inline]
3250 #[must_use]
3251 pub const fn frequency_range(&self) -> (f64, f64) {
3252 let data = self.frequencies.as_slice();
3253 let min = data[0];
3254 let max_idx = data.len().saturating_sub(1); // safe for non-empty
3255 let max = data[max_idx];
3256 (min, max)
3257 }
3258
3259 /// Get the number of frequency bins.
3260 ///
3261 /// # Returns
3262 ///
3263 /// Returns the number of frequency bins as a NonZeroUsize.
3264 #[inline]
3265 #[must_use]
3266 pub const fn len(&self) -> NonZeroUsize {
3267 self.frequencies.len()
3268 }
3269}
3270
3271/// Spectrogram axes container.
3272///
3273/// Holds frequency and time axes.
3274#[derive(Debug, Clone)]
3275#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3276pub struct Axes<FreqScale>
3277where
3278 FreqScale: Copy + Clone + 'static,
3279{
3280 freq: FrequencyAxis<FreqScale>,
3281 times: NonEmptyVec<f64>,
3282}
3283
3284impl<FreqScale> Axes<FreqScale>
3285where
3286 FreqScale: Copy + Clone + 'static,
3287{
3288 pub(crate) const fn new(freq: FrequencyAxis<FreqScale>, times: NonEmptyVec<f64>) -> Self {
3289 Self { freq, times }
3290 }
3291
3292 /// Get the frequency values in Hz.
3293 ///
3294 /// # Returns
3295 ///
3296 /// Returns a non-empty slice of frequencies.
3297 #[inline]
3298 #[must_use]
3299 pub fn frequencies(&self) -> &NonEmptySlice<f64> {
3300 self.freq.frequencies()
3301 }
3302
3303 /// Get the time values in seconds.
3304 ///
3305 /// # Returns
3306 ///
3307 /// Returns a non-empty slice of time values.
3308 #[inline]
3309 #[must_use]
3310 pub fn times(&self) -> &NonEmptySlice<f64> {
3311 &self.times
3312 }
3313
3314 /// Get the frequency range (min, max) in Hz.
3315 ///
3316 /// # Returns
3317 ///
3318 /// Returns a tuple containing the minimum and maximum frequency.
3319 #[inline]
3320 #[must_use]
3321 pub const fn frequency_range(&self) -> (f64, f64) {
3322 self.freq.frequency_range()
3323 }
3324
3325 /// Get the duration of the spectrogram in seconds.
3326 ///
3327 /// # Returns
3328 ///
3329 /// Returns the duration in seconds.
3330 #[inline]
3331 #[must_use]
3332 pub fn duration(&self) -> f64 {
3333 *self.times.last()
3334 }
3335}
3336
3337// Enum types for frequency and amplitude scales
3338
3339/// Linear frequency scale
3340#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3341#[cfg_attr(feature = "python", pyclass(from_py_object))]
3342#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3343#[non_exhaustive]
3344pub enum LinearHz {
3345 _Phantom,
3346}
3347
3348/// Logarithmic frequency scale
3349#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3350#[cfg_attr(feature = "python", pyclass(from_py_object))]
3351#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3352#[non_exhaustive]
3353pub enum LogHz {
3354 _Phantom,
3355}
3356
3357/// Mel frequency scale
3358#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3359#[cfg_attr(feature = "python", pyclass(from_py_object))]
3360#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3361#[non_exhaustive]
3362pub enum Mel {
3363 _Phantom,
3364}
3365
3366/// ERB/gammatone frequency scale
3367#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3368#[cfg_attr(feature = "python", pyclass(from_py_object))]
3369#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3370#[non_exhaustive]
3371pub enum Erb {
3372 _Phantom,
3373}
3374pub type Gammatone = Erb;
3375
3376/// Constant-Q Transform frequency scale
3377#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3378#[cfg_attr(feature = "python", pyclass(from_py_object))]
3379#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3380#[non_exhaustive]
3381pub enum Cqt {
3382 _Phantom,
3383}
3384
3385// Amplitude scales
3386
3387/// Power amplitude scale
3388#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3389#[cfg_attr(feature = "python", pyclass(from_py_object))]
3390#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3391#[non_exhaustive]
3392pub enum Power {
3393 _Phantom,
3394}
3395
3396/// Decibel amplitude scale
3397#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3398#[cfg_attr(feature = "python", pyclass(from_py_object))]
3399#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3400#[non_exhaustive]
3401pub enum Decibels {
3402 _Phantom,
3403}
3404
3405/// Magnitude amplitude scale
3406#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3407#[cfg_attr(feature = "python", pyclass(from_py_object))]
3408#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3409#[non_exhaustive]
3410pub enum Magnitude {
3411 _Phantom,
3412}
3413
3414/// STFT parameters for spectrogram computation.
3415///
3416/// * `n_fft`: Size of the FFT window.
3417/// * `hop_size`: Number of samples between successive frames.
3418/// * window: Window function to apply to each frame.
3419/// * centre: Whether to pad the input signal so that frames are centered.
3420#[derive(Debug, Clone, PartialEq)]
3421#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3422pub struct StftParams {
3423 n_fft: NonZeroUsize,
3424 hop_size: NonZeroUsize,
3425 window: WindowType,
3426 centre: bool,
3427}
3428
3429impl StftParams {
3430 /// Create new STFT parameters.
3431 ///
3432 /// # Arguments
3433 ///
3434 /// * `n_fft` - Size of the FFT window
3435 /// * `hop_size` - Number of samples between successive frames
3436 /// * `window` - Window function to apply to each frame
3437 /// * `centre` - Whether to pad the input signal so that frames are centered
3438 ///
3439 /// # Errors
3440 ///
3441 /// Returns an error if:
3442 /// - `hop_size` > `n_fft`
3443 /// - Custom window size doesn't match `n_fft`
3444 ///
3445 /// # Returns
3446 ///
3447 /// New `StftParams` instance.
3448 #[inline]
3449 pub fn new(
3450 n_fft: NonZeroUsize,
3451 hop_size: NonZeroUsize,
3452 window: WindowType,
3453 centre: bool,
3454 ) -> SpectrogramResult<Self> {
3455 if hop_size.get() > n_fft.get() {
3456 return Err(SpectrogramError::invalid_input("hop_size must be <= n_fft"));
3457 }
3458
3459 // Validate custom window size matches n_fft
3460 if let WindowType::Custom { size, .. } = &window {
3461 if size.get() != n_fft.get() {
3462 return Err(SpectrogramError::invalid_input(format!(
3463 "Custom window size ({}) must match n_fft ({})",
3464 size.get(),
3465 n_fft.get()
3466 )));
3467 }
3468 }
3469
3470 Ok(Self {
3471 n_fft,
3472 hop_size,
3473 window,
3474 centre,
3475 })
3476 }
3477
3478 const unsafe fn new_unchecked(
3479 n_fft: NonZeroUsize,
3480 hop_size: NonZeroUsize,
3481 window: WindowType,
3482 centre: bool,
3483 ) -> Self {
3484 Self {
3485 n_fft,
3486 hop_size,
3487 window,
3488 centre,
3489 }
3490 }
3491
3492 /// Get the FFT window size.
3493 ///
3494 /// # Returns
3495 ///
3496 /// The FFT window size.
3497 #[inline]
3498 #[must_use]
3499 pub const fn n_fft(&self) -> NonZeroUsize {
3500 self.n_fft
3501 }
3502
3503 /// Get the hop size (samples between successive frames).
3504 ///
3505 /// # Returns
3506 ///
3507 /// The hop size.
3508 #[inline]
3509 #[must_use]
3510 pub const fn hop_size(&self) -> NonZeroUsize {
3511 self.hop_size
3512 }
3513
3514 /// Get the window function.
3515 ///
3516 /// # Returns
3517 ///
3518 /// The window function.
3519 #[inline]
3520 #[must_use]
3521 pub fn window(&self) -> WindowType {
3522 self.window.clone()
3523 }
3524
3525 /// Get whether frames are centered (input signal is padded).
3526 ///
3527 /// # Returns
3528 ///
3529 /// `true` if frames are centered, `false` otherwise.
3530 #[inline]
3531 #[must_use]
3532 pub const fn centre(&self) -> bool {
3533 self.centre
3534 }
3535
3536 /// Create a builder for STFT parameters.
3537 ///
3538 /// # Returns
3539 ///
3540 /// A `StftParamsBuilder` instance.
3541 ///
3542 /// # Examples
3543 ///
3544 /// ```
3545 /// use spectrograms::{StftParams, WindowType, nzu};
3546 ///
3547 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
3548 /// let stft = StftParams::builder()
3549 /// .n_fft(nzu!(2048))
3550 /// .hop_size(nzu!(512))
3551 /// .window(WindowType::Hanning)
3552 /// .centre(true)
3553 /// .build()?;
3554 ///
3555 /// assert_eq!(stft.n_fft(), nzu!(2048));
3556 /// assert_eq!(stft.hop_size(), nzu!(512));
3557 /// # Ok(())
3558 /// # }
3559 /// ```
3560 #[inline]
3561 #[must_use]
3562 pub fn builder() -> StftParamsBuilder {
3563 StftParamsBuilder::default()
3564 }
3565}
3566
3567/// Builder for [`StftParams`].
3568#[derive(Debug, Clone)]
3569pub struct StftParamsBuilder {
3570 n_fft: Option<NonZeroUsize>,
3571 hop_size: Option<NonZeroUsize>,
3572 window: WindowType,
3573 centre: bool,
3574}
3575
3576impl Default for StftParamsBuilder {
3577 #[inline]
3578 fn default() -> Self {
3579 Self {
3580 n_fft: None,
3581 hop_size: None,
3582 window: WindowType::Hanning,
3583 centre: true,
3584 }
3585 }
3586}
3587
3588impl StftParamsBuilder {
3589 /// Set the FFT window size.
3590 ///
3591 /// # Arguments
3592 ///
3593 /// * `n_fft` - Size of the FFT window
3594 ///
3595 /// # Returns
3596 ///
3597 /// The builder with the updated FFT window size.
3598 #[inline]
3599 #[must_use]
3600 pub const fn n_fft(mut self, n_fft: NonZeroUsize) -> Self {
3601 self.n_fft = Some(n_fft);
3602 self
3603 }
3604
3605 /// Set the hop size (samples between successive frames).
3606 ///
3607 /// # Arguments
3608 ///
3609 /// * `hop_size` - Number of samples between successive frames
3610 ///
3611 /// # Returns
3612 ///
3613 /// The builder with the updated hop size.
3614 #[inline]
3615 #[must_use]
3616 pub const fn hop_size(mut self, hop_size: NonZeroUsize) -> Self {
3617 self.hop_size = Some(hop_size);
3618 self
3619 }
3620
3621 /// Set the window function.
3622 ///
3623 /// # Arguments
3624 ///
3625 /// * `window` - Window function to apply to each frame
3626 ///
3627 /// # Returns
3628 ///
3629 /// The builder with the updated window function.
3630 #[inline]
3631 #[must_use]
3632 pub fn window(mut self, window: WindowType) -> Self {
3633 self.window = window;
3634 self
3635 }
3636
3637 /// Set whether to center frames (pad input signal).
3638 #[inline]
3639 #[must_use]
3640 pub const fn centre(mut self, centre: bool) -> Self {
3641 self.centre = centre;
3642 self
3643 }
3644
3645 /// Build the [`StftParams`].
3646 ///
3647 /// # Errors
3648 ///
3649 /// Returns an error if:
3650 /// - `n_fft` or `hop_size` are not set or are zero
3651 /// - `hop_size` > `n_fft`
3652 #[inline]
3653 pub fn build(self) -> SpectrogramResult<StftParams> {
3654 let n_fft = self
3655 .n_fft
3656 .ok_or_else(|| SpectrogramError::invalid_input("n_fft must be set"))?;
3657 let hop_size = self
3658 .hop_size
3659 .ok_or_else(|| SpectrogramError::invalid_input("hop_size must be set"))?;
3660
3661 StftParams::new(n_fft, hop_size, self.window, self.centre)
3662 }
3663}
3664
3665//
3666// ========================
3667// Mel parameters
3668// ========================
3669//
3670
3671/// Mel filterbank normalization strategy.
3672///
3673/// Determines how the triangular mel filters are normalized after construction.
3674#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3675#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3676#[non_exhaustive]
3677#[derive(Default)]
3678pub enum MelNorm {
3679 /// No normalization (triangular filters with peak = 1.0).
3680 ///
3681 /// This is the default and fastest option.
3682 #[default]
3683 None,
3684
3685 /// Slaney-style area normalization (librosa default).
3686 ///
3687 /// Each mel filter is divided by its bandwidth in Hz: `2.0 / (f_max - f_min)`.
3688 /// This ensures constant energy per mel band regardless of bandwidth.
3689 ///
3690 /// Use this for compatibility with librosa's default behavior.
3691 Slaney,
3692
3693 /// L1 normalization (sum of weights = 1.0).
3694 ///
3695 /// Each mel filter's weights are divided by their sum.
3696 /// Useful when you want each filter to act as a weighted average.
3697 L1,
3698
3699 /// L2 normalization (Euclidean norm = 1.0).
3700 ///
3701 /// Each mel filter's weights are divided by their L2 norm.
3702 /// Provides unit-norm filters in the L2 sense.
3703 L2,
3704}
3705
3706/// Mel filter bank parameters
3707///
3708/// * `n_mels`: Number of mel bands
3709/// * `f_min`: Minimum frequency (Hz)
3710/// * `f_max`: Maximum frequency (Hz)
3711/// * `norm`: Filterbank normalization strategy
3712#[derive(Debug, Clone, Copy, PartialEq)]
3713#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3714pub struct MelParams {
3715 n_mels: NonZeroUsize,
3716 f_min: f64,
3717 f_max: f64,
3718 norm: MelNorm,
3719}
3720
3721impl MelParams {
3722 /// Create new mel filter bank parameters.
3723 ///
3724 /// # Arguments
3725 ///
3726 /// * `n_mels` - Number of mel bands
3727 /// * `f_min` - Minimum frequency (Hz)
3728 /// * `f_max` - Maximum frequency (Hz)
3729 ///
3730 /// # Errors
3731 ///
3732 /// Returns an error if:
3733 /// - `f_min` is not >= 0
3734 /// - `f_max` is not > `f_min`
3735 ///
3736 /// # Returns
3737 ///
3738 /// A `MelParams` instance with no normalization (default).
3739 #[inline]
3740 pub fn new(n_mels: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
3741 Self::with_norm(n_mels, f_min, f_max, MelNorm::None)
3742 }
3743
3744 /// Create new mel filter bank parameters with specified normalization.
3745 ///
3746 /// # Arguments
3747 ///
3748 /// * `n_mels` - Number of mel bands
3749 /// * `f_min` - Minimum frequency (Hz)
3750 /// * `f_max` - Maximum frequency (Hz)
3751 /// * `norm` - Filterbank normalization strategy
3752 ///
3753 /// # Errors
3754 ///
3755 /// Returns an error if:
3756 /// - `f_min` is not >= 0
3757 /// - `f_max` is not > `f_min`
3758 ///
3759 /// # Returns
3760 ///
3761 /// A `MelParams` instance.
3762 #[inline]
3763 pub fn with_norm(
3764 n_mels: NonZeroUsize,
3765 f_min: f64,
3766 f_max: f64,
3767 norm: MelNorm,
3768 ) -> SpectrogramResult<Self> {
3769 if f_min < 0.0 {
3770 return Err(SpectrogramError::invalid_input("f_min must be >= 0"));
3771 }
3772
3773 if f_max <= f_min {
3774 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
3775 }
3776
3777 Ok(Self {
3778 n_mels,
3779 f_min,
3780 f_max,
3781 norm,
3782 })
3783 }
3784
3785 /// Create new mel filter bank parameters.
3786 ///
3787 /// # Arguments
3788 ///
3789 /// * `n_mels` - Number of mel bands
3790 /// * `f_min` - Minimum frequency (Hz)
3791 /// * `f_max` - Maximum frequency (Hz)
3792 ///
3793 /// # Returns
3794 ///
3795 /// A `MelParams` instance with no normalization (default).
3796 ///
3797 /// # Safety
3798 ///
3799 /// The caller must ensure `f_min < f_max` (not validated at runtime).
3800 #[must_use]
3801 pub const unsafe fn new_unchecked(n_mels: NonZeroUsize, f_min: f64, f_max: f64) -> Self {
3802 Self {
3803 n_mels,
3804 f_min,
3805 f_max,
3806 norm: MelNorm::None,
3807 }
3808 }
3809
3810 /// Get the number of mel bands.
3811 ///
3812 /// # Returns
3813 ///
3814 /// The number of mel bands.
3815 #[inline]
3816 #[must_use]
3817 pub const fn n_mels(&self) -> NonZeroUsize {
3818 self.n_mels
3819 }
3820
3821 /// Get the minimum frequency (Hz).
3822 ///
3823 /// # Returns
3824 ///
3825 /// The minimum frequency in Hz.
3826 #[inline]
3827 #[must_use]
3828 pub const fn f_min(&self) -> f64 {
3829 self.f_min
3830 }
3831
3832 /// Get the maximum frequency (Hz).
3833 ///
3834 /// # Returns
3835 ///
3836 /// The maximum frequency in Hz.
3837 #[inline]
3838 #[must_use]
3839 pub const fn f_max(&self) -> f64 {
3840 self.f_max
3841 }
3842
3843 /// Get the filterbank normalization strategy.
3844 ///
3845 /// # Returns
3846 ///
3847 /// The normalization strategy.
3848 #[inline]
3849 #[must_use]
3850 pub const fn norm(&self) -> MelNorm {
3851 self.norm
3852 }
3853
3854 /// Create standard mel filterbank parameters.
3855 ///
3856 /// Uses 128 mel bands from 0 Hz to the Nyquist frequency.
3857 ///
3858 /// # Arguments
3859 ///
3860 /// * `sample_rate` - Sample rate in Hz (used to determine `f_max`)
3861 ///
3862 /// # Returns
3863 ///
3864 /// A `MelParams` instance with standard settings.
3865 ///
3866 /// # Panics
3867 ///
3868 /// Panics if `sample_rate` is not greater than 0.
3869 #[inline]
3870 #[must_use]
3871 pub const fn standard(sample_rate: f64) -> Self {
3872 assert!(sample_rate > 0.0);
3873 // safety: parameters are known to be valid
3874 unsafe { Self::new_unchecked(nzu!(128), 0.0, sample_rate / 2.0) }
3875 }
3876
3877 /// Create mel filterbank parameters optimized for speech.
3878 ///
3879 /// Uses 40 mel bands from 0 Hz to 8000 Hz (typical speech bandwidth).
3880 ///
3881 /// # Returns
3882 ///
3883 /// A `MelParams` instance with speech-optimized settings.
3884 #[inline]
3885 #[must_use]
3886 pub const fn speech_standard() -> Self {
3887 // safety: parameters are known to be valid
3888 unsafe { Self::new_unchecked(nzu!(40), 0.0, 8000.0) }
3889 }
3890}
3891
3892//
3893// ========================
3894// LogHz parameters
3895// ========================
3896//
3897
3898/// Logarithmic frequency scale parameters
3899///
3900/// * `n_bins`: Number of logarithmically-spaced frequency bins
3901/// * `f_min`: Minimum frequency (Hz)
3902/// * `f_max`: Maximum frequency (Hz)
3903#[derive(Debug, Clone, Copy, PartialEq)]
3904#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3905pub struct LogHzParams {
3906 n_bins: NonZeroUsize,
3907 f_min: f64,
3908 f_max: f64,
3909}
3910
3911impl LogHzParams {
3912 /// Create new logarithmic frequency scale parameters.
3913 ///
3914 /// # Arguments
3915 ///
3916 /// * `n_bins` - Number of logarithmically-spaced frequency bins
3917 /// * `f_min` - Minimum frequency (Hz)
3918 /// * `f_max` - Maximum frequency (Hz)
3919 ///
3920 /// # Errors
3921 ///
3922 /// Returns an error if:
3923 /// - `f_min` is not finite and > 0
3924 /// - `f_max` is not > `f_min`
3925 ///
3926 /// # Returns
3927 ///
3928 /// A `LogHzParams` instance.
3929 #[inline]
3930 pub fn new(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
3931 if !(f_min > 0.0 && f_min.is_finite()) {
3932 return Err(SpectrogramError::invalid_input(
3933 "f_min must be finite and > 0",
3934 ));
3935 }
3936
3937 if f_max <= f_min {
3938 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
3939 }
3940
3941 Ok(Self {
3942 n_bins,
3943 f_min,
3944 f_max,
3945 })
3946 }
3947
3948 const unsafe fn new_unchecked(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> Self {
3949 Self {
3950 n_bins,
3951 f_min,
3952 f_max,
3953 }
3954 }
3955
3956 /// Get the number of frequency bins.
3957 ///
3958 /// # Returns
3959 ///
3960 /// The number of frequency bins.
3961 #[inline]
3962 #[must_use]
3963 pub const fn n_bins(&self) -> NonZeroUsize {
3964 self.n_bins
3965 }
3966
3967 /// Get the minimum frequency (Hz).
3968 ///
3969 /// # Returns
3970 ///
3971 /// The minimum frequency in Hz.
3972 #[inline]
3973 #[must_use]
3974 pub const fn f_min(&self) -> f64 {
3975 self.f_min
3976 }
3977
3978 /// Get the maximum frequency (Hz).
3979 ///
3980 /// # Returns
3981 ///
3982 /// The maximum frequency in Hz.
3983 #[inline]
3984 #[must_use]
3985 pub const fn f_max(&self) -> f64 {
3986 self.f_max
3987 }
3988
3989 /// Create standard logarithmic frequency parameters.
3990 ///
3991 /// Uses 128 log bins from 20 Hz to the Nyquist frequency.
3992 ///
3993 /// # Arguments
3994 ///
3995 /// * `sample_rate` - Sample rate in Hz (used to determine `f_max`)
3996 #[inline]
3997 #[must_use]
3998 pub fn standard(sample_rate: f64) -> Self {
3999 // safety: parameters are known to be valid
4000 unsafe { Self::new_unchecked(nzu!(128), 20.0, sample_rate / 2.0) }
4001 }
4002
4003 /// Create logarithmic frequency parameters optimized for music.
4004 ///
4005 /// Uses 84 bins (7 octaves * 12 bins/octave) from 27.5 Hz (A0) to 4186 Hz (C8).
4006 #[inline]
4007 #[must_use]
4008 pub const fn music_standard() -> Self {
4009 // safety: parameters are known to be valid
4010 unsafe { Self::new_unchecked(nzu!(84), 27.5, 4186.0) }
4011 }
4012}
4013
4014//
4015// ========================
4016// Log scaling parameters
4017// ========================
4018//
4019
4020#[derive(Debug, Clone, Copy, PartialEq)]
4021#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
4022pub struct LogParams {
4023 floor_db: f64,
4024}
4025
4026impl LogParams {
4027 /// Create new logarithmic scaling parameters.
4028 ///
4029 /// # Arguments
4030 ///
4031 /// * `floor_db` - Minimum dB value (floor) for logarithmic scaling
4032 ///
4033 /// # Errors
4034 ///
4035 /// Returns an error if `floor_db` is not finite.
4036 ///
4037 /// # Returns
4038 ///
4039 /// A `LogParams` instance.
4040 #[inline]
4041 pub fn new(floor_db: f64) -> SpectrogramResult<Self> {
4042 if !floor_db.is_finite() {
4043 return Err(SpectrogramError::invalid_input("floor_db must be finite"));
4044 }
4045
4046 Ok(Self { floor_db })
4047 }
4048
4049 /// Create new logarithmic scaling parameters.
4050 ///
4051 /// # Arguments
4052 ///
4053 /// * `floor_db` - Minimum dB value (floor) for logarithmic scaling
4054 ///
4055 /// # Returns
4056 ///
4057 /// A `LogParams` instance.
4058 #[inline]
4059 #[must_use]
4060 pub const fn new_unchecked(floor_db: f64) -> Self {
4061 Self { floor_db }
4062 }
4063
4064 /// Get the floor dB value.
4065 #[inline]
4066 #[must_use]
4067 pub const fn floor_db(&self) -> f64 {
4068 self.floor_db
4069 }
4070}
4071
4072/// Spectrogram computation parameters.
4073///
4074/// * `stft`: STFT parameters
4075/// * `sample_rate_hz`: Sample rate in Hz
4076#[derive(Debug, Clone)]
4077#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
4078pub struct SpectrogramParams {
4079 pub(crate) stft: StftParams,
4080 pub(crate) sample_rate_hz: f64,
4081}
4082
4083impl SpectrogramParams {
4084 /// Create new spectrogram parameters.
4085 ///
4086 /// # Arguments
4087 ///
4088 /// * `stft` - STFT parameters
4089 /// * `sample_rate_hz` - Sample rate in Hz
4090 ///
4091 /// # Errors
4092 ///
4093 /// Returns an error if the sample rate is not positive and finite.
4094 ///
4095 /// # Returns
4096 ///
4097 /// A `SpectrogramParams` instance.
4098 #[inline]
4099 pub fn new(stft: StftParams, sample_rate_hz: f64) -> SpectrogramResult<Self> {
4100 if !(sample_rate_hz > 0.0 && sample_rate_hz.is_finite()) {
4101 return Err(SpectrogramError::invalid_input(
4102 "sample_rate_hz must be finite and > 0",
4103 ));
4104 }
4105
4106 Ok(Self {
4107 stft,
4108 sample_rate_hz,
4109 })
4110 }
4111
4112 /// Create new spectrogram parameters without checking the arguments.
4113 ///
4114 /// # Arguments
4115 ///
4116 /// * `stft` - STFT parameters
4117 /// * `sample_rate_hz` - Sample rate in Hz
4118 ///
4119 /// # Returns
4120 ///
4121 /// A `SpectrogramParams` instance.
4122 #[inline]
4123 #[must_use]
4124 pub const fn new_unchecked(stft: StftParams, sample_rate_hz: f64) -> Self {
4125 Self {
4126 stft,
4127 sample_rate_hz,
4128 }
4129 }
4130
4131 /// Create a builder for spectrogram parameters.
4132 ///
4133 /// # Errors
4134 ///
4135 /// Returns an error if required parameters are not set or are invalid.
4136 ///
4137 /// # Returns
4138 ///
4139 /// A builder for [`SpectrogramParams`].
4140 ///
4141 /// # Examples
4142 ///
4143 /// ```
4144 /// use spectrograms::{SpectrogramParams, WindowType, nzu};
4145 ///
4146 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4147 /// let params = SpectrogramParams::builder()
4148 /// .sample_rate(16000.0)
4149 /// .n_fft(nzu!(512))
4150 /// .hop_size(nzu!(256))
4151 /// .window(WindowType::Hanning)
4152 /// .centre(true)
4153 /// .build()?;
4154 ///
4155 /// assert_eq!(params.sample_rate_hz(), 16000.0);
4156 /// # Ok(())
4157 /// # }
4158 /// ```
4159 #[inline]
4160 #[must_use]
4161 pub fn builder() -> SpectrogramParamsBuilder {
4162 SpectrogramParamsBuilder::default()
4163 }
4164
4165 /// Create default parameters for speech processing.
4166 ///
4167 /// # Arguments
4168 ///
4169 /// * `sample_rate_hz` - Sample rate in Hz
4170 ///
4171 /// # Returns
4172 ///
4173 /// A `SpectrogramParams` instance with default settings for music analysis.
4174 ///
4175 /// # Errors
4176 ///
4177 /// Returns an error if the sample rate is not positive and finite.
4178 ///
4179 /// Uses:
4180 /// - `n_fft`: 512 (32ms at 16kHz)
4181 /// - `hop_size`: 160 (10ms at 16kHz)
4182 /// - window: Hanning
4183 /// - centre: true
4184 #[inline]
4185 pub fn speech_default(sample_rate_hz: f64) -> SpectrogramResult<Self> {
4186 // safety: parameters are known to be valid
4187 let stft =
4188 unsafe { StftParams::new_unchecked(nzu!(512), nzu!(160), WindowType::Hanning, true) };
4189
4190 Self::new(stft, sample_rate_hz)
4191 }
4192
4193 /// Create default parameters for music processing.
4194 ///
4195 /// # Arguments
4196 ///
4197 /// * `sample_rate_hz` - Sample rate in Hz
4198 ///
4199 /// # Returns
4200 ///
4201 /// A `SpectrogramParams` instance with default settings for music analysis.
4202 ///
4203 /// # Errors
4204 ///
4205 /// Returns an error if the sample rate is not positive and finite.
4206 ///
4207 /// Uses:
4208 /// - `n_fft`: 2048 (46ms at 44.1kHz)
4209 /// - `hop_size`: 512 (11.6ms at 44.1kHz)
4210 /// - window: Hanning
4211 /// - centre: true
4212 #[inline]
4213 pub fn music_default(sample_rate_hz: f64) -> SpectrogramResult<Self> {
4214 // safety: parameters are known to be valid
4215 let stft =
4216 unsafe { StftParams::new_unchecked(nzu!(2048), nzu!(512), WindowType::Hanning, true) };
4217 Self::new(stft, sample_rate_hz)
4218 }
4219
4220 /// Get the STFT parameters.
4221 #[inline]
4222 #[must_use]
4223 pub const fn stft(&self) -> &StftParams {
4224 &self.stft
4225 }
4226
4227 /// Get the sample rate in Hz.
4228 #[inline]
4229 #[must_use]
4230 pub const fn sample_rate_hz(&self) -> f64 {
4231 self.sample_rate_hz
4232 }
4233
4234 /// Get the frame period in seconds.
4235 #[inline]
4236 #[must_use]
4237 #[allow(clippy::cast_precision_loss)]
4238 pub fn frame_period_seconds(&self) -> f64 {
4239 self.stft.hop_size().get() as f64 / self.sample_rate_hz
4240 }
4241
4242 /// Get the Nyquist frequency in Hz.
4243 #[inline]
4244 #[must_use]
4245 pub fn nyquist_hz(&self) -> f64 {
4246 self.sample_rate_hz * 0.5
4247 }
4248}
4249
4250/// Builder for [`SpectrogramParams`].
4251#[derive(Debug, Clone)]
4252pub struct SpectrogramParamsBuilder {
4253 sample_rate: Option<f64>,
4254 n_fft: Option<NonZeroUsize>,
4255 hop_size: Option<NonZeroUsize>,
4256 window: WindowType,
4257 centre: bool,
4258}
4259
4260impl Default for SpectrogramParamsBuilder {
4261 #[inline]
4262 fn default() -> Self {
4263 Self {
4264 sample_rate: None,
4265 n_fft: None,
4266 hop_size: None,
4267 window: WindowType::Hanning,
4268 centre: true,
4269 }
4270 }
4271}
4272
4273impl SpectrogramParamsBuilder {
4274 /// Set the sample rate in Hz.
4275 ///
4276 /// # Arguments
4277 ///
4278 /// * `sample_rate` - Sample rate in Hz.
4279 ///
4280 /// # Returns
4281 ///
4282 /// The updated builder instance.
4283 #[inline]
4284 #[must_use]
4285 pub const fn sample_rate(mut self, sample_rate: f64) -> Self {
4286 self.sample_rate = Some(sample_rate);
4287 self
4288 }
4289
4290 /// Set the FFT window size.
4291 ///
4292 /// # Arguments
4293 ///
4294 /// * `n_fft` - FFT size.
4295 ///
4296 /// # Returns
4297 ///
4298 /// The updated builder instance.
4299 #[inline]
4300 #[must_use]
4301 pub const fn n_fft(mut self, n_fft: NonZeroUsize) -> Self {
4302 self.n_fft = Some(n_fft);
4303 self
4304 }
4305
4306 /// Set the hop size (samples between successive frames).
4307 ///
4308 /// # Arguments
4309 ///
4310 /// * `hop_size` - Hop size in samples.
4311 ///
4312 /// # Returns
4313 ///
4314 /// The updated builder instance.
4315 #[inline]
4316 #[must_use]
4317 pub const fn hop_size(mut self, hop_size: NonZeroUsize) -> Self {
4318 self.hop_size = Some(hop_size);
4319 self
4320 }
4321
4322 /// Set the window function.
4323 ///
4324 /// # Arguments
4325 ///
4326 /// * `window` - Window function to apply to each frame.
4327 ///
4328 /// # Returns
4329 ///
4330 /// The updated builder instance.
4331 #[inline]
4332 #[must_use]
4333 pub fn window(mut self, window: WindowType) -> Self {
4334 self.window = window;
4335 self
4336 }
4337
4338 /// Set whether to center frames (pad input signal).
4339 ///
4340 /// # Arguments
4341 ///
4342 /// * `centre` - If true, frames are centered by padding the input signal.
4343 ///
4344 /// # Returns
4345 ///
4346 /// The updated builder instance.
4347 #[inline]
4348 #[must_use]
4349 pub const fn centre(mut self, centre: bool) -> Self {
4350 self.centre = centre;
4351 self
4352 }
4353
4354 /// Build the [`SpectrogramParams`].
4355 ///
4356 /// # Errors
4357 ///
4358 /// Returns an error if required parameters are not set or are invalid.
4359 ///
4360 /// # Returns
4361 ///
4362 /// A `SpectrogramParams` instance.
4363 #[inline]
4364 pub fn build(self) -> SpectrogramResult<SpectrogramParams> {
4365 let sample_rate = self
4366 .sample_rate
4367 .ok_or_else(|| SpectrogramError::invalid_input("sample_rate must be set"))?;
4368 let n_fft = self
4369 .n_fft
4370 .ok_or_else(|| SpectrogramError::invalid_input("n_fft must be set"))?;
4371 let hop_size = self
4372 .hop_size
4373 .ok_or_else(|| SpectrogramError::invalid_input("hop_size must be set"))?;
4374
4375 let stft = StftParams::new(n_fft, hop_size, self.window, self.centre)?;
4376 SpectrogramParams::new(stft, sample_rate)
4377 }
4378
4379 /// Build the [`SpectrogramParams`].
4380 ///
4381 /// # Safety
4382 ///
4383 /// The caller is responsible for ensuring the 'n_fft', 'hop_size' are set.
4384 ///
4385 /// # Returns
4386 ///
4387 /// A `SpectrogramParams` instance.
4388 #[inline]
4389 #[must_use]
4390 pub unsafe fn build_unchecked(self) -> SpectrogramParams {
4391 // safety: is the repsonsibility of the caller
4392 unsafe {
4393 let n_fft = self.n_fft.unwrap_unchecked();
4394 let hop_size = self.hop_size.unwrap_unchecked();
4395 let stft = StftParams::new_unchecked(n_fft, hop_size, self.window, self.centre);
4396 let sample_rate = self.sample_rate.unwrap_unchecked();
4397 SpectrogramParams::new_unchecked(stft, sample_rate)
4398 }
4399 }
4400}
4401
4402//
4403// ========================
4404// Standalone FFT Functions
4405// ========================
4406//
4407
4408/// Compute the real-to-complex FFT of a real-valued signal.
4409///
4410/// This function performs a forward FFT on real-valued input, returning the
4411/// complex frequency domain representation. Only the positive frequencies
4412/// are returned (length = `n_fft/2` + 1) due to conjugate symmetry.
4413///
4414/// # Arguments
4415///
4416/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4417/// * `n_fft` - FFT size
4418///
4419/// # Returns
4420///
4421/// A vector of complex frequency bins with length `n_fft/2` + 1.
4422///
4423/// # Automatic Zero-Padding
4424///
4425/// If the input signal is shorter than `n_fft`, it will be automatically
4426/// zero-padded to the required length. This is standard DSP practice and
4427/// preserves frequency resolution (bin spacing = sample_rate / n_fft).
4428///
4429/// ```
4430/// use spectrograms::{fft, nzu};
4431/// use non_empty_slice::non_empty_vec;
4432///
4433/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4434/// let signal = non_empty_vec![1.0, 2.0, 3.0]; // Only 3 samples
4435/// let spectrum = fft(&signal, nzu!(8))?; // Automatically padded to 8
4436/// assert_eq!(spectrum.len(), 5); // Output: 8/2 + 1 = 5 bins
4437/// # Ok(())
4438/// # }
4439/// ```
4440///
4441/// # Errors
4442///
4443/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4444///
4445/// # Examples
4446///
4447/// ```
4448/// use spectrograms::*;
4449/// use non_empty_slice::non_empty_vec;
4450///
4451/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4452/// let signal = non_empty_vec![0.0; nzu!(512)];
4453/// let spectrum = fft(&signal, nzu!(512))?;
4454///
4455/// assert_eq!(spectrum.len(), 257); // 512/2 + 1
4456/// # Ok(())
4457/// # }
4458/// ```
4459#[inline]
4460pub fn fft(
4461 samples: &NonEmptySlice<f64>,
4462 n_fft: NonZeroUsize,
4463) -> SpectrogramResult<Array1<Complex<f64>>> {
4464 if samples.len() > n_fft {
4465 return Err(SpectrogramError::invalid_input(format!(
4466 "Input length ({}) exceeds FFT size ({})",
4467 samples.len(),
4468 n_fft
4469 )));
4470 }
4471
4472 let out_len = r2c_output_size(n_fft.get());
4473
4474 // Get FFT plan from global cache (or create if first use)
4475 #[cfg(feature = "realfft")]
4476 let mut fft = {
4477 use crate::fft_backend::get_or_create_r2c_plan;
4478 let plan = get_or_create_r2c_plan(n_fft.get())?;
4479 // Clone the plan to get our own mutable copy with independent scratch buffer
4480 // This is cheap - only clones the scratch buffer, not the expensive twiddle factors
4481 (*plan).clone()
4482 };
4483
4484 #[cfg(feature = "fftw")]
4485 let mut fft = {
4486 use std::sync::Arc;
4487 let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
4488 crate::FftwPlan::new(Arc::new(plan))
4489 };
4490
4491 let input = if samples.len() < n_fft {
4492 let mut padded = vec![0.0; n_fft.get()];
4493 padded[..samples.len().get()].copy_from_slice(samples);
4494 // safety: samples.len() < n_fft checked above and n_fft > 0
4495
4496 unsafe { NonEmptyVec::new_unchecked(padded) }
4497 } else {
4498 samples.to_non_empty_vec()
4499 };
4500
4501 let mut output = vec![Complex::new(0.0, 0.0); out_len];
4502 fft.process(&input, &mut output)?;
4503 let output = Array1::from_vec(output);
4504 Ok(output)
4505}
4506
4507#[inline]
4508/// Compute the real-valued fft of a signal.
4509///
4510/// # Arguments
4511/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4512/// * `n_fft` - FFT size
4513///
4514/// # Returns
4515///
4516/// An array with length `n_fft/2` + 1.
4517///
4518/// # Errors
4519///
4520/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4521///
4522/// # Examples
4523///
4524/// ```
4525/// use spectrograms::*;
4526/// use non_empty_slice::non_empty_vec;
4527///
4528/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4529/// let signal = non_empty_vec![0.0; nzu!(512)];
4530/// let rfft_result = rfft(&signal, nzu!(512))?;
4531/// // equivalent to
4532/// let fft_result = fft(&signal, nzu!(512))?;
4533/// let rfft_result = fft_result.mapv(num_complex::Complex::norm);
4534/// # Ok(())
4535/// # }
4536///
4537pub fn rfft(samples: &NonEmptySlice<f64>, n_fft: NonZeroUsize) -> SpectrogramResult<Array1<f64>> {
4538 Ok(fft(samples, n_fft)?.mapv(Complex::norm))
4539}
4540
4541/// Compute the power spectrum of a signal (|X|²).
4542///
4543/// This function applies an optional window function and computes the
4544/// power spectrum via FFT. The result contains only positive frequencies.
4545///
4546/// # Arguments
4547///
4548/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4549/// * `n_fft` - FFT size
4550/// * `window` - Optional window function (None for rectangular window)
4551///
4552/// # Returns
4553///
4554/// A vector of power values with length `n_fft/2` + 1.
4555///
4556/// # Automatic Zero-Padding
4557///
4558/// If the input signal is shorter than `n_fft`, it will be automatically
4559/// zero-padded to the required length. This is standard DSP practice and
4560/// preserves frequency resolution (bin spacing = sample_rate / n_fft).
4561///
4562/// ```
4563/// use spectrograms::{power_spectrum, nzu};
4564/// use non_empty_slice::non_empty_vec;
4565///
4566/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4567/// let signal = non_empty_vec![1.0, 2.0, 3.0]; // Only 3 samples
4568/// let power = power_spectrum(&signal, nzu!(8), None)?;
4569/// assert_eq!(power.len(), nzu!(5)); // Output: 8/2 + 1 = 5 bins
4570/// # Ok(())
4571/// # }
4572/// ```
4573///
4574/// # Errors
4575///
4576/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4577///
4578/// # Examples
4579///
4580/// ```
4581/// use spectrograms::*;
4582/// use non_empty_slice::non_empty_vec;
4583///
4584/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4585/// let signal = non_empty_vec![0.0; nzu!(512)];
4586/// let power = power_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4587///
4588/// assert_eq!(power.len(), nzu!(257)); // 512/2 + 1
4589/// # Ok(())
4590/// # }
4591/// ```
4592#[inline]
4593pub fn power_spectrum(
4594 samples: &NonEmptySlice<f64>,
4595 n_fft: NonZeroUsize,
4596 window: Option<WindowType>,
4597) -> SpectrogramResult<NonEmptyVec<f64>> {
4598 if samples.len() > n_fft {
4599 return Err(SpectrogramError::invalid_input(format!(
4600 "Input length ({}) exceeds FFT size ({})",
4601 samples.len(),
4602 n_fft
4603 )));
4604 }
4605
4606 let mut windowed = vec![0.0; n_fft.get()];
4607 windowed[..samples.len().get()].copy_from_slice(samples);
4608
4609 if let Some(win_type) = window {
4610 let window_samples = make_window(win_type, n_fft);
4611 for i in 0..n_fft.get() {
4612 windowed[i] *= window_samples[i];
4613 }
4614 }
4615
4616 // safety: windowed is non-empty since n_fft > 0
4617 let windowed = unsafe { NonEmptySlice::new_unchecked(&windowed) };
4618 let fft_result = fft(windowed, n_fft)?;
4619 let fft_result = fft_result
4620 .iter()
4621 .map(num_complex::Complex::norm_sqr)
4622 .collect();
4623 // safety: fft_result is non-empty since fft returned successfully
4624 Ok(unsafe { NonEmptyVec::new_unchecked(fft_result) })
4625}
4626
4627/// Compute the magnitude spectrum of a signal (|X|).
4628///
4629/// This function applies an optional window function and computes the
4630/// magnitude spectrum via FFT. The result contains only positive frequencies.
4631///
4632/// # Arguments
4633///
4634/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4635/// * `n_fft` - FFT size
4636/// * `window` - Optional window function (None for rectangular window)
4637///
4638/// # Automatic Zero-Padding
4639///
4640/// If the input signal is shorter than `n_fft`, it will be automatically
4641/// zero-padded to the required length. This preserves frequency resolution.
4642///
4643/// # Errors
4644///
4645/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4646///
4647/// # Returns
4648///
4649/// A vector of magnitude values with length `n_fft/2` + 1.
4650///
4651/// # Examples
4652///
4653/// ```
4654/// use spectrograms::*;
4655/// use non_empty_slice::non_empty_vec;
4656///
4657/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4658/// let signal = non_empty_vec![0.0; nzu!(512)];
4659/// let magnitude = magnitude_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4660///
4661/// assert_eq!(magnitude.len(), nzu!(257)); // 512/2 + 1
4662/// # Ok(())
4663/// # }
4664/// ```
4665#[inline]
4666pub fn magnitude_spectrum(
4667 samples: &NonEmptySlice<f64>,
4668 n_fft: NonZeroUsize,
4669 window: Option<WindowType>,
4670) -> SpectrogramResult<NonEmptyVec<f64>> {
4671 let power = power_spectrum(samples, n_fft, window)?;
4672 let power = power.iter().map(|&p| p.sqrt()).collect();
4673 // safety: power is non-empty since power_spectrum returned successfully
4674 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
4675}
4676
4677/// Compute the Short-Time Fourier Transform (STFT) of a signal.
4678///
4679/// This function computes the STFT by applying a sliding window and FFT
4680/// to sequential frames of the input signal.
4681///
4682/// # Arguments
4683///
4684/// * `samples` - Input signal (any type that can be converted to a slice)
4685/// * `n_fft` - FFT size
4686/// * `hop_size` - Number of samples between successive frames
4687/// * `window` - Window function to apply to each frame
4688/// * `center` - If true, pad the signal to center frames
4689///
4690/// # Returns
4691///
4692/// A 2D array with shape (`frequency_bins`, `time_frames`) containing complex STFT values.
4693///
4694/// # Errors
4695///
4696/// Returns an error if:
4697/// - `hop_size` > `n_fft`
4698/// - STFT computation fails
4699///
4700/// # Examples
4701///
4702/// ```
4703/// use spectrograms::*;
4704/// use non_empty_slice::non_empty_vec;
4705///
4706/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4707/// let signal = non_empty_vec![0.0; nzu!(16000)];
4708/// let stft_matrix = stft(&signal, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4709///
4710/// println!("STFT: {} bins x {} frames", stft_matrix.nrows(), stft_matrix.ncols());
4711/// # Ok(())
4712/// # }
4713/// ```
4714#[inline]
4715pub fn stft(
4716 samples: &NonEmptySlice<f64>,
4717 n_fft: NonZeroUsize,
4718 hop_size: NonZeroUsize,
4719 window: WindowType,
4720 center: bool,
4721) -> SpectrogramResult<Array2<Complex<f64>>> {
4722 let stft_params = StftParams::new(n_fft, hop_size, window, center)?;
4723 let params = SpectrogramParams::new(stft_params, 1.0)?; // dummy sample rate
4724
4725 let planner = SpectrogramPlanner::new();
4726 let result = planner.compute_stft(samples, ¶ms)?;
4727
4728 Ok(result.data)
4729}
4730
4731/// Compute the inverse real FFT (complex-to-real IFFT).
4732///
4733/// This function performs an inverse FFT, converting frequency domain data
4734/// back to the time domain. Only the positive frequencies need to be provided
4735/// (length = `n_fft/2` + 1) due to conjugate symmetry.
4736///
4737/// # Arguments
4738///
4739/// * `spectrum` - Complex frequency bins (length should be `n_fft/2` + 1)
4740/// * `n_fft` - FFT size (length of the output signal)
4741///
4742/// # Returns
4743///
4744/// A vector of real-valued time-domain samples with length `n_fft`.
4745///
4746/// # Errors
4747///
4748/// Returns an error if:
4749/// - `spectrum` length doesn't match `n_fft/2` + 1
4750/// - Inverse FFT computation fails
4751///
4752/// # Examples
4753///
4754/// ```
4755/// use spectrograms::*;
4756/// use non_empty_slice::{non_empty_vec, NonEmptySlice};
4757/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4758/// // Forward FFT
4759/// let signal = non_empty_vec![1.0, 0.0, -1.0, 0.0, 1.0, 0.0, -1.0, 0.0];
4760/// let spectrum = fft(&signal, nzu!(8))?;
4761/// let slice = spectrum.as_slice().unwrap();
4762/// let spectrum_slice = NonEmptySlice::new(slice).unwrap();
4763/// // Inverse FFT
4764/// let reconstructed = irfft(spectrum_slice, nzu!(8))?;
4765///
4766/// assert_eq!(reconstructed.len(), nzu!(8));
4767/// # Ok(())
4768/// # }
4769/// ```
4770#[inline]
4771pub fn irfft(
4772 spectrum: &NonEmptySlice<Complex<f64>>,
4773 n_fft: NonZeroUsize,
4774) -> SpectrogramResult<NonEmptyVec<f64>> {
4775 use crate::fft_backend::{C2rPlan, r2c_output_size};
4776
4777 let n_fft = n_fft.get();
4778 let expected_len = r2c_output_size(n_fft);
4779 if spectrum.len().get() != expected_len {
4780 return Err(SpectrogramError::dimension_mismatch(
4781 expected_len,
4782 spectrum.len().get(),
4783 ));
4784 }
4785
4786 // Get inverse FFT plan from global cache (or create if first use)
4787 #[cfg(feature = "realfft")]
4788 let mut ifft = {
4789 use crate::fft_backend::get_or_create_c2r_plan;
4790 let plan = get_or_create_c2r_plan(n_fft)?;
4791 // Clone to get our own mutable copy with independent scratch buffer
4792 (*plan).clone()
4793 };
4794
4795 #[cfg(feature = "fftw")]
4796 let mut ifft = {
4797 use crate::fft_backend::C2rPlanner;
4798 let mut planner = crate::FftwPlanner::new();
4799 planner.plan_c2r(n_fft)?
4800 };
4801
4802 let mut output = vec![0.0; n_fft];
4803 ifft.process(spectrum.as_slice(), &mut output)?;
4804
4805 // Safety: output is non-empty since n_fft > 0
4806 Ok(unsafe { NonEmptyVec::new_unchecked(output) })
4807}
4808
4809/// Reconstruct a time-domain signal from its STFT using overlap-add.
4810///
4811/// This function performs the inverse Short-Time Fourier Transform, converting
4812/// a 2D complex STFT matrix back to a 1D time-domain signal using overlap-add
4813/// synthesis with the specified window function.
4814///
4815/// # Arguments
4816///
4817/// * `stft_matrix` - Complex STFT with shape (`frequency_bins`, `time_frames`)
4818/// * `n_fft` - FFT size
4819/// * `hop_size` - Number of samples between successive frames
4820/// * `window` - Window function to apply (should match forward STFT window)
4821/// * `center` - If true, assume the forward STFT was centered
4822///
4823/// # Returns
4824///
4825/// A vector of reconstructed time-domain samples.
4826///
4827/// # Errors
4828///
4829/// Returns an error if:
4830/// - `stft_matrix` dimensions are inconsistent with `n_fft`
4831/// - `hop_size` > `n_fft`
4832/// - Inverse STFT computation fails
4833///
4834/// # Examples
4835///
4836/// ```
4837/// use spectrograms::*;
4838/// use non_empty_slice::non_empty_vec;
4839///
4840/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4841/// // Generate signal
4842/// let signal = non_empty_vec![1.0; nzu!(16000)];
4843///
4844/// // Forward STFT
4845/// let stft_matrix = stft(&signal, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4846///
4847/// // Inverse STFT
4848/// let reconstructed = istft(&stft_matrix, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4849///
4850/// println!("Original: {} samples", signal.len());
4851/// println!("Reconstructed: {} samples", reconstructed.len());
4852/// # Ok(())
4853/// # }
4854/// ```
4855#[inline]
4856pub fn istft(
4857 stft_matrix: &Array2<Complex<f64>>,
4858 n_fft: NonZeroUsize,
4859 hop_size: NonZeroUsize,
4860 window: WindowType,
4861 center: bool,
4862) -> SpectrogramResult<NonEmptyVec<f64>> {
4863 use crate::fft_backend::{C2rPlan, C2rPlanner, r2c_output_size};
4864
4865 let n_bins = stft_matrix.nrows();
4866 let n_frames = stft_matrix.ncols();
4867
4868 let expected_bins = r2c_output_size(n_fft.get());
4869 if n_bins != expected_bins {
4870 return Err(SpectrogramError::dimension_mismatch(expected_bins, n_bins));
4871 }
4872 if hop_size.get() > n_fft.get() {
4873 return Err(SpectrogramError::invalid_input("hop_size must be <= n_fft"));
4874 }
4875 // Create inverse FFT plan
4876 #[cfg(feature = "realfft")]
4877 let mut ifft = {
4878 let mut planner = crate::RealFftPlanner::new();
4879 planner.plan_c2r(n_fft.get())?
4880 };
4881
4882 #[cfg(feature = "fftw")]
4883 let mut ifft = {
4884 let mut planner = crate::FftwPlanner::new();
4885 planner.plan_c2r(n_fft.get())?
4886 };
4887
4888 // Generate window
4889 let window_samples = make_window(window, n_fft);
4890 let n_fft = n_fft.get();
4891 let hop_size = hop_size.get();
4892 // Calculate output length
4893 let pad = if center { n_fft / 2 } else { 0 };
4894 let output_len = (n_frames - 1) * hop_size + n_fft;
4895 // safety: output_len > 0 since n_frames > 0 and n_fft, hop_size > 0
4896 let output_len = unsafe { NonZeroUsize::new_unchecked(output_len) };
4897 let unpadded_len = output_len.get().saturating_sub(2 * pad);
4898
4899 // Allocate output buffer and normalization buffer
4900 let mut output = non_empty_vec![0.0; output_len];
4901 let mut norm = non_empty_vec![0.0; output_len];
4902
4903 // Overlap-add synthesis
4904 let mut frame_buffer = vec![Complex::new(0.0, 0.0); n_bins];
4905 let mut time_frame = vec![0.0; n_fft];
4906
4907 for frame_idx in 0..n_frames {
4908 // Extract complex frame from STFT matrix
4909 for bin_idx in 0..n_bins {
4910 frame_buffer[bin_idx] = stft_matrix[[bin_idx, frame_idx]];
4911 }
4912
4913 // Inverse FFT
4914 ifft.process(&frame_buffer, &mut time_frame)?;
4915
4916 // Apply window
4917 for i in 0..n_fft {
4918 time_frame[i] *= window_samples[i];
4919 }
4920
4921 // Overlap-add into output buffer
4922 let start = frame_idx * hop_size;
4923 for i in 0..n_fft {
4924 let pos = start + i;
4925 if pos < output_len.get() {
4926 output[pos] += time_frame[i];
4927 norm[pos] += window_samples[i] * window_samples[i];
4928 }
4929 }
4930 }
4931
4932 // Normalize by window energy
4933 for i in 0..output_len.get() {
4934 if norm[i] > 1e-10 {
4935 output[i] /= norm[i];
4936 }
4937 }
4938
4939 // Remove padding if centered
4940 if center && unpadded_len > 0 {
4941 let start = pad;
4942 let end = start + unpadded_len;
4943 // safety: start < end <= output_len, therefore slice is non-empty
4944 output = unsafe {
4945 NonEmptySlice::new_unchecked(&output[start..end.min(output_len.get())])
4946 .to_non_empty_vec()
4947 };
4948 }
4949
4950 Ok(output)
4951}
4952
4953//
4954// ========================
4955// Reusable FFT Plans
4956// ========================
4957//
4958
4959/// A reusable FFT planner for efficient repeated FFT operations.
4960///
4961/// This planner caches FFT plans internally, making repeated FFT operations
4962/// of the same size much more efficient than calling `fft()` repeatedly.
4963///
4964/// # Examples
4965///
4966/// ```
4967/// use spectrograms::*;
4968/// use non_empty_slice::non_empty_vec;
4969///
4970/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4971/// let mut planner = FftPlanner::new();
4972///
4973/// // Process multiple signals of the same size efficiently
4974/// for _ in 0..100 {
4975/// let signal = non_empty_vec![0.0; nzu!(512)];
4976/// let spectrum = planner.fft(&signal, nzu!(512))?;
4977/// // ... process spectrum ...
4978/// }
4979/// # Ok(())
4980/// # }
4981/// ```
4982pub struct FftPlanner {
4983 #[cfg(feature = "realfft")]
4984 inner: crate::RealFftPlanner,
4985 #[cfg(feature = "fftw")]
4986 inner: crate::FftwPlanner,
4987}
4988
4989impl FftPlanner {
4990 /// Create a new FFT planner with empty cache.
4991 #[inline]
4992 #[must_use]
4993 pub fn new() -> Self {
4994 Self {
4995 #[cfg(feature = "realfft")]
4996 inner: crate::RealFftPlanner::new(),
4997 #[cfg(feature = "fftw")]
4998 inner: crate::FftwPlanner::new(),
4999 }
5000 }
5001
5002 /// Compute forward FFT, reusing cached plans.
5003 ///
5004 /// This is more efficient than calling the standalone `fft()` function
5005 /// repeatedly for the same FFT size.
5006 ///
5007 /// # Automatic Zero-Padding
5008 ///
5009 /// If the input signal is shorter than `n_fft`, it will be automatically
5010 /// zero-padded to the required length.
5011 ///
5012 /// # Errors
5013 ///
5014 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
5015 ///
5016 /// # Examples
5017 ///
5018 /// ```
5019 /// use spectrograms::*;
5020 /// use non_empty_slice::non_empty_vec;
5021 ///
5022 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
5023 /// let mut planner = FftPlanner::new();
5024 ///
5025 /// let signal = non_empty_vec![1.0; nzu!(512)];
5026 /// let spectrum = planner.fft(&signal, nzu!(512))?;
5027 ///
5028 /// assert_eq!(spectrum.len(), 257); // 512/2 + 1
5029 /// # Ok(())
5030 /// # }
5031 /// ```
5032 #[inline]
5033 pub fn fft(
5034 &mut self,
5035 samples: &NonEmptySlice<f64>,
5036 n_fft: NonZeroUsize,
5037 ) -> SpectrogramResult<Array1<Complex<f64>>> {
5038 use crate::fft_backend::{R2cPlan, R2cPlanner, r2c_output_size};
5039
5040 if samples.len() > n_fft {
5041 return Err(SpectrogramError::invalid_input(format!(
5042 "Input length ({}) exceeds FFT size ({})",
5043 samples.len(),
5044 n_fft
5045 )));
5046 }
5047
5048 let out_len = r2c_output_size(n_fft.get());
5049 let mut plan = self.inner.plan_r2c(n_fft.get())?;
5050
5051 let input = if samples.len() < n_fft {
5052 let mut padded = vec![0.0; n_fft.get()];
5053 padded[..samples.len().get()].copy_from_slice(samples);
5054
5055 // safety: samples.len() < n_fft checked above and n_fft > 0
5056 unsafe { NonEmptyVec::new_unchecked(padded) }
5057 } else {
5058 samples.to_non_empty_vec()
5059 };
5060
5061 let mut output = vec![Complex::new(0.0, 0.0); out_len];
5062 plan.process(&input, &mut output)?;
5063
5064 let output = Array1::from_vec(output);
5065 Ok(output)
5066 }
5067
5068 /// Compute forward real FFT magnitude
5069 ///
5070 /// # Errors
5071 ///
5072 /// Returns an error if:
5073 /// - `n_fft` doesn't match the samples length
5074 ///
5075 ///
5076 #[inline]
5077 pub fn rfft(
5078 &mut self,
5079 samples: &NonEmptySlice<f64>,
5080 n_fft: NonZeroUsize,
5081 ) -> SpectrogramResult<Array1<f64>> {
5082 let fft_with_complex = fft(samples, n_fft)?;
5083 Ok(fft_with_complex.mapv(Complex::norm))
5084 }
5085
5086 /// Compute inverse FFT, reusing cached plans.
5087 ///
5088 /// This is more efficient than calling the standalone `irfft()` function
5089 /// repeatedly for the same FFT size.
5090 ///
5091 /// # Errors
5092 /// Returns an error if:
5093 ///
5094 /// - The calculated expected length of `spectrum` doesn't match its actual length
5095 ///
5096 /// # Examples
5097 ///
5098 /// ```
5099 /// use spectrograms::*;
5100 /// use non_empty_slice::{non_empty_vec, NonEmptySlice};
5101 ///
5102 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
5103 /// let mut planner = FftPlanner::new();
5104 ///
5105 /// // Forward FFT
5106 /// let signal = non_empty_vec![1.0; nzu!(512)];
5107 /// let spectrum = planner.fft(&signal, nzu!(512))?;
5108 ///
5109 /// // Inverse FFT
5110 /// let spectrum_slice = NonEmptySlice::new(spectrum.as_slice().unwrap()).unwrap();
5111 /// let reconstructed = planner.irfft(spectrum_slice, nzu!(512))?;
5112 ///
5113 /// assert_eq!(reconstructed.len(), nzu!(512));
5114 /// # Ok(())
5115 /// # }
5116 /// ```
5117 #[inline]
5118 pub fn irfft(
5119 &mut self,
5120 spectrum: &NonEmptySlice<Complex<f64>>,
5121 n_fft: NonZeroUsize,
5122 ) -> SpectrogramResult<NonEmptyVec<f64>> {
5123 use crate::fft_backend::{C2rPlan, C2rPlanner, r2c_output_size};
5124
5125 let expected_len = r2c_output_size(n_fft.get());
5126 if spectrum.len().get() != expected_len {
5127 return Err(SpectrogramError::dimension_mismatch(
5128 expected_len,
5129 spectrum.len().get(),
5130 ));
5131 }
5132
5133 let mut plan = self.inner.plan_c2r(n_fft.get())?;
5134 let mut output = vec![0.0; n_fft.get()];
5135 plan.process(spectrum, &mut output)?;
5136 // Safety: output is non-empty since n_fft > 0
5137 let output = unsafe { NonEmptyVec::new_unchecked(output) };
5138 Ok(output)
5139 }
5140
5141 /// Compute power spectrum with optional windowing, reusing cached plans.
5142 ///
5143 /// # Automatic Zero-Padding
5144 ///
5145 /// If the input signal is shorter than `n_fft`, it will be automatically
5146 /// zero-padded to the required length.
5147 ///
5148 /// # Errors
5149 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
5150 ///
5151 /// # Examples
5152 ///
5153 /// ```
5154 /// use spectrograms::*;
5155 /// use non_empty_slice::non_empty_vec;
5156 ///
5157 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
5158 /// let mut planner = FftPlanner::new();
5159 ///
5160 /// let signal = non_empty_vec![1.0; nzu!(512)];
5161 /// let power = planner.power_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
5162 ///
5163 /// assert_eq!(power.len(), nzu!(257));
5164 /// # Ok(())
5165 /// # }
5166 /// ```
5167 #[inline]
5168 pub fn power_spectrum(
5169 &mut self,
5170 samples: &NonEmptySlice<f64>,
5171 n_fft: NonZeroUsize,
5172 window: Option<WindowType>,
5173 ) -> SpectrogramResult<NonEmptyVec<f64>> {
5174 if samples.len() > n_fft {
5175 return Err(SpectrogramError::invalid_input(format!(
5176 "Input length ({}) exceeds FFT size ({})",
5177 samples.len(),
5178 n_fft
5179 )));
5180 }
5181
5182 let mut windowed = vec![0.0; n_fft.get()];
5183 windowed[..samples.len().get()].copy_from_slice(samples);
5184 if let Some(win_type) = window {
5185 let window_samples = make_window(win_type, n_fft);
5186 for i in 0..n_fft.get() {
5187 windowed[i] *= window_samples[i];
5188 }
5189 }
5190
5191 // safety: windowed is non-empty since n_fft > 0
5192 let windowed = unsafe { NonEmptySlice::new_unchecked(&windowed) };
5193 let fft_result = self.fft(windowed, n_fft)?;
5194 let f = fft_result
5195 .iter()
5196 .map(num_complex::Complex::norm_sqr)
5197 .collect();
5198 // safety: fft_result is non-empty since fft returned successfully
5199 Ok(unsafe { NonEmptyVec::new_unchecked(f) })
5200 }
5201
5202 /// Compute magnitude spectrum with optional windowing, reusing cached plans.
5203 ///
5204 /// # Automatic Zero-Padding
5205 ///
5206 /// If the input signal is shorter than `n_fft`, it will be automatically
5207 /// zero-padded to the required length.
5208 ///
5209 /// # Errors
5210 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
5211 ///
5212 /// # Examples
5213 ///
5214 /// ```
5215 /// use spectrograms::*;
5216 /// use non_empty_slice::non_empty_vec;
5217 ///
5218 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
5219 /// let mut planner = FftPlanner::new();
5220 ///
5221 /// let signal = non_empty_vec![1.0; nzu!(512)];
5222 /// let magnitude = planner.magnitude_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
5223 ///
5224 /// assert_eq!(magnitude.len(), nzu!(257));
5225 /// # Ok(())
5226 /// # }
5227 /// ```
5228 #[inline]
5229 pub fn magnitude_spectrum(
5230 &mut self,
5231 samples: &NonEmptySlice<f64>,
5232 n_fft: NonZeroUsize,
5233 window: Option<WindowType>,
5234 ) -> SpectrogramResult<NonEmptyVec<f64>> {
5235 let power = self.power_spectrum(samples, n_fft, window)?;
5236 let power = power.iter().map(|&p| p.sqrt()).collect::<Vec<f64>>();
5237 // safety: power is non-empty since power_spectrum returned successfully
5238 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
5239 }
5240}
5241
5242impl Default for FftPlanner {
5243 #[inline]
5244 fn default() -> Self {
5245 Self::new()
5246 }
5247}
5248
5249#[cfg(test)]
5250mod tests {
5251 use super::*;
5252
5253 #[test]
5254 fn test_sparse_matrix_basic() {
5255 // Create a simple 3x5 sparse matrix
5256 let mut sparse = SparseMatrix::new(3, 5);
5257
5258 // Row 0: only column 1 has value 2.0
5259 sparse.set(0, 1, 2.0);
5260
5261 // Row 1: columns 2 and 3
5262 sparse.set(1, 2, 0.5);
5263 sparse.set(1, 3, 1.5);
5264
5265 // Row 2: columns 0 and 4
5266 sparse.set(2, 0, 3.0);
5267 sparse.set(2, 4, 1.0);
5268
5269 // Test matrix-vector multiplication
5270 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
5271 let mut output = vec![0.0; 3];
5272
5273 sparse.multiply_vec(&input, &mut output);
5274
5275 // Expected results:
5276 // Row 0: 2.0 * 2.0 = 4.0
5277 // Row 1: 0.5 * 3.0 + 1.5 * 4.0 = 1.5 + 6.0 = 7.5
5278 // Row 2: 3.0 * 1.0 + 1.0 * 5.0 = 3.0 + 5.0 = 8.0
5279 assert_eq!(output[0], 4.0);
5280 assert_eq!(output[1], 7.5);
5281 assert_eq!(output[2], 8.0);
5282 }
5283
5284 #[test]
5285 fn test_sparse_matrix_zeros_ignored() {
5286 // Verify that zero values are not stored
5287 let mut sparse = SparseMatrix::new(2, 3);
5288
5289 sparse.set(0, 0, 1.0);
5290 sparse.set(0, 1, 0.0); // Should be ignored
5291 sparse.set(0, 2, 2.0);
5292
5293 // Only 2 values should be stored in row 0
5294 assert_eq!(sparse.values[0].len(), 2);
5295 assert_eq!(sparse.indices[0].len(), 2);
5296
5297 // The stored indices should be 0 and 2
5298 assert_eq!(sparse.indices[0], vec![0, 2]);
5299 assert_eq!(sparse.values[0], vec![1.0, 2.0]);
5300 }
5301
5302 #[test]
5303 fn test_loghz_matrix_sparsity() {
5304 // Verify that LogHz matrices are very sparse (1-2 non-zeros per row)
5305 let sample_rate = 16000.0;
5306 let n_fft = nzu!(512);
5307 let n_bins = nzu!(128);
5308 let f_min = 20.0;
5309 let f_max = sample_rate / 2.0;
5310
5311 let (matrix, _freqs) =
5312 build_loghz_matrix(sample_rate, n_fft, n_bins, f_min, f_max).unwrap();
5313
5314 // Each row should have at most 2 non-zero values (linear interpolation)
5315 for row_idx in 0..matrix.nrows() {
5316 let nnz = matrix.values[row_idx].len();
5317 assert!(
5318 nnz <= 2,
5319 "Row {} has {} non-zeros, expected at most 2",
5320 row_idx,
5321 nnz
5322 );
5323 assert!(nnz >= 1, "Row {} has no non-zeros", row_idx);
5324 }
5325
5326 // Total non-zeros should be close to n_bins * 2
5327 let total_nnz: usize = matrix.values.iter().map(|v| v.len()).sum();
5328 assert!(total_nnz <= n_bins.get() * 2);
5329 assert!(total_nnz >= n_bins.get()); // At least 1 per row
5330 }
5331
5332 #[test]
5333 fn test_mel_matrix_sparsity() {
5334 // Verify that Mel matrices are sparse (triangular filters)
5335 let sample_rate = 16000.0;
5336 let n_fft = nzu!(512);
5337 let n_mels = nzu!(40);
5338 let f_min = 0.0;
5339 let f_max = sample_rate / 2.0;
5340
5341 let matrix =
5342 build_mel_filterbank_matrix(sample_rate, n_fft, n_mels, f_min, f_max, MelNorm::None)
5343 .unwrap();
5344
5345 let n_fft_bins = r2c_output_size(n_fft.get());
5346
5347 // Calculate sparsity
5348 let total_nnz: usize = matrix.values.iter().map(|v| v.len()).sum();
5349 let total_elements = n_mels.get() * n_fft_bins;
5350 let sparsity = 1.0 - (total_nnz as f64 / total_elements as f64);
5351
5352 // Mel filterbanks should be >80% sparse
5353 assert!(
5354 sparsity > 0.8,
5355 "Mel matrix sparsity is only {:.1}%, expected >80%",
5356 sparsity * 100.0
5357 );
5358
5359 // Each mel filter should have significantly fewer than n_fft_bins non-zeros
5360 for row_idx in 0..matrix.nrows() {
5361 let nnz = matrix.values[row_idx].len();
5362 assert!(
5363 nnz < n_fft_bins / 2,
5364 "Mel filter {} is not sparse enough",
5365 row_idx
5366 );
5367 }
5368 }
5369}