spectrograms/spectrogram.rs
1use std::marker::PhantomData;
2use std::num::NonZeroUsize;
3use std::ops::{Deref, DerefMut};
4
5use ndarray::{Array1, Array2};
6use non_empty_slice::{NonEmptySlice, NonEmptyVec, non_empty_vec};
7use num_complex::Complex;
8
9#[cfg(feature = "python")]
10use pyo3::prelude::*;
11
12use crate::cqt::CqtKernel;
13use crate::erb::ErbFilterbank;
14use crate::{
15 CqtParams, ErbParams, R2cPlan, SpectrogramError, SpectrogramResult, WindowType,
16 min_max_single_pass, nzu,
17};
18const EPS: f64 = 1e-12;
19
20//
21// ========================
22// Sparse Matrix for efficient filterbank multiplication
23// ========================
24//
25
26/// Row-wise sparse matrix optimized for matrix-vector multiplication.
27///
28/// This structure stores sparse data as a vector of vectors, where each row maintains
29/// its own list of non-zero values and corresponding column indices. This is more
30/// flexible than traditional CSR format and allows efficient row-by-row construction.
31///
32/// This structure is designed for matrices with very few non-zero values per row,
33/// such as mel filterbanks (triangular filters) and logarithmic frequency mappings
34/// (linear interpolation between 1-2 adjacent bins).
35///
36/// For typical spectrograms:
37/// - `LogHz` interpolation: Only 1-2 non-zeros per row (~99% sparse)
38/// - Mel filterbank: ~10-50 non-zeros per row depending on FFT size (~90-98% sparse)
39///
40/// By storing only non-zero values, we avoid wasting CPU cycles multiplying by zero,
41/// which can provide 10-100x speedup compared to dense matrix multiplication.
42#[derive(Debug, Clone)]
43struct SparseMatrix {
44 /// Number of rows
45 nrows: usize,
46 /// Number of columns
47 ncols: usize,
48 /// Non-zero values for each row (row-major order)
49 values: Vec<Vec<f64>>,
50 /// Column indices for each non-zero value
51 indices: Vec<Vec<usize>>,
52}
53
54impl SparseMatrix {
55 /// Create a new sparse matrix with the given dimensions.
56 fn new(nrows: usize, ncols: usize) -> Self {
57 Self {
58 nrows,
59 ncols,
60 values: vec![Vec::new(); nrows],
61 indices: vec![Vec::new(); nrows],
62 }
63 }
64
65 /// Set a value in the matrix. Only stores if value is non-zero.
66 ///
67 /// # Panics (debug mode only)
68 /// Panics in debug builds if row or col are out of bounds.
69 fn set(&mut self, row: usize, col: usize, value: f64) {
70 debug_assert!(
71 row < self.nrows && col < self.ncols,
72 "SparseMatrix index out of bounds: ({}, {}) for {}x{} matrix",
73 row,
74 col,
75 self.nrows,
76 self.ncols
77 );
78 if row >= self.nrows || col >= self.ncols {
79 return;
80 }
81
82 // Only store non-zero values (with small epsilon for numerical stability)
83 if value.abs() > 1e-10 {
84 self.values[row].push(value);
85 self.indices[row].push(col);
86 }
87 }
88
89 /// Get the number of rows.
90 const fn nrows(&self) -> usize {
91 self.nrows
92 }
93
94 /// Get the number of columns.
95 const fn ncols(&self) -> usize {
96 self.ncols
97 }
98
99 /// Perform sparse matrix-vector multiplication: out = self * input
100 /// This is much faster than dense multiplication when the matrix is sparse.
101 #[inline]
102 fn multiply_vec(&self, input: &[f64], out: &mut [f64]) {
103 debug_assert_eq!(input.len(), self.ncols);
104 debug_assert_eq!(out.len(), self.nrows);
105
106 for (row_idx, (row_values, row_indices)) in
107 self.values.iter().zip(&self.indices).enumerate()
108 {
109 let mut acc = 0.0;
110 for (&value, &col_idx) in row_values.iter().zip(row_indices) {
111 acc += value * input[col_idx];
112 }
113 out[row_idx] = acc;
114 }
115 }
116}
117
118// Linear frequency
119pub type LinearPowerSpectrogram = Spectrogram<LinearHz, Power>;
120pub type LinearMagnitudeSpectrogram = Spectrogram<LinearHz, Magnitude>;
121pub type LinearDbSpectrogram = Spectrogram<LinearHz, Decibels>;
122pub type LinearSpectrogram<AmpScale> = Spectrogram<LinearHz, AmpScale>;
123
124// Log-frequency (e.g. CQT-style)
125pub type LogHzPowerSpectrogram = Spectrogram<LogHz, Power>;
126pub type LogHzMagnitudeSpectrogram = Spectrogram<LogHz, Magnitude>;
127pub type LogHzDbSpectrogram = Spectrogram<LogHz, Decibels>;
128pub type LogHzSpectrogram<AmpScale> = Spectrogram<LogHz, AmpScale>;
129
130// ERB / gammatone
131pub type ErbPowerSpectrogram = Spectrogram<Erb, Power>;
132pub type ErbMagnitudeSpectrogram = Spectrogram<Erb, Magnitude>;
133pub type ErbDbSpectrogram = Spectrogram<Erb, Decibels>;
134pub type GammatonePowerSpectrogram = ErbPowerSpectrogram;
135pub type GammatoneMagnitudeSpectrogram = ErbMagnitudeSpectrogram;
136pub type GammatoneDbSpectrogram = ErbDbSpectrogram;
137pub type ErbSpectrogram<AmpScale> = Spectrogram<Erb, AmpScale>;
138pub type GammatoneSpectrogram<AmpScale> = ErbSpectrogram<AmpScale>;
139
140// Mel
141pub type MelMagnitudeSpectrogram = Spectrogram<Mel, Magnitude>;
142pub type MelPowerSpectrogram = Spectrogram<Mel, Power>;
143pub type MelDbSpectrogram = Spectrogram<Mel, Decibels>;
144pub type LogMelSpectrogram = MelDbSpectrogram;
145pub type MelSpectrogram<AmpScale> = Spectrogram<Mel, AmpScale>;
146
147// CQT
148pub type CqtPowerSpectrogram = Spectrogram<Cqt, Power>;
149pub type CqtMagnitudeSpectrogram = Spectrogram<Cqt, Magnitude>;
150pub type CqtDbSpectrogram = Spectrogram<Cqt, Decibels>;
151pub type CqtSpectrogram<AmpScale> = Spectrogram<Cqt, AmpScale>;
152
153use crate::fft_backend::r2c_output_size;
154
155/// A spectrogram plan is the compiled, reusable execution object.
156///
157/// It owns:
158/// - FFT plan (reusable)
159/// - window samples
160/// - mapping (identity / mel filterbank / etc.)
161/// - amplitude scaling config
162/// - workspace buffers to avoid allocations in hot loops
163///
164/// It computes one specific spectrogram type: `Spectrogram<FreqScale, AmpScale>`.
165///
166/// # Type Parameters
167///
168/// - `FreqScale`: Frequency scale type (e.g. `LinearHz`, `LogHz`, `Mel`, etc.)
169/// - `AmpScale`: Amplitude scaling type (e.g. `Power`, `Magnitude`, `Decibels`, etc.)
170pub struct SpectrogramPlan<FreqScale, AmpScale>
171where
172 AmpScale: AmpScaleSpec + 'static,
173 FreqScale: Copy + Clone + 'static,
174{
175 params: SpectrogramParams,
176
177 stft: StftPlan,
178 mapping: FrequencyMapping<FreqScale>,
179 scaling: AmplitudeScaling<AmpScale>,
180
181 freq_axis: FrequencyAxis<FreqScale>,
182 workspace: Workspace,
183
184 _amp: PhantomData<AmpScale>,
185}
186
187impl<FreqScale, AmpScale> SpectrogramPlan<FreqScale, AmpScale>
188where
189 AmpScale: AmpScaleSpec + 'static,
190 FreqScale: Copy + Clone + 'static,
191{
192 /// Get the spectrogram parameters used to create this plan.
193 ///
194 /// # Returns
195 ///
196 /// A reference to the `SpectrogramParams` used in this plan.
197 #[inline]
198 #[must_use]
199 pub const fn params(&self) -> &SpectrogramParams {
200 &self.params
201 }
202
203 /// Get the frequency axis for this spectrogram plan.
204 ///
205 /// # Returns
206 ///
207 /// A reference to the `FrequencyAxis<FreqScale>` used in this plan.
208 #[inline]
209 #[must_use]
210 pub const fn freq_axis(&self) -> &FrequencyAxis<FreqScale> {
211 &self.freq_axis
212 }
213
214 /// Compute a spectrogram for a mono signal.
215 ///
216 /// This function performs:
217 /// - framing + windowing
218 /// - FFT per frame
219 /// - magnitude/power
220 /// - frequency mapping (identity/mel/etc.)
221 /// - amplitude scaling (linear or dB)
222 ///
223 /// It allocates the output `Array2` once, but does not allocate per-frame.
224 ///
225 /// # Arguments
226 ///
227 /// * `samples` - Audio samples
228 ///
229 /// # Returns
230 ///
231 /// A `Spectrogram<FreqScale, AmpScale>` containing the computed spectrogram.
232 ///
233 /// # Errors
234 ///
235 /// Returns an error if STFT computation or mapping fails.
236 #[inline]
237 pub fn compute(
238 &mut self,
239 samples: &NonEmptySlice<f64>,
240 ) -> SpectrogramResult<Spectrogram<FreqScale, AmpScale>> {
241 let n_frames = self.stft.frame_count(samples.len())?;
242 let n_bins = self.mapping.output_bins();
243
244 // Create output matrix: (n_bins, n_frames)
245 let mut data = Array2::<f64>::zeros((n_bins.get(), n_frames.get()));
246
247 // Ensure workspace is correctly sized
248 self.workspace
249 .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
250
251 // Main loop: fill each frame (column)
252 for frame_idx in 0..n_frames.get() {
253 // CQT needs unwindowed frames because its kernels already contain windowing.
254 // Other mappings use the FFT spectrum, so they need the windowed frame.
255 if self.mapping.kind.needs_unwindowed_frame() {
256 // Fill frame without windowing (for CQT)
257 self.stft
258 .fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
259 } else {
260 // Compute windowed frame spectrum (for FFT-based mappings)
261 self.stft
262 .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
263 }
264
265 // mapping: spectrum(out_len) -> mapped(n_bins)
266 // For CQT, this uses workspace.frame; for others, workspace.spectrum
267 // For ERB, we need the complex FFT output (fft_out)
268 // We need to borrow workspace fields separately to avoid borrow conflicts
269 let Workspace {
270 spectrum,
271 mapped,
272 frame,
273 ..
274 } = &mut self.workspace;
275
276 self.mapping.apply(spectrum, frame, mapped)?;
277
278 // amplitude scaling in-place on mapped vector
279 self.scaling.apply_in_place(mapped)?;
280
281 // write column into output
282 for (row, &val) in mapped.iter().enumerate() {
283 data[[row, frame_idx]] = val;
284 }
285 }
286
287 let times = build_time_axis_seconds(&self.params, n_frames);
288 let axes = Axes::new(self.freq_axis.clone(), times);
289
290 Ok(Spectrogram::new(data, axes, self.params.clone()))
291 }
292
293 /// Compute a single frame of the spectrogram.
294 ///
295 /// This is useful for streaming/online processing where you want to
296 /// process audio frame-by-frame without computing the entire spectrogram.
297 ///
298 /// # Arguments
299 ///
300 /// * `samples` - Audio samples (must contain at least enough samples for the requested frame)
301 /// * `frame_idx` - Frame index to compute
302 ///
303 /// # Returns
304 ///
305 /// A vector of frequency bin values for the requested frame.
306 ///
307 /// # Errors
308 ///
309 /// Returns an error if the frame index is out of bounds or if STFT computation fails.
310 ///
311 /// # Examples
312 ///
313 /// ```
314 /// use spectrograms::*;
315 /// use non_empty_slice::non_empty_vec;
316 ///
317 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
318 /// let samples = non_empty_vec![0.0; nzu!(16000)];
319 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
320 /// let params = SpectrogramParams::new(stft, 16000.0)?;
321 ///
322 /// let planner = SpectrogramPlanner::new();
323 /// let mut plan = planner.linear_plan::<Power>(¶ms, None)?;
324 ///
325 /// // Compute just the first frame
326 /// let frame = plan.compute_frame(&samples, 0)?;
327 /// assert_eq!(frame.len(), nzu!(257)); // n_fft/2 + 1
328 /// # Ok(())
329 /// # }
330 /// ```
331 #[inline]
332 pub fn compute_frame(
333 &mut self,
334 samples: &NonEmptySlice<f64>,
335 frame_idx: usize,
336 ) -> SpectrogramResult<NonEmptyVec<f64>> {
337 let n_bins = self.mapping.output_bins();
338
339 // Ensure workspace is correctly sized
340 self.workspace
341 .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
342
343 // CQT needs unwindowed frames because its kernels already contain windowing.
344 // Other mappings use the FFT spectrum, so they need the windowed frame.
345 if self.mapping.kind.needs_unwindowed_frame() {
346 // Fill frame without windowing (for CQT)
347 self.stft
348 .fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
349 } else {
350 // Compute windowed frame spectrum (for FFT-based mappings)
351 self.stft
352 .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
353 }
354
355 // Apply mapping (using split borrows to avoid borrow conflicts)
356 let Workspace {
357 spectrum,
358 mapped,
359 frame,
360 ..
361 } = &mut self.workspace;
362
363 self.mapping.apply(spectrum, frame, mapped)?;
364
365 // Apply amplitude scaling
366 self.scaling.apply_in_place(mapped)?;
367
368 Ok(mapped.clone())
369 }
370
371 /// Compute spectrogram into a pre-allocated buffer.
372 ///
373 /// This avoids allocating the output matrix, which is useful when
374 /// you want to reuse buffers or have strict memory requirements.
375 ///
376 /// # Arguments
377 ///
378 /// * `samples` - Audio samples
379 /// * `output` - Pre-allocated output matrix (must be correct size: `n_bins` x `n_frames`)
380 ///
381 /// # Returns
382 ///
383 /// An empty result on success.
384 ///
385 /// # Errors
386 ///
387 /// Returns an error if the output buffer dimensions don't match the expected size.
388 ///
389 /// # Examples
390 ///
391 /// ```
392 /// use spectrograms::*;
393 /// use ndarray::Array2;
394 /// use non_empty_slice::non_empty_vec;
395 ///
396 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
397 /// let samples = non_empty_vec![0.0; nzu!(16000)];
398 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
399 /// let params = SpectrogramParams::new(stft, 16000.0)?;
400 ///
401 /// let planner = SpectrogramPlanner::new();
402 /// let mut plan = planner.linear_plan::<Power>(¶ms, None)?;
403 ///
404 /// // Pre-allocate output buffer
405 /// let mut output = Array2::<f64>::zeros((257, 63));
406 /// plan.compute_into(&samples, &mut output)?;
407 /// # Ok(())
408 /// # }
409 /// ```
410 #[inline]
411 pub fn compute_into(
412 &mut self,
413 samples: &NonEmptySlice<f64>,
414 output: &mut Array2<f64>,
415 ) -> SpectrogramResult<()> {
416 let n_frames = self.stft.frame_count(samples.len())?;
417 let n_bins = self.mapping.output_bins();
418
419 // Validate output dimensions
420 if output.nrows() != n_bins.get() {
421 return Err(SpectrogramError::dimension_mismatch(
422 n_bins.get(),
423 output.nrows(),
424 ));
425 }
426 if output.ncols() != n_frames.get() {
427 return Err(SpectrogramError::dimension_mismatch(
428 n_frames.get(),
429 output.ncols(),
430 ));
431 }
432
433 // Ensure workspace is correctly sized
434 self.workspace
435 .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
436
437 // Main loop: fill each frame (column)
438 for frame_idx in 0..n_frames.get() {
439 // CQT needs unwindowed frames because its kernels already contain windowing.
440 // Other mappings use the FFT spectrum, so they need the windowed frame.
441 if self.mapping.kind.needs_unwindowed_frame() {
442 // Fill frame without windowing (for CQT)
443 self.stft
444 .fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
445 } else {
446 // Compute windowed frame spectrum (for FFT-based mappings)
447 self.stft
448 .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
449 }
450
451 // mapping: spectrum(out_len) -> mapped(n_bins)
452 // For CQT, this uses workspace.frame; for others, workspace.spectrum
453 // For ERB, we need the complex FFT output (fft_out)
454 // We need to borrow workspace fields separately to avoid borrow conflicts
455 let Workspace {
456 spectrum,
457 mapped,
458 frame,
459 ..
460 } = &mut self.workspace;
461
462 self.mapping.apply(spectrum, frame, mapped)?;
463
464 // amplitude scaling in-place on mapped vector
465 self.scaling.apply_in_place(mapped)?;
466
467 // write column into output
468 for (row, &val) in mapped.iter().enumerate() {
469 output[[row, frame_idx]] = val;
470 }
471 }
472
473 Ok(())
474 }
475
476 /// Get the expected output dimensions for a given signal length.
477 ///
478 /// # Arguments
479 ///
480 /// * `signal_length` - Length of the input signal in samples.
481 ///
482 /// # Returns
483 ///
484 /// A tuple `(n_bins, n_frames)` representing the output spectrogram shape.
485 ///
486 /// # Errors
487 ///
488 /// Returns an error if the signal length is too short to produce any frames.
489 ///
490 /// # Examples
491 ///
492 /// ```
493 /// use spectrograms::*;
494 ///
495 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
496 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
497 /// let params = SpectrogramParams::new(stft, 16000.0)?;
498 ///
499 /// let planner = SpectrogramPlanner::new();
500 /// let plan = planner.linear_plan::<Power>(¶ms, None)?;
501 ///
502 /// let (n_bins, n_frames) = plan.output_shape(nzu!(16000))?;
503 /// assert_eq!(n_bins, nzu!(257));
504 /// assert_eq!(n_frames, nzu!(63));
505 /// # Ok(())
506 /// # }
507 /// ```
508 #[inline]
509 pub fn output_shape(
510 &self,
511 signal_length: NonZeroUsize,
512 ) -> SpectrogramResult<(NonZeroUsize, NonZeroUsize)> {
513 let n_frames = self.stft.frame_count(signal_length)?;
514 let n_bins = self.mapping.output_bins();
515 Ok((n_bins, n_frames))
516 }
517}
518
519/// STFT (Short-Time Fourier Transform) result containing complex frequency bins.
520///
521/// This is the raw STFT output before any frequency mapping or amplitude scaling.
522///
523/// # Fields
524///
525/// - `data`: Complex STFT matrix with shape (`frequency_bins`, `time_frames`)
526/// - `frequencies`: Frequency axis in Hz
527/// - `sample_rate`: Sample rate in Hz
528/// - `params`: STFT computation parameters
529#[derive(Debug, Clone)]
530#[non_exhaustive]
531pub struct StftResult {
532 /// Complex STFT matrix with shape (`frequency_bins`, `time_frames`)
533 pub data: Array2<Complex<f64>>,
534 /// Frequency axis in Hz
535 pub frequencies: NonEmptyVec<f64>,
536 /// Sample rate in Hz
537 pub sample_rate: f64,
538 pub params: StftParams,
539}
540
541impl StftResult {
542 /// Get the number of frequency bins.
543 ///
544 /// # Returns
545 ///
546 /// Number of frequency bins in the STFT result.
547 #[inline]
548 #[must_use]
549 pub fn n_bins(&self) -> NonZeroUsize {
550 // safety: nrows() > 0 for NonEmptyVec
551 unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
552 }
553
554 /// Get the number of time frames.
555 ///
556 /// # Returns
557 ///
558 /// Number of time frames in the STFT result.
559 #[inline]
560 #[must_use]
561 pub fn n_frames(&self) -> NonZeroUsize {
562 // safety: ncols() > 0 for NonEmptyVec
563 unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
564 }
565
566 /// Get the frequency resolution in Hz
567 ///
568 /// # Returns
569 ///
570 /// Frequency bin width in Hz.
571 #[inline]
572 #[must_use]
573 pub fn frequency_resolution(&self) -> f64 {
574 self.sample_rate / self.params.n_fft().get() as f64
575 }
576
577 /// Get the time resolution in seconds.
578 ///
579 /// # Returns
580 ///
581 /// Time between successive frames in seconds.
582 #[inline]
583 #[must_use]
584 pub fn time_resolution(&self) -> f64 {
585 self.params.hop_size().get() as f64 / self.sample_rate
586 }
587
588 /// Normalizes self.data to remove the complex aspect of it.
589 ///
590 /// # Returns
591 ///
592 /// An Array2\<f64\> containing the norms of each complex number in self.data.
593 #[inline]
594 pub fn norm(&self) -> Array2<f64> {
595 self.as_ref().mapv(Complex::norm)
596 }
597}
598
599impl AsRef<Array2<Complex<f64>>> for StftResult {
600 #[inline]
601 fn as_ref(&self) -> &Array2<Complex<f64>> {
602 &self.data
603 }
604}
605
606impl AsMut<Array2<Complex<f64>>> for StftResult {
607 #[inline]
608 fn as_mut(&mut self) -> &mut Array2<Complex<f64>> {
609 &mut self.data
610 }
611}
612
613impl Deref for StftResult {
614 type Target = Array2<Complex<f64>>;
615
616 #[inline]
617 fn deref(&self) -> &Self::Target {
618 &self.data
619 }
620}
621
622impl DerefMut for StftResult {
623 #[inline]
624 fn deref_mut(&mut self) -> &mut Self::Target {
625 &mut self.data
626 }
627}
628
629/// A planner is an object that can build spectrogram plans.
630///
631/// In your design, this is where:
632/// - FFT plans are created
633/// - mapping matrices are compiled
634/// - axes are computed
635///
636/// This allows you to keep plan building separate from the output types.
637#[derive(Debug, Default)]
638#[non_exhaustive]
639pub struct SpectrogramPlanner;
640
641impl SpectrogramPlanner {
642 /// Create a new spectrogram planner.
643 ///
644 /// # Returns
645 ///
646 /// A new `SpectrogramPlanner` instance.
647 #[inline]
648 #[must_use]
649 pub const fn new() -> Self {
650 Self
651 }
652
653 /// Compute the Short-Time Fourier Transform (STFT) of a signal.
654 ///
655 /// This returns the raw complex STFT matrix before any frequency mapping
656 /// or amplitude scaling. Useful for applications that need the full complex
657 /// spectrum or custom processing.
658 ///
659 /// # Arguments
660 ///
661 /// * `samples` - Audio samples (any type that can be converted to a slice)
662 /// * `params` - STFT computation parameters
663 ///
664 /// # Returns
665 ///
666 /// An `StftResult` containing the complex STFT matrix and metadata.
667 ///
668 /// # Errors
669 ///
670 /// Returns an error if STFT computation fails.
671 ///
672 /// # Examples
673 ///
674 /// ```
675 /// use spectrograms::*;
676 /// use non_empty_slice::non_empty_vec;
677 ///
678 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
679 /// let samples = non_empty_vec![0.0; nzu!(16000)];
680 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
681 /// let params = SpectrogramParams::new(stft, 16000.0)?;
682 ///
683 /// let planner = SpectrogramPlanner::new();
684 /// let stft_result = planner.compute_stft(&samples, ¶ms)?;
685 ///
686 /// println!("STFT: {} bins x {} frames", stft_result.n_bins(), stft_result.n_frames());
687 /// # Ok(())
688 /// # }
689 /// ```
690 ///
691 /// # Performance Note
692 ///
693 /// This method creates a new FFT plan each time. For processing multiple
694 /// signals, create a reusable plan with `StftPlan::new()` instead.
695 ///
696 /// # Examples
697 ///
698 /// ```rust
699 /// use spectrograms::*;
700 /// use non_empty_slice::non_empty_vec;
701 ///
702 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
703 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
704 /// let params = SpectrogramParams::new(stft, 16000.0)?;
705 ///
706 /// // One-shot (convenient)
707 /// let planner = SpectrogramPlanner::new();
708 /// let stft_result = planner.compute_stft(&non_empty_vec![0.0; nzu!(16000)], ¶ms)?;
709 ///
710 /// // Reusable plan (efficient for batch)
711 /// let mut plan = StftPlan::new(¶ms)?;
712 /// for signal in &[non_empty_vec![0.0; nzu!(16000)], non_empty_vec![1.0; nzu!(16000)]] {
713 /// let stft = plan.compute(&signal, ¶ms)?;
714 /// }
715 /// # Ok(())
716 /// # }
717 /// ```
718 #[inline]
719 pub fn compute_stft(
720 &self,
721 samples: &NonEmptySlice<f64>,
722 params: &SpectrogramParams,
723 ) -> SpectrogramResult<StftResult> {
724 let mut plan = StftPlan::new(params)?;
725 plan.compute(samples, params)
726 }
727
728 /// Compute the power spectrum of a single audio frame.
729 ///
730 /// This is useful for real-time processing or analyzing individual frames.
731 ///
732 /// # Arguments
733 ///
734 /// * `samples` - Audio frame (length ≤ n_fft, will be zero-padded if shorter)
735 /// * `n_fft` - FFT size
736 /// * `window` - Window type to apply
737 ///
738 /// # Returns
739 ///
740 /// A vector of power values (|X|²) with length `n_fft/2` + 1.
741 ///
742 /// # Automatic Zero-Padding
743 ///
744 /// If the input signal is shorter than `n_fft`, it will be automatically
745 /// zero-padded to the required length.
746 ///
747 /// # Errors
748 ///
749 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
750 ///
751 /// # Examples
752 ///
753 /// ```
754 /// use spectrograms::*;
755 /// use non_empty_slice::non_empty_vec;
756 ///
757 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
758 /// let frame = non_empty_vec![0.0; nzu!(512)];
759 ///
760 /// let planner = SpectrogramPlanner::new();
761 /// let power = planner.compute_power_spectrum(frame.as_ref(), nzu!(512), WindowType::Hanning)?;
762 ///
763 /// assert_eq!(power.len(), nzu!(257)); // 512/2 + 1
764 /// # Ok(())
765 /// # }
766 /// ```
767 #[inline]
768 pub fn compute_power_spectrum(
769 &self,
770 samples: &NonEmptySlice<f64>,
771 n_fft: NonZeroUsize,
772 window: WindowType,
773 ) -> SpectrogramResult<NonEmptyVec<f64>> {
774 if samples.len() > n_fft {
775 return Err(SpectrogramError::invalid_input(format!(
776 "Input length ({}) exceeds FFT size ({})",
777 samples.len(),
778 n_fft
779 )));
780 }
781
782 let window_samples = make_window(window, n_fft);
783 let out_len = r2c_output_size(n_fft.get());
784
785 // Create FFT plan
786 #[cfg(feature = "realfft")]
787 let mut fft = {
788 let mut planner = crate::RealFftPlanner::new();
789 let plan = planner.get_or_create(n_fft.get());
790 crate::RealFftPlan::new(n_fft.get(), plan)
791 };
792
793 #[cfg(feature = "fftw")]
794 let mut fft = {
795 use std::sync::Arc;
796 let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
797 crate::FftwPlan::new(Arc::new(plan))
798 };
799
800 // Apply window and compute FFT
801 let mut windowed = vec![0.0; n_fft.get()];
802 for i in 0..samples.len().get() {
803 windowed[i] = samples[i] * window_samples[i];
804 }
805 // The rest is already zero-padded
806 let mut fft_out = vec![Complex::new(0.0, 0.0); out_len];
807 fft.process(&windowed, &mut fft_out)?;
808
809 // Convert to power
810 let power: Vec<f64> = fft_out.iter().map(num_complex::Complex::norm_sqr).collect();
811 // safety: power is non-empty since n_fft > 0
812 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
813 }
814
815 /// Compute the magnitude spectrum of a single audio frame.
816 ///
817 /// This is useful for real-time processing or analyzing individual frames.
818 ///
819 /// # Arguments
820 ///
821 /// * `samples` - Audio frame (length ≤ n_fft, will be zero-padded if shorter)
822 /// * `n_fft` - FFT size
823 /// * `window` - Window type to apply
824 ///
825 /// # Returns
826 ///
827 /// A vector of magnitude values (|X|) with length `n_fft/2` + 1.
828 ///
829 /// # Automatic Zero-Padding
830 ///
831 /// If the input signal is shorter than `n_fft`, it will be automatically
832 /// zero-padded to the required length.
833 ///
834 /// # Errors
835 ///
836 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
837 ///
838 /// # Examples
839 ///
840 /// ```
841 /// use spectrograms::*;
842 /// use non_empty_slice::non_empty_vec;
843 ///
844 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
845 /// let frame = non_empty_vec![0.0; nzu!(512)];
846 ///
847 /// let planner = SpectrogramPlanner::new();
848 /// let magnitude = planner.compute_magnitude_spectrum(frame.as_ref(), nzu!(512), WindowType::Hanning)?;
849 ///
850 /// assert_eq!(magnitude.len(), nzu!(257)); // 512/2 + 1
851 /// # Ok(())
852 /// # }
853 /// ```
854 #[inline]
855 pub fn compute_magnitude_spectrum(
856 &self,
857 samples: &NonEmptySlice<f64>,
858 n_fft: NonZeroUsize,
859 window: WindowType,
860 ) -> SpectrogramResult<NonEmptyVec<f64>> {
861 let power = self.compute_power_spectrum(samples, n_fft, window)?;
862 let power = power.iter().map(|&p| p.sqrt()).collect::<Vec<f64>>();
863 // safety: power is non-empty since power_spectrum returned successfully
864 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
865 }
866
867 /// Build a linear-frequency spectrogram plan.
868 ///
869 /// # Type Parameters
870 ///
871 /// `AmpScale` determines whether output is:
872 /// - Magnitude
873 /// - Power
874 /// - Decibels
875 ///
876 /// # Arguments
877 ///
878 /// * `params` - Spectrogram parameters
879 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale
880 /// is `Decibels`)
881 ///
882 /// # Returns
883 ///
884 /// A `SpectrogramPlan` configured for linear-frequency spectrogram computation.
885 ///
886 /// # Errors
887 ///
888 /// Returns an error if the plan cannot be created due to invalid parameters.
889 #[inline]
890 pub fn linear_plan<AmpScale>(
891 &self,
892 params: &SpectrogramParams,
893 db: Option<&LogParams>, // only used when AmpScale = Decibels
894 ) -> SpectrogramResult<SpectrogramPlan<LinearHz, AmpScale>>
895 where
896 AmpScale: AmpScaleSpec + 'static,
897 {
898 let stft = StftPlan::new(params)?;
899 let mapping = FrequencyMapping::<LinearHz>::new(params)?;
900 let scaling = AmplitudeScaling::<AmpScale>::new(db);
901 let freq_axis = build_frequency_axis::<LinearHz>(params, &mapping);
902
903 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
904
905 Ok(SpectrogramPlan {
906 params: params.clone(),
907 stft,
908 mapping,
909 scaling,
910 freq_axis,
911 workspace,
912 _amp: PhantomData,
913 })
914 }
915
916 /// Build a mel-frequency spectrogram plan.
917 ///
918 /// This compiles a mel filterbank matrix and caches it inside the plan.
919 ///
920 /// # Type Parameters
921 ///
922 /// `AmpScale`: determines whether output is:
923 /// - Magnitude
924 /// - Power
925 /// - Decibels
926 ///
927 /// # Arguments
928 ///
929 /// * `params` - Spectrogram parameters
930 /// * `mel` - Mel-specific parameters
931 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
932 ///
933 /// # Returns
934 ///
935 /// A `SpectrogramPlan` configured for mel spectrogram computation.
936 ///
937 /// # Errors
938 ///
939 /// Returns an error if the plan cannot be created due to invalid parameters.
940 #[inline]
941 pub fn mel_plan<AmpScale>(
942 &self,
943 params: &SpectrogramParams,
944 mel: &MelParams,
945 db: Option<&LogParams>, // only used when AmpScale = Decibels
946 ) -> SpectrogramResult<SpectrogramPlan<Mel, AmpScale>>
947 where
948 AmpScale: AmpScaleSpec + 'static,
949 {
950 // cross-validation: mel range must be compatible with sample rate
951 let nyquist = params.nyquist_hz();
952 if mel.f_max() > nyquist {
953 return Err(SpectrogramError::invalid_input(
954 "mel f_max must be <= Nyquist",
955 ));
956 }
957
958 let stft = StftPlan::new(params)?;
959 let mapping = FrequencyMapping::<Mel>::new_mel(params, mel)?;
960 let scaling = AmplitudeScaling::<AmpScale>::new(db);
961 let freq_axis = build_frequency_axis::<Mel>(params, &mapping);
962
963 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
964
965 Ok(SpectrogramPlan {
966 params: params.clone(),
967 stft,
968 mapping,
969 scaling,
970 freq_axis,
971 workspace,
972 _amp: PhantomData,
973 })
974 }
975
976 /// Build an ERB-scale spectrogram plan.
977 ///
978 /// This creates a spectrogram with ERB-spaced frequency bands using gammatone
979 /// filterbank approximation in the frequency domain.
980 ///
981 /// # Type Parameters
982 ///
983 /// `AmpScale`: determines whether output is:
984 /// - Magnitude
985 /// - Power
986 /// - Decibels
987 ///
988 /// # Arguments
989 ///
990 /// * `params` - Spectrogram parameters
991 /// * `erb` - ERB-specific parameters
992 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
993 ///
994 /// # Returns
995 ///
996 /// A `SpectrogramPlan` configured for ERB spectrogram computation.
997 ///
998 /// # Errors
999 ///
1000 /// Returns an error if the plan cannot be created due to invalid parameters.
1001 #[inline]
1002 pub fn erb_plan<AmpScale>(
1003 &self,
1004 params: &SpectrogramParams,
1005 erb: &ErbParams,
1006 db: Option<&LogParams>,
1007 ) -> SpectrogramResult<SpectrogramPlan<Erb, AmpScale>>
1008 where
1009 AmpScale: AmpScaleSpec + 'static,
1010 {
1011 // cross-validation: erb range must be compatible with sample rate
1012 let nyquist = params.nyquist_hz();
1013 if erb.f_max() > nyquist {
1014 return Err(SpectrogramError::invalid_input(format!(
1015 "f_max={} exceeds Nyquist={}",
1016 erb.f_max(),
1017 nyquist
1018 )));
1019 }
1020
1021 let stft = StftPlan::new(params)?;
1022 let mapping = FrequencyMapping::<Erb>::new_erb(params, erb)?;
1023 let scaling = AmplitudeScaling::<AmpScale>::new(db);
1024 let freq_axis = build_frequency_axis::<Erb>(params, &mapping);
1025
1026 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1027
1028 Ok(SpectrogramPlan {
1029 params: params.clone(),
1030 stft,
1031 mapping,
1032 scaling,
1033 freq_axis,
1034 workspace,
1035 _amp: PhantomData,
1036 })
1037 }
1038
1039 /// Build a log-frequency plan.
1040 ///
1041 /// This creates a spectrogram with logarithmically-spaced frequency bins.
1042 ///
1043 /// # Type Parameters
1044 ///
1045 /// `AmpScale`: determines whether output is:
1046 /// - Magnitude
1047 /// - Power
1048 /// - Decibels
1049 ///
1050 /// # Arguments
1051 ///
1052 /// * `params` - Spectrogram parameters
1053 /// * `loghz` - LogHz-specific parameters
1054 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
1055 ///
1056 /// # Returns
1057 ///
1058 /// A `SpectrogramPlan` configured for log-frequency spectrogram computation.
1059 ///
1060 /// # Errors
1061 ///
1062 /// Returns an error if the plan cannot be created due to invalid parameters.
1063 #[inline]
1064 pub fn log_hz_plan<AmpScale>(
1065 &self,
1066 params: &SpectrogramParams,
1067 loghz: &LogHzParams,
1068 db: Option<&LogParams>,
1069 ) -> SpectrogramResult<SpectrogramPlan<LogHz, AmpScale>>
1070 where
1071 AmpScale: AmpScaleSpec + 'static,
1072 {
1073 // cross-validation: loghz range must be compatible with sample rate
1074 let nyquist = params.nyquist_hz();
1075 if loghz.f_max() > nyquist {
1076 return Err(SpectrogramError::invalid_input(format!(
1077 "f_max={} exceeds Nyquist={}",
1078 loghz.f_max(),
1079 nyquist
1080 )));
1081 }
1082
1083 let stft = StftPlan::new(params)?;
1084 let mapping = FrequencyMapping::<LogHz>::new_loghz(params, loghz)?;
1085 let scaling = AmplitudeScaling::<AmpScale>::new(db);
1086 let freq_axis = build_frequency_axis::<LogHz>(params, &mapping);
1087
1088 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1089
1090 Ok(SpectrogramPlan {
1091 params: params.clone(),
1092 stft,
1093 mapping,
1094 scaling,
1095 freq_axis,
1096 workspace,
1097 _amp: PhantomData,
1098 })
1099 }
1100
1101 /// Build a cqt spectrogram plan.
1102 ///
1103 /// # Type Parameters
1104 ///
1105 /// `AmpScale`: determines whether output is:
1106 /// - Magnitude
1107 /// - Power
1108 /// - Decibels
1109 ///
1110 /// # Arguments
1111 ///
1112 /// * `params` - Spectrogram parameters
1113 /// * `cqt` - CQT-specific parameters
1114 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
1115 ///
1116 /// # Returns
1117 ///
1118 /// A `SpectrogramPlan` configured for CQT spectrogram computation.
1119 ///
1120 /// # Errors
1121 ///
1122 /// Returns an error if the plan cannot be created due to invalid parameters.
1123 #[inline]
1124 pub fn cqt_plan<AmpScale>(
1125 &self,
1126 params: &SpectrogramParams,
1127 cqt: &CqtParams,
1128 db: Option<&LogParams>, // only used when AmpScale = Decibels
1129 ) -> SpectrogramResult<SpectrogramPlan<Cqt, AmpScale>>
1130 where
1131 AmpScale: AmpScaleSpec + 'static,
1132 {
1133 let stft = StftPlan::new(params)?;
1134 let mapping = FrequencyMapping::<Cqt>::new(params, cqt)?;
1135 let scaling = AmplitudeScaling::<AmpScale>::new(db);
1136 let freq_axis = build_frequency_axis::<Cqt>(params, &mapping);
1137
1138 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1139
1140 Ok(SpectrogramPlan {
1141 params: params.clone(),
1142 stft,
1143 mapping,
1144 scaling,
1145 freq_axis,
1146 workspace,
1147 _amp: PhantomData,
1148 })
1149 }
1150}
1151
1152/// STFT plan containing reusable FFT plan and buffers.
1153///
1154/// This struct is responsible for performing the Short-Time Fourier Transform (STFT)
1155/// on audio signals based on the provided parameters.
1156///
1157/// It encapsulates the FFT plan, windowing function, and internal buffers to efficiently
1158/// compute the STFT for multiple frames of audio data.
1159///
1160/// # Fields
1161///
1162/// - `n_fft`: Size of the FFT.
1163/// - `hop_size`: Hop size between consecutive frames.
1164/// - `window`: Windowing function samples.
1165/// - `centre`: Whether to centre the frames with padding.
1166/// - `out_len`: Length of the FFT output.
1167/// - `fft`: Boxed FFT plan for real-to-complex transformation.
1168/// - `fft_out`: Internal buffer for FFT output.
1169/// - `frame`: Internal buffer for windowed audio frames.
1170pub struct StftPlan {
1171 n_fft: NonZeroUsize,
1172 hop_size: NonZeroUsize,
1173 window: NonEmptyVec<f64>,
1174 centre: bool,
1175
1176 out_len: NonZeroUsize,
1177
1178 // FFT plan (reused for all frames)
1179 fft: Box<dyn R2cPlan>,
1180
1181 // internal scratch
1182 fft_out: NonEmptyVec<Complex<f64>>,
1183 frame: NonEmptyVec<f64>,
1184}
1185
1186impl StftPlan {
1187 /// Create a new STFT plan from parameters.
1188 ///
1189 /// # Arguments
1190 ///
1191 /// * `params` - Spectrogram parameters containing STFT config
1192 ///
1193 /// # Returns
1194 ///
1195 /// A new `StftPlan` instance.
1196 ///
1197 /// # Errors
1198 ///
1199 /// Returns an error if the FFT plan cannot be created.
1200 #[inline]
1201 pub fn new(params: &SpectrogramParams) -> SpectrogramResult<Self> {
1202 let stft = params.stft();
1203 let n_fft = stft.n_fft();
1204 let hop_size = stft.hop_size();
1205 let centre = stft.centre();
1206
1207 let window = make_window(stft.window(), n_fft);
1208
1209 let out_len = r2c_output_size(n_fft.get());
1210 let out_len = NonZeroUsize::new(out_len)
1211 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1212
1213 #[cfg(feature = "realfft")]
1214 let fft = {
1215 let mut planner = crate::RealFftPlanner::new();
1216 let plan = planner.get_or_create(n_fft.get());
1217 let plan = crate::RealFftPlan::new(n_fft.get(), plan);
1218 Box::new(plan)
1219 };
1220
1221 #[cfg(feature = "fftw")]
1222 let fft = {
1223 use std::sync::Arc;
1224 let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
1225 Box::new(crate::FftwPlan::new(Arc::new(plan)))
1226 };
1227
1228 Ok(Self {
1229 n_fft,
1230 hop_size,
1231 window,
1232 centre,
1233 out_len,
1234 fft,
1235 fft_out: non_empty_vec![Complex::new(0.0, 0.0); out_len],
1236 frame: non_empty_vec![0.0; n_fft],
1237 })
1238 }
1239
1240 fn frame_count(&self, n_samples: NonZeroUsize) -> SpectrogramResult<NonZeroUsize> {
1241 // Framing policy:
1242 // - centre = true: implicit padding of n_fft/2 on both sides
1243 // - centre = false: no padding
1244 //
1245 // Define the number of frames such that each frame has a valid centre sample position.
1246 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1247 let padded_len = n_samples.get() + 2 * pad;
1248
1249 if padded_len < self.n_fft.get() {
1250 // still produce 1 frame (all padding / partial)
1251 return Ok(nzu!(1));
1252 }
1253
1254 let remaining = padded_len - self.n_fft.get();
1255 let n_frames = remaining / self.hop_size().get() + 1;
1256 let n_frames = NonZeroUsize::new(n_frames).ok_or_else(|| {
1257 SpectrogramError::invalid_input("computed number of frames must be non-zero")
1258 })?;
1259 Ok(n_frames)
1260 }
1261
1262 /// Compute one frame FFT using internal buffers only.
1263 fn compute_frame_fft_simple(
1264 &mut self,
1265 samples: &NonEmptySlice<f64>,
1266 frame_idx: usize,
1267 ) -> SpectrogramResult<()> {
1268 let out = self.frame.as_mut_slice();
1269 debug_assert_eq!(out.len(), self.n_fft.get());
1270
1271 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1272 let start = frame_idx
1273 .checked_mul(self.hop_size.get())
1274 .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1275
1276 // Fill windowed frame
1277 for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1278 let v_idx = start + i;
1279 let s_idx = v_idx as isize - pad as isize;
1280
1281 let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1282 0.0
1283 } else {
1284 samples[s_idx as usize]
1285 };
1286 *sample = sample_val * self.window[i];
1287 }
1288
1289 // Compute FFT
1290 let fft_out = self.fft_out.as_mut_slice();
1291 self.fft.process(out, fft_out)?;
1292
1293 Ok(())
1294 }
1295
1296 /// Compute one frame spectrum into workspace:
1297 /// - fills windowed frame
1298 /// - runs FFT
1299 /// - converts to magnitude/power based on `AmpScale` later
1300 fn compute_frame_spectrum(
1301 &mut self,
1302 samples: &NonEmptySlice<f64>,
1303 frame_idx: usize,
1304 workspace: &mut Workspace,
1305 ) -> SpectrogramResult<()> {
1306 let out = workspace.frame.as_mut_slice();
1307
1308 // self.fill_frame(samples, frame_idx, frame)?;
1309 debug_assert_eq!(out.len(), self.n_fft.get());
1310
1311 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1312 let start = frame_idx
1313 .checked_mul(self.hop_size().get())
1314 .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1315
1316 // The "virtual" signal is samples with pad zeros on both sides.
1317 // Virtual index 0..padded_len
1318 // Map virtual index to original samples by subtracting pad.
1319 for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1320 let v_idx = start + i;
1321 let s_idx = v_idx as isize - pad as isize;
1322
1323 let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1324 0.0
1325 } else {
1326 samples[s_idx as usize]
1327 };
1328
1329 *sample = sample_val * self.window[i];
1330 }
1331 let fft_out = workspace.fft_out.as_mut_slice();
1332 // FFT
1333 self.fft.process(out, fft_out)?;
1334
1335 // Convert complex spectrum to linear magnitude OR power here? No:
1336 // Keep "spectrum" as power by default? That would entangle semantics.
1337 //
1338 // Instead, we store magnitude^2 (power) as the canonical intermediate,
1339 // and let AmpScale decide later whether output is magnitude or power.
1340 //
1341 // This is consistent and avoids recomputing norms multiple times.
1342 for (i, c) in workspace.fft_out.iter().enumerate() {
1343 workspace.spectrum[i] = c.norm_sqr();
1344 }
1345
1346 Ok(())
1347 }
1348
1349 /// Fill a time-domain frame WITHOUT applying the window.
1350 ///
1351 /// This is used for CQT mapping, where the CQT kernels already contain
1352 /// windowing applied during kernel generation. Applying the STFT window
1353 /// would result in double-windowing.
1354 ///
1355 /// # Arguments
1356 ///
1357 /// * `samples` - Input audio samples
1358 /// * `frame_idx` - Frame index
1359 /// * `workspace` - Workspace containing the frame buffer to fill
1360 ///
1361 /// # Errors
1362 ///
1363 /// Returns an error if frame index is out of bounds.
1364 fn fill_frame_unwindowed(
1365 &self,
1366 samples: &NonEmptySlice<f64>,
1367 frame_idx: usize,
1368 workspace: &mut Workspace,
1369 ) -> SpectrogramResult<()> {
1370 let out = workspace.frame.as_mut_slice();
1371 debug_assert_eq!(out.len(), self.n_fft.get());
1372
1373 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1374 let start = frame_idx
1375 .checked_mul(self.hop_size().get())
1376 .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1377
1378 // Fill frame WITHOUT windowing
1379 for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1380 let v_idx = start + i;
1381 let s_idx = v_idx as isize - pad as isize;
1382
1383 let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1384 0.0
1385 } else {
1386 samples[s_idx as usize]
1387 };
1388
1389 // No window multiplication - just copy the sample
1390 *sample = sample_val;
1391 }
1392
1393 Ok(())
1394 }
1395
1396 /// Compute the full STFT for a signal, returning an `StftResult`.
1397 ///
1398 /// This is a convenience method that handles frame iteration and
1399 /// builds the complete STFT matrix.
1400 ///
1401 ///
1402 /// # Arguments
1403 ///
1404 /// * `samples` - Input audio samples
1405 /// * `params` - STFT computation parameters
1406 ///
1407 /// # Returns
1408 ///
1409 /// An `StftResult` containing the complex STFT matrix and metadata.
1410 ///
1411 /// # Errors
1412 ///
1413 /// Returns an error if computation fails.
1414 ///
1415 /// # Examples
1416 ///
1417 /// ```rust
1418 /// use spectrograms::*;
1419 /// use non_empty_slice::non_empty_vec;
1420 ///
1421 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1422 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1423 /// let params = SpectrogramParams::new(stft, 16000.0)?;
1424 /// let mut plan = StftPlan::new(¶ms)?;
1425 ///
1426 /// let samples = non_empty_vec![0.0; nzu!(16000)];
1427 /// let stft_result = plan.compute(&samples, ¶ms)?;
1428 ///
1429 /// println!("STFT: {} bins x {} frames", stft_result.n_bins(), stft_result.n_frames());
1430 /// # Ok(())
1431 /// # }
1432 /// ```
1433 #[inline]
1434 pub fn compute(
1435 &mut self,
1436 samples: &NonEmptySlice<f64>,
1437 params: &SpectrogramParams,
1438 ) -> SpectrogramResult<StftResult> {
1439 let n_frames = self.frame_count(samples.len())?;
1440 let n_bins = self.out_len;
1441
1442 // Allocate output matrix (frequency_bins x time_frames)
1443 let mut data = Array2::<Complex<f64>>::zeros((n_bins.get(), n_frames.get()));
1444
1445 // Compute each frame
1446 for frame_idx in 0..n_frames.get() {
1447 self.compute_frame_fft_simple(samples, frame_idx)?;
1448
1449 // Copy from internal buffer to output
1450 for (bin_idx, &value) in self.fft_out.iter().enumerate() {
1451 data[[bin_idx, frame_idx]] = value;
1452 }
1453 }
1454
1455 // Build frequency axis
1456 let frequencies: Vec<f64> = (0..n_bins.get())
1457 .map(|k| k as f64 * params.sample_rate_hz() / params.stft().n_fft().get() as f64)
1458 .collect();
1459 // SAFETY: n_bins > 0
1460 let frequencies = unsafe { NonEmptyVec::new_unchecked(frequencies) };
1461
1462 Ok(StftResult {
1463 data,
1464 frequencies,
1465 sample_rate: params.sample_rate_hz(),
1466 params: params.stft().clone(),
1467 })
1468 }
1469
1470 /// Compute a single frame of STFT, returning the complex spectrum.
1471 ///
1472 /// This is useful for streaming/online processing where you want to
1473 /// process audio frame-by-frame.
1474 ///
1475 /// # Arguments
1476 ///
1477 /// * `samples` - Input audio samples
1478 /// * `frame_idx` - Index of the frame to compute
1479 ///
1480 /// # Returns
1481 ///
1482 /// A `NonEmptyVec` containing the complex spectrum for the specified frame.
1483 ///
1484 /// # Errors
1485 ///
1486 /// Returns an error if computation fails.
1487 ///
1488 /// # Examples
1489 ///
1490 /// ```rust
1491 /// use spectrograms::*;
1492 /// use non_empty_slice::non_empty_vec;
1493 ///
1494 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1495 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1496 /// let params = SpectrogramParams::new(stft, 16000.0)?;
1497 /// let mut plan = StftPlan::new(¶ms)?;
1498 ///
1499 /// let samples = non_empty_vec![0.0; nzu!(16000)];
1500 /// let (_, n_frames) = plan.output_shape(samples.len())?;
1501 ///
1502 /// for frame_idx in 0..n_frames.get() {
1503 /// let spectrum = plan.compute_frame_simple(&samples, frame_idx)?;
1504 /// // Process spectrum...
1505 /// }
1506 /// # Ok(())
1507 /// # }
1508 /// ```
1509 #[inline]
1510 pub fn compute_frame_simple(
1511 &mut self,
1512 samples: &NonEmptySlice<f64>,
1513 frame_idx: usize,
1514 ) -> SpectrogramResult<NonEmptyVec<Complex<f64>>> {
1515 self.compute_frame_fft_simple(samples, frame_idx)?;
1516 Ok(self.fft_out.clone())
1517 }
1518
1519 /// Compute STFT into a pre-allocated buffer.
1520 ///
1521 /// This avoids allocating the output matrix, useful for reusing buffers.
1522 ///
1523 /// # Arguments
1524 ///
1525 /// * `samples` - Input audio samples
1526 /// * `output` - Pre-allocated output buffer (shape: `n_bins` x `n_frames`)
1527 ///
1528 /// # Returns
1529 ///
1530 /// `Ok(())` on success, or an error if dimensions mismatch.
1531 ///
1532 /// # Errors
1533 ///
1534 /// Returns `DimensionMismatch` error if the output buffer has incorrect shape.
1535 ///
1536 /// # Examples
1537 ///
1538 /// ```rust
1539 /// use spectrograms::*;
1540 /// use ndarray::Array2;
1541 /// use num_complex::Complex;
1542 /// use non_empty_slice::non_empty_vec;
1543 ///
1544 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1545 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1546 /// let params = SpectrogramParams::new(stft, 16000.0)?;
1547 /// let mut plan = StftPlan::new(¶ms)?;
1548 ///
1549 /// let samples = non_empty_vec![0.0; nzu!(16000)];
1550 /// let (n_bins, n_frames) = plan.output_shape(samples.len())?;
1551 /// let mut output = Array2::<Complex<f64>>::zeros((n_bins.get(), n_frames.get()));
1552 ///
1553 /// plan.compute_into(&samples, &mut output)?;
1554 /// # Ok(())
1555 /// # }
1556 /// ```
1557 #[inline]
1558 pub fn compute_into(
1559 &mut self,
1560 samples: &NonEmptySlice<f64>,
1561 output: &mut Array2<Complex<f64>>,
1562 ) -> SpectrogramResult<()> {
1563 let n_frames = self.frame_count(samples.len())?;
1564 let n_bins = self.out_len;
1565
1566 // Validate output dimensions
1567 if output.nrows() != n_bins.get() {
1568 return Err(SpectrogramError::dimension_mismatch(
1569 n_bins.get(),
1570 output.nrows(),
1571 ));
1572 }
1573 if output.ncols() != n_frames.get() {
1574 return Err(SpectrogramError::dimension_mismatch(
1575 n_frames.get(),
1576 output.ncols(),
1577 ));
1578 }
1579
1580 // Compute into pre-allocated buffer
1581 for frame_idx in 0..n_frames.get() {
1582 self.compute_frame_fft_simple(samples, frame_idx)?;
1583
1584 for (bin_idx, &value) in self.fft_out.iter().enumerate() {
1585 output[[bin_idx, frame_idx]] = value;
1586 }
1587 }
1588
1589 Ok(())
1590 }
1591
1592 /// Get the expected output dimensions for a given signal length.
1593 ///
1594 /// # Arguments
1595 ///
1596 /// * `signal_length` - Length of the input signal in samples
1597 ///
1598 /// # Returns
1599 ///
1600 /// A tuple `(n_frequency_bins, n_time_frames)`.
1601 ///
1602 /// # Errors
1603 ///
1604 /// Returns an error if the computed number of frames is invalid.
1605 #[inline]
1606 pub fn output_shape(
1607 &self,
1608 signal_length: NonZeroUsize,
1609 ) -> SpectrogramResult<(NonZeroUsize, NonZeroUsize)> {
1610 let n_frames = self.frame_count(signal_length)?;
1611 Ok((self.out_len, n_frames))
1612 }
1613
1614 /// Get the number of frequency bins in the output.
1615 ///
1616 /// # Returns
1617 ///
1618 /// The number of frequency bins.
1619 #[inline]
1620 #[must_use]
1621 pub const fn n_bins(&self) -> NonZeroUsize {
1622 self.out_len
1623 }
1624
1625 /// Get the FFT size.
1626 ///
1627 /// # Returns
1628 ///
1629 /// The FFT size.
1630 #[inline]
1631 #[must_use]
1632 pub const fn n_fft(&self) -> NonZeroUsize {
1633 self.n_fft
1634 }
1635
1636 /// Get the hop size.
1637 ///
1638 /// # Returns
1639 ///
1640 /// The hop size.
1641 #[inline]
1642 #[must_use]
1643 pub const fn hop_size(&self) -> NonZeroUsize {
1644 self.hop_size
1645 }
1646}
1647
1648#[derive(Debug, Clone)]
1649enum MappingKind {
1650 Identity {
1651 out_len: NonZeroUsize,
1652 },
1653 Mel {
1654 matrix: SparseMatrix,
1655 }, // shape: (n_mels, out_len)
1656 LogHz {
1657 matrix: SparseMatrix,
1658 frequencies: NonEmptyVec<f64>,
1659 }, // shape: (n_bins, out_len)
1660 Erb {
1661 filterbank: ErbFilterbank,
1662 },
1663 Cqt {
1664 kernel: CqtKernel,
1665 },
1666}
1667
1668impl MappingKind {
1669 /// Check if this mapping requires unwindowed time-domain frames.
1670 ///
1671 /// CQT kernels already contain windowing applied during kernel generation,
1672 /// so they should receive unwindowed frames to avoid double-windowing.
1673 /// All other mappings work on FFT spectra and don't use the time-domain frame.
1674 const fn needs_unwindowed_frame(&self) -> bool {
1675 matches!(self, Self::Cqt { .. })
1676 }
1677}
1678
1679/// Typed mapping wrapper.
1680#[derive(Debug, Clone)]
1681struct FrequencyMapping<FreqScale> {
1682 kind: MappingKind,
1683 _marker: PhantomData<FreqScale>,
1684}
1685
1686impl FrequencyMapping<LinearHz> {
1687 fn new(params: &SpectrogramParams) -> SpectrogramResult<Self> {
1688 let out_len = r2c_output_size(params.stft().n_fft().get());
1689 let out_len = NonZeroUsize::new(out_len)
1690 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1691 Ok(Self {
1692 kind: MappingKind::Identity { out_len },
1693 _marker: PhantomData,
1694 })
1695 }
1696}
1697
1698impl FrequencyMapping<Mel> {
1699 fn new_mel(params: &SpectrogramParams, mel: &MelParams) -> SpectrogramResult<Self> {
1700 let n_fft = params.stft().n_fft();
1701 let out_len = r2c_output_size(n_fft.get());
1702 let out_len = NonZeroUsize::new(out_len)
1703 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1704
1705 // Validate: mel bins must be <= something sensible
1706 if mel.n_mels() > nzu!(10_000) {
1707 return Err(SpectrogramError::invalid_input(
1708 "n_mels is unreasonably large",
1709 ));
1710 }
1711
1712 let matrix = build_mel_filterbank_matrix(
1713 params.sample_rate_hz(),
1714 n_fft,
1715 mel.n_mels(),
1716 mel.f_min(),
1717 mel.f_max(),
1718 mel.norm(),
1719 )?;
1720
1721 // matrix must be (n_mels, out_len)
1722 if matrix.nrows() != mel.n_mels().get() || matrix.ncols() != out_len.get() {
1723 return Err(SpectrogramError::invalid_input(
1724 "mel filterbank matrix shape mismatch",
1725 ));
1726 }
1727
1728 Ok(Self {
1729 kind: MappingKind::Mel { matrix },
1730 _marker: PhantomData,
1731 })
1732 }
1733}
1734
1735impl FrequencyMapping<LogHz> {
1736 fn new_loghz(params: &SpectrogramParams, loghz: &LogHzParams) -> SpectrogramResult<Self> {
1737 let n_fft = params.stft().n_fft();
1738 let out_len = r2c_output_size(n_fft.get());
1739 let out_len = NonZeroUsize::new(out_len)
1740 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1741 // Validate: n_bins must be <= something sensible
1742 if loghz.n_bins() > nzu!(10_000) {
1743 return Err(SpectrogramError::invalid_input(
1744 "n_bins is unreasonably large",
1745 ));
1746 }
1747
1748 let (matrix, frequencies) = build_loghz_matrix(
1749 params.sample_rate_hz(),
1750 n_fft,
1751 loghz.n_bins(),
1752 loghz.f_min(),
1753 loghz.f_max(),
1754 )?;
1755
1756 // matrix must be (n_bins, out_len)
1757 if matrix.nrows() != loghz.n_bins().get() || matrix.ncols() != out_len.get() {
1758 return Err(SpectrogramError::invalid_input(
1759 "loghz matrix shape mismatch",
1760 ));
1761 }
1762
1763 Ok(Self {
1764 kind: MappingKind::LogHz {
1765 matrix,
1766 frequencies,
1767 },
1768 _marker: PhantomData,
1769 })
1770 }
1771}
1772
1773impl FrequencyMapping<Erb> {
1774 fn new_erb(params: &SpectrogramParams, erb: &crate::erb::ErbParams) -> SpectrogramResult<Self> {
1775 let n_fft = params.stft().n_fft();
1776 let sample_rate = params.sample_rate_hz();
1777
1778 // Validate: n_filters must be <= something sensible
1779 if erb.n_filters() > nzu!(10_000) {
1780 return Err(SpectrogramError::invalid_input(
1781 "n_filters is unreasonably large",
1782 ));
1783 }
1784
1785 // Generate ERB filterbank with pre-computed frequency responses
1786 let filterbank = crate::erb::ErbFilterbank::generate(erb, sample_rate, n_fft)?;
1787
1788 Ok(Self {
1789 kind: MappingKind::Erb { filterbank },
1790 _marker: PhantomData,
1791 })
1792 }
1793}
1794
1795impl FrequencyMapping<Cqt> {
1796 fn new(params: &SpectrogramParams, cqt: &CqtParams) -> SpectrogramResult<Self> {
1797 let sample_rate = params.sample_rate_hz();
1798 let n_fft = params.stft().n_fft();
1799
1800 // Validate that frequency range is reasonable
1801 let f_max = cqt.bin_frequency(cqt.num_bins().get().saturating_sub(1));
1802 if f_max >= sample_rate / 2.0 {
1803 return Err(SpectrogramError::invalid_input(
1804 "CQT maximum frequency must be below Nyquist frequency",
1805 ));
1806 }
1807
1808 // Generate CQT kernel using n_fft as the signal length for kernel generation
1809 let kernel = CqtKernel::generate(cqt, sample_rate, n_fft);
1810
1811 Ok(Self {
1812 kind: MappingKind::Cqt { kernel },
1813 _marker: PhantomData,
1814 })
1815 }
1816}
1817
1818impl<FreqScale> FrequencyMapping<FreqScale> {
1819 const fn output_bins(&self) -> NonZeroUsize {
1820 // safety: all variants ensure output bins > 0 OR rely on a matrix that is guaranteed to have rows > 0
1821 match &self.kind {
1822 MappingKind::Identity { out_len } => *out_len,
1823 // safety: matrix.nrows() > 0
1824 MappingKind::LogHz { matrix, .. } | MappingKind::Mel { matrix } => unsafe {
1825 NonZeroUsize::new_unchecked(matrix.nrows())
1826 },
1827 MappingKind::Erb { filterbank, .. } => filterbank.num_filters(),
1828 MappingKind::Cqt { kernel, .. } => kernel.num_bins(),
1829 }
1830 }
1831
1832 fn apply(
1833 &self,
1834 spectrum: &NonEmptySlice<f64>,
1835 frame: &NonEmptySlice<f64>,
1836 out: &mut NonEmptySlice<f64>,
1837 ) -> SpectrogramResult<()> {
1838 match &self.kind {
1839 MappingKind::Identity { out_len } => {
1840 if spectrum.len() != *out_len {
1841 return Err(SpectrogramError::dimension_mismatch(
1842 (*out_len).get(),
1843 spectrum.len().get(),
1844 ));
1845 }
1846 if out.len() != *out_len {
1847 return Err(SpectrogramError::dimension_mismatch(
1848 (*out_len).get(),
1849 out.len().get(),
1850 ));
1851 }
1852 out.copy_from_slice(spectrum);
1853 Ok(())
1854 }
1855 MappingKind::LogHz { matrix, .. } | MappingKind::Mel { matrix } => {
1856 let out_bins = matrix.nrows();
1857 let in_bins = matrix.ncols();
1858
1859 if spectrum.len().get() != in_bins {
1860 return Err(SpectrogramError::dimension_mismatch(
1861 in_bins,
1862 spectrum.len().get(),
1863 ));
1864 }
1865 if out.len().get() != out_bins {
1866 return Err(SpectrogramError::dimension_mismatch(
1867 out_bins,
1868 out.len().get(),
1869 ));
1870 }
1871
1872 // Sparse matrix-vector multiplication: out = matrix * spectrum
1873 matrix.multiply_vec(spectrum, out);
1874 Ok(())
1875 }
1876 MappingKind::Erb { filterbank } => {
1877 // Apply ERB filterbank using pre-computed frequency responses
1878 // The filterbank already has |H(f)|^2 pre-computed, so we just
1879 // apply it to the power spectrum
1880 let erb_out = filterbank.apply_to_power_spectrum(spectrum)?;
1881
1882 if out.len().get() != erb_out.len().get() {
1883 return Err(SpectrogramError::dimension_mismatch(
1884 erb_out.len().get(),
1885 out.len().get(),
1886 ));
1887 }
1888
1889 out.copy_from_slice(&erb_out);
1890 Ok(())
1891 }
1892 MappingKind::Cqt { kernel } => {
1893 // CQT works on time-domain unwindowed frame (not FFT spectrum).
1894 // The CQT kernels contain windowing applied during kernel generation,
1895 // so the input frame should be unwindowed to avoid double-windowing.
1896 let cqt_complex = kernel.apply(frame)?;
1897
1898 if out.len().get() != cqt_complex.len().get() {
1899 return Err(SpectrogramError::dimension_mismatch(
1900 cqt_complex.len().get(),
1901 out.len().get(),
1902 ));
1903 }
1904
1905 // Convert complex coefficients to power (|z|^2)
1906 // This matches the convention where intermediate values are in power domain
1907 for (i, c) in cqt_complex.iter().enumerate() {
1908 out[i] = c.norm_sqr();
1909 }
1910
1911 Ok(())
1912 }
1913 }
1914 }
1915
1916 fn frequencies_hz(&self, params: &SpectrogramParams) -> NonEmptyVec<f64> {
1917 match &self.kind {
1918 MappingKind::Identity { out_len } => {
1919 // Standard R2C bins: k * sr / n_fft
1920 let n_fft = params.stft().n_fft().get() as f64;
1921 let sr = params.sample_rate_hz();
1922 let df = sr / n_fft;
1923
1924 let mut f = Vec::with_capacity((*out_len).get());
1925 for k in 0..(*out_len).get() {
1926 f.push(k as f64 * df);
1927 }
1928 // safety: out_len > 0
1929 unsafe { NonEmptyVec::new_unchecked(f) }
1930 }
1931 MappingKind::Mel { matrix } => {
1932 // For mel, the axis is defined by the mel band centre frequencies.
1933 // We compute and store them consistently with how we built the filterbank.
1934 let n_mels = matrix.nrows();
1935 // safety: n_mels > 0
1936 let n_mels = unsafe { NonZeroUsize::new_unchecked(n_mels) };
1937 mel_band_centres_hz(n_mels, params.sample_rate_hz(), params.nyquist_hz())
1938 }
1939 MappingKind::LogHz { frequencies, .. } => {
1940 // Frequencies are stored when the mapping is created
1941 frequencies.clone()
1942 }
1943 MappingKind::Erb { filterbank, .. } => {
1944 // ERB center frequencies
1945 filterbank.center_frequencies().to_non_empty_vec()
1946 }
1947 MappingKind::Cqt { kernel, .. } => {
1948 // CQT center frequencies from the kernel
1949 kernel.frequencies().to_non_empty_vec()
1950 }
1951 }
1952 }
1953}
1954
1955//
1956// ========================
1957// Amplitude scaling
1958// ========================
1959//
1960
1961/// Marker trait so we can specialise behaviour by `AmpScale`.
1962pub trait AmpScaleSpec: Sized + Send + Sync {
1963 /// Apply conversion from power-domain value to the desired amplitude scale.
1964 ///
1965 /// # Arguments
1966 ///
1967 /// - `power`: input power-domain value.
1968 ///
1969 /// # Returns
1970 ///
1971 /// Converted amplitude value.
1972 fn apply_from_power(power: f64) -> f64;
1973
1974 /// Apply dB conversion in-place on a power-domain vector.
1975 ///
1976 /// This is a no-op for Power and Magnitude scales.
1977 ///
1978 /// # Arguments
1979 ///
1980 /// - `x`: power-domain values to convert to dB in-place.
1981 /// - `floor_db`: dB floor value to apply.
1982 ///
1983 /// # Returns
1984 ///
1985 /// `SpectrogramResult<()>`: Ok on success, error on invalid input.
1986 ///
1987 /// # Errors
1988 ///
1989 /// - If `floor_db` is not finite.
1990 fn apply_db_in_place(x: &mut [f64], floor_db: f64) -> SpectrogramResult<()>;
1991}
1992
1993impl AmpScaleSpec for Power {
1994 #[inline]
1995 fn apply_from_power(power: f64) -> f64 {
1996 power
1997 }
1998
1999 #[inline]
2000 fn apply_db_in_place(_x: &mut [f64], _floor_db: f64) -> SpectrogramResult<()> {
2001 Ok(())
2002 }
2003}
2004
2005impl AmpScaleSpec for Magnitude {
2006 #[inline]
2007 fn apply_from_power(power: f64) -> f64 {
2008 power.sqrt()
2009 }
2010
2011 #[inline]
2012 fn apply_db_in_place(_x: &mut [f64], _floor_db: f64) -> SpectrogramResult<()> {
2013 Ok(())
2014 }
2015}
2016
2017impl AmpScaleSpec for Decibels {
2018 #[inline]
2019 fn apply_from_power(power: f64) -> f64 {
2020 // dB conversion is applied in batch, not here.
2021 power
2022 }
2023
2024 #[inline]
2025 fn apply_db_in_place(x: &mut [f64], floor_db: f64) -> SpectrogramResult<()> {
2026 // Convert power -> dB: 10*log10(max(power, eps))
2027 // where eps is derived from floor_db to ensure consistency
2028 if !floor_db.is_finite() {
2029 return Err(SpectrogramError::invalid_input("floor_db must be finite"));
2030 }
2031
2032 // Convert floor_db to linear scale to get epsilon
2033 // e.g., floor_db = -80 dB -> eps = 10^(-80/10) = 1e-8
2034 let eps = 10.0_f64.powf(floor_db / 10.0);
2035
2036 for v in x.iter_mut() {
2037 // Clamp power to epsilon before log to avoid log(0) and ensure floor
2038 *v = 10.0 * v.max(eps).log10();
2039 }
2040 Ok(())
2041 }
2042}
2043
2044/// Amplitude scaling configuration.
2045///
2046/// This handles conversion from power-domain intermediate to the desired amplitude scale (Power, Magnitude, Decibels).
2047#[derive(Debug, Clone)]
2048struct AmplitudeScaling<AmpScale> {
2049 db_floor: Option<f64>,
2050 _marker: PhantomData<AmpScale>,
2051}
2052
2053impl<AmpScale> AmplitudeScaling<AmpScale>
2054where
2055 AmpScale: AmpScaleSpec + 'static,
2056{
2057 fn new(db: Option<&LogParams>) -> Self {
2058 let db_floor = db.map(LogParams::floor_db);
2059 Self {
2060 db_floor,
2061 _marker: PhantomData,
2062 }
2063 }
2064
2065 /// Apply amplitude scaling in-place on a mapped spectrum vector.
2066 ///
2067 /// The input vector is assumed to be in the *power* domain (|X|^2),
2068 /// because the STFT stage produces power as the canonical intermediate.
2069 ///
2070 /// - Power: leaves values unchanged.
2071 /// - Magnitude: sqrt(power).
2072 /// - Decibels: converts power -> dB and floors at `db_floor`.
2073 pub fn apply_in_place(&self, x: &mut [f64]) -> SpectrogramResult<()> {
2074 // Convert from canonical power-domain intermediate into the requested linear domain.
2075 for v in x.iter_mut() {
2076 *v = AmpScale::apply_from_power(*v);
2077 }
2078
2079 // Apply dB conversion if configured (no-op for Power/Magnitude via trait impls).
2080 if let Some(floor_db) = self.db_floor {
2081 AmpScale::apply_db_in_place(x, floor_db)?;
2082 }
2083
2084 Ok(())
2085 }
2086}
2087
2088#[derive(Debug, Clone)]
2089struct Workspace {
2090 spectrum: NonEmptyVec<f64>, // out_len (power spectrum)
2091 mapped: NonEmptyVec<f64>, // n_bins (after mapping)
2092 frame: NonEmptyVec<f64>, // n_fft (windowed frame for FFT)
2093 fft_out: NonEmptyVec<Complex<f64>>, // out_len (FFT output)
2094}
2095
2096impl Workspace {
2097 fn new(n_fft: NonZeroUsize, out_len: NonZeroUsize, n_bins: NonZeroUsize) -> Self {
2098 Self {
2099 spectrum: non_empty_vec![0.0; out_len],
2100 mapped: non_empty_vec![0.0; n_bins],
2101 frame: non_empty_vec![0.0; n_fft],
2102 fft_out: non_empty_vec![Complex::new(0.0, 0.0); out_len],
2103 }
2104 }
2105
2106 fn ensure_sizes(&mut self, n_fft: NonZeroUsize, out_len: NonZeroUsize, n_bins: NonZeroUsize) {
2107 if self.spectrum.len() != out_len {
2108 self.spectrum.resize(out_len, 0.0);
2109 }
2110 if self.mapped.len() != n_bins {
2111 self.mapped.resize(n_bins, 0.0);
2112 }
2113 if self.frame.len() != n_fft {
2114 self.frame.resize(n_fft, 0.0);
2115 }
2116 if self.fft_out.len() != out_len {
2117 self.fft_out.resize(out_len, Complex::new(0.0, 0.0));
2118 }
2119 }
2120}
2121
2122fn build_frequency_axis<FreqScale>(
2123 params: &SpectrogramParams,
2124 mapping: &FrequencyMapping<FreqScale>,
2125) -> FrequencyAxis<FreqScale>
2126where
2127 FreqScale: Copy + Clone + 'static,
2128{
2129 let frequencies = mapping.frequencies_hz(params);
2130 FrequencyAxis::new(frequencies)
2131}
2132
2133fn build_time_axis_seconds(params: &SpectrogramParams, n_frames: NonZeroUsize) -> NonEmptyVec<f64> {
2134 let dt = params.frame_period_seconds();
2135 let mut times = Vec::with_capacity(n_frames.get());
2136
2137 for i in 0..n_frames.get() {
2138 times.push(i as f64 * dt);
2139 }
2140
2141 // safety: times is guaranteed non-empty since n_frames > 0
2142
2143 unsafe { NonEmptyVec::new_unchecked(times) }
2144}
2145
2146/// Generate window function samples.
2147///
2148/// Supports various window types including Rectangular, Hanning, Hamming, Blackman, Kaiser, and Gaussian.
2149///
2150/// # Arguments
2151///
2152/// * `window` - The type of window function to generate.
2153/// * `n_fft` - The size of the FFT, which determines the length of the window.
2154///
2155/// # Returns
2156///
2157/// A `NonEmptyVec<f64>` containing the window function samples.
2158///
2159/// # Panics
2160///
2161/// Panics if a custom window is provided with a size that does not match `n_fft`.
2162#[inline]
2163#[must_use]
2164pub fn make_window(window: WindowType, n_fft: NonZeroUsize) -> NonEmptyVec<f64> {
2165 let n_fft = n_fft.get();
2166 let mut w = vec![0.0; n_fft];
2167
2168 match window {
2169 WindowType::Rectangular => {
2170 w.fill(1.0);
2171 }
2172 WindowType::Hanning => {
2173 // Hann: 0.5 - 0.5*cos(2Ï€n/(N-1))
2174 let n1 = (n_fft - 1) as f64;
2175 for (n, v) in w.iter_mut().enumerate() {
2176 *v = 0.5f64.mul_add(-(2.0 * std::f64::consts::PI * (n as f64) / n1).cos(), 0.5);
2177 }
2178 }
2179 WindowType::Hamming => {
2180 // Hamming: 0.54 - 0.46*cos(2Ï€n/(N-1))
2181 let n1 = (n_fft - 1) as f64;
2182 for (n, v) in w.iter_mut().enumerate() {
2183 *v = 0.46f64.mul_add(-(2.0 * std::f64::consts::PI * (n as f64) / n1).cos(), 0.54);
2184 }
2185 }
2186 WindowType::Blackman => {
2187 // Blackman: 0.42 - 0.5*cos(2Ï€n/(N-1)) + 0.08*cos(4Ï€n/(N-1))
2188 let n1 = (n_fft - 1) as f64;
2189 for (n, v) in w.iter_mut().enumerate() {
2190 let a = 2.0 * std::f64::consts::PI * (n as f64) / n1;
2191 *v = 0.08f64.mul_add((2.0 * a).cos(), 0.5f64.mul_add(-a.cos(), 0.42));
2192 }
2193 }
2194 WindowType::Kaiser { beta } => {
2195 if n_fft == 1 {
2196 w[0] = 1.0;
2197 } else {
2198 let denom = modified_bessel_i0(beta);
2199 let n_max = (n_fft - 1) as f64 / 2.0;
2200
2201 for (i, value) in w.iter_mut().enumerate() {
2202 let n = i as f64 - n_max;
2203 let ratio = if n_max == 0.0 {
2204 0.0
2205 } else {
2206 let normalized = n / n_max;
2207 (1.0 - normalized * normalized).max(0.0)
2208 };
2209 let arg = beta * ratio.sqrt();
2210 *value = if denom == 0.0 {
2211 0.0
2212 } else {
2213 modified_bessel_i0(arg) / denom
2214 };
2215 }
2216 }
2217 }
2218 WindowType::Gaussian { std } => (0..n_fft).for_each(|i| {
2219 let n = i as f64;
2220 let center: f64 = (n_fft - 1) as f64 / 2.0;
2221 let exponent: f64 = -0.5 * ((n - center) / std).powi(2);
2222 w[i] = exponent.exp();
2223 }),
2224 WindowType::Custom { coefficients, size } => {
2225 assert!(
2226 size.get() == n_fft,
2227 "Custom window size mismatch: expected {}, got {}. \
2228 Custom windows must be pre-computed with the exact FFT size.",
2229 n_fft,
2230 size.get()
2231 );
2232 w.copy_from_slice(&coefficients);
2233 }
2234 }
2235
2236 // safety: window is guaranteed non-empty since n_fft > 0
2237 unsafe { NonEmptyVec::new_unchecked(w) }
2238}
2239
2240fn modified_bessel_i0(x: f64) -> f64 {
2241 let ax = x.abs();
2242 if ax <= 3.75 {
2243 let t = x / 3.75;
2244 let t2 = t * t;
2245 1.0 + t2
2246 * (3.515_622_9
2247 + t2 * (3.089_942_4
2248 + t2 * (1.206_749_2
2249 + t2 * (0.265_973_2 + t2 * (0.036_076_8 + t2 * 0.004_581_3)))))
2250 } else {
2251 let t = 3.75 / ax;
2252 let poly = 0.398_942_28
2253 + t * (0.013_285_92
2254 + t * (0.002_253_19
2255 + t * (-0.001_575_65
2256 + t * (0.009_162_81
2257 + t * (-0.020_577_06
2258 + t * (0.026_355_37 + t * (-0.016_476_33 + t * 0.003_923_77)))))));
2259
2260 (ax.exp() / (ax.sqrt() * (2.0 * std::f64::consts::PI).sqrt())) * poly
2261 }
2262}
2263
2264/// Convert Hz to mel scale using Slaney formula (librosa default, htk=False).
2265///
2266/// Uses a hybrid scale:
2267/// - Linear below 1000 Hz: mel = hz / (200/3)
2268/// - Logarithmic above 1000 Hz: mel = 15 + log(hz/1000) / log_step
2269///
2270/// This matches librosa's default behavior.
2271fn hz_to_mel(hz: f64) -> f64 {
2272 const F_MIN: f64 = 0.0;
2273 const F_SP: f64 = 200.0 / 3.0; // ~66.667
2274 const MIN_LOG_HZ: f64 = 1000.0;
2275 const MIN_LOG_MEL: f64 = (MIN_LOG_HZ - F_MIN) / F_SP; // = 15.0
2276 const LOGSTEP: f64 = 0.068_751_777_420_949_23; // ln(6.4) / 27
2277 if hz >= MIN_LOG_HZ {
2278 // Logarithmic region
2279 MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / LOGSTEP
2280 } else {
2281 // Linear region
2282 (hz - F_MIN) / F_SP
2283 }
2284}
2285
2286/// Convert mel to Hz using Slaney formula (librosa default, htk=False).
2287///
2288/// Inverse of hz_to_mel.
2289fn mel_to_hz(mel: f64) -> f64 {
2290 const F_MIN: f64 = 0.0;
2291 const F_SP: f64 = 200.0 / 3.0; // ~66.667
2292 const MIN_LOG_HZ: f64 = 1000.0;
2293 const MIN_LOG_MEL: f64 = (MIN_LOG_HZ - F_MIN) / F_SP; // = 15.0
2294 const LOGSTEP: f64 = 0.068_751_777_420_949_23; // ln(6.4) / 27
2295
2296 if mel >= MIN_LOG_MEL {
2297 // Logarithmic region
2298 MIN_LOG_HZ * (LOGSTEP * (mel - MIN_LOG_MEL)).exp()
2299 } else {
2300 // Linear region
2301 F_SP.mul_add(mel, F_MIN)
2302 }
2303}
2304
2305fn build_mel_filterbank_matrix(
2306 sample_rate_hz: f64,
2307 n_fft: NonZeroUsize,
2308 n_mels: NonZeroUsize,
2309 f_min: f64,
2310 f_max: f64,
2311 norm: MelNorm,
2312) -> SpectrogramResult<SparseMatrix> {
2313 if sample_rate_hz <= 0.0 || !sample_rate_hz.is_finite() {
2314 return Err(SpectrogramError::invalid_input(
2315 "sample_rate_hz must be finite and > 0",
2316 ));
2317 }
2318 if f_min < 0.0 || f_min.is_infinite() {
2319 return Err(SpectrogramError::invalid_input("f_min must be >= 0"));
2320 }
2321 if f_max <= f_min {
2322 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
2323 }
2324 if f_max > sample_rate_hz * 0.5 {
2325 return Err(SpectrogramError::invalid_input("f_max must be <= Nyquist"));
2326 }
2327 let n_mels = n_mels.get();
2328 let n_fft = n_fft.get();
2329 let out_len = r2c_output_size(n_fft);
2330
2331 // FFT bin frequencies
2332 let df = sample_rate_hz / n_fft as f64;
2333
2334 // Mel points: n_mels + 2 (for triangular edges)
2335 let mel_min = hz_to_mel(f_min);
2336 let mel_max = hz_to_mel(f_max);
2337
2338 let n_points = n_mels + 2;
2339 let step = (mel_max - mel_min) / (n_points - 1) as f64;
2340
2341 let mut mel_points = Vec::with_capacity(n_points);
2342 for i in 0..n_points {
2343 mel_points.push((i as f64).mul_add(step, mel_min));
2344 }
2345
2346 let mut hz_points = Vec::with_capacity(n_points);
2347 for m in &mel_points {
2348 hz_points.push(mel_to_hz(*m));
2349 }
2350
2351 // Build filterbank as sparse matrix (librosa-style, in frequency space)
2352 // This builds triangular filters based on actual frequencies, not bin indices
2353 let mut fb = SparseMatrix::new(n_mels, out_len);
2354
2355 for m in 0..n_mels {
2356 let freq_left = hz_points[m];
2357 let freq_center = hz_points[m + 1];
2358 let freq_right = hz_points[m + 2];
2359
2360 let fdiff_left = freq_center - freq_left;
2361 let fdiff_right = freq_right - freq_center;
2362
2363 if fdiff_left == 0.0 || fdiff_right == 0.0 {
2364 // Degenerate triangle, skip
2365 continue;
2366 }
2367
2368 // For each FFT bin, compute the triangular weight based on its frequency
2369 for k in 0..out_len {
2370 let bin_freq = k as f64 * df;
2371
2372 // Lower slope: rises from freq_left to freq_center
2373 let lower = (bin_freq - freq_left) / fdiff_left;
2374
2375 // Upper slope: falls from freq_center to freq_right
2376 let upper = (freq_right - bin_freq) / fdiff_right;
2377
2378 // Triangle is the minimum of the two slopes, clipped to [0, 1]
2379 let weight = lower.min(upper).clamp(0.0, 1.0);
2380
2381 if weight > 0.0 {
2382 fb.set(m, k, weight);
2383 }
2384 }
2385 }
2386
2387 // Apply normalization
2388 match norm {
2389 MelNorm::None => {
2390 // No normalization needed
2391 }
2392 MelNorm::Slaney => {
2393 // Slaney-style area normalization: 2 / (hz_max - hz_min) for each triangle
2394 // NOTE: Uses Hz bandwidth, not mel bandwidth (to match librosa's implementation)
2395 for m in 0..n_mels {
2396 let mel_left = mel_points[m];
2397 let mel_right = mel_points[m + 2];
2398 let hz_left = mel_to_hz(mel_left);
2399 let hz_right = mel_to_hz(mel_right);
2400 let enorm = 2.0 / (hz_right - hz_left);
2401
2402 // Normalize all values in this row
2403 for val in &mut fb.values[m] {
2404 *val *= enorm;
2405 }
2406 }
2407 }
2408 MelNorm::L1 => {
2409 // L1 normalization: sum of weights = 1.0
2410 for m in 0..n_mels {
2411 let sum: f64 = fb.values[m].iter().sum();
2412 if sum > 0.0 {
2413 let normalizer = 1.0 / sum;
2414 for val in &mut fb.values[m] {
2415 *val *= normalizer;
2416 }
2417 }
2418 }
2419 }
2420 MelNorm::L2 => {
2421 // L2 normalization: L2 norm = 1.0
2422 for m in 0..n_mels {
2423 let norm_val: f64 = fb.values[m].iter().map(|&v| v * v).sum::<f64>().sqrt();
2424 if norm_val > 0.0 {
2425 let normalizer = 1.0 / norm_val;
2426 for val in &mut fb.values[m] {
2427 *val *= normalizer;
2428 }
2429 }
2430 }
2431 }
2432 }
2433
2434 Ok(fb)
2435}
2436
2437/// Build a logarithmic frequency interpolation matrix.
2438///
2439/// Maps linearly-spaced FFT bins to logarithmically-spaced frequency bins
2440/// using linear interpolation.
2441fn build_loghz_matrix(
2442 sample_rate_hz: f64,
2443 n_fft: NonZeroUsize,
2444 n_bins: NonZeroUsize,
2445 f_min: f64,
2446 f_max: f64,
2447) -> SpectrogramResult<(SparseMatrix, NonEmptyVec<f64>)> {
2448 if sample_rate_hz <= 0.0 || !sample_rate_hz.is_finite() {
2449 return Err(SpectrogramError::invalid_input(
2450 "sample_rate_hz must be finite and > 0",
2451 ));
2452 }
2453 if f_min <= 0.0 || f_min.is_infinite() {
2454 return Err(SpectrogramError::invalid_input(
2455 "f_min must be finite and > 0",
2456 ));
2457 }
2458 if f_max <= f_min {
2459 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
2460 }
2461 if f_max > sample_rate_hz * 0.5 {
2462 return Err(SpectrogramError::invalid_input("f_max must be <= Nyquist"));
2463 }
2464
2465 let n_bins = n_bins.get();
2466 let n_fft = n_fft.get();
2467
2468 let out_len = r2c_output_size(n_fft);
2469 let df = sample_rate_hz / n_fft as f64;
2470
2471 // Generate logarithmically-spaced frequencies
2472 let log_f_min = f_min.ln();
2473 let log_f_max = f_max.ln();
2474 let log_step = (log_f_max - log_f_min) / (n_bins - 1) as f64;
2475
2476 let mut log_frequencies = Vec::with_capacity(n_bins);
2477 for i in 0..n_bins {
2478 let log_f = (i as f64).mul_add(log_step, log_f_min);
2479 log_frequencies.push(log_f.exp());
2480 }
2481 // safety: n_bins > 0
2482 let log_frequencies = unsafe { NonEmptyVec::new_unchecked(log_frequencies) };
2483
2484 // Build interpolation matrix as sparse matrix
2485 let mut matrix = SparseMatrix::new(n_bins, out_len);
2486
2487 for (bin_idx, &target_freq) in log_frequencies.iter().enumerate() {
2488 // Find the two FFT bins that bracket this frequency
2489 let exact_bin = target_freq / df;
2490 let lower_bin = exact_bin.floor() as usize;
2491 let upper_bin = (exact_bin.ceil() as usize).min(out_len - 1);
2492
2493 if lower_bin >= out_len {
2494 continue;
2495 }
2496
2497 if lower_bin == upper_bin {
2498 // Exact match
2499 matrix.set(bin_idx, lower_bin, 1.0);
2500 } else {
2501 // Linear interpolation
2502 let frac = exact_bin - lower_bin as f64;
2503 matrix.set(bin_idx, lower_bin, 1.0 - frac);
2504 if upper_bin < out_len {
2505 matrix.set(bin_idx, upper_bin, frac);
2506 }
2507 }
2508 }
2509
2510 Ok((matrix, log_frequencies))
2511}
2512
2513fn mel_band_centres_hz(
2514 n_mels: NonZeroUsize,
2515 sample_rate_hz: f64,
2516 nyquist_hz: f64,
2517) -> NonEmptyVec<f64> {
2518 let f_min = 0.0;
2519 let f_max = nyquist_hz.min(sample_rate_hz * 0.5);
2520
2521 let mel_min = hz_to_mel(f_min);
2522 let mel_max = hz_to_mel(f_max);
2523 let n_mels = n_mels.get();
2524 let step = (mel_max - mel_min) / (n_mels + 1) as f64;
2525
2526 let mut centres = Vec::with_capacity(n_mels);
2527 for i in 0..n_mels {
2528 let mel = (i as f64 + 1.0).mul_add(step, mel_min);
2529 centres.push(mel_to_hz(mel));
2530 }
2531 // safety: centres is guaranteed non-empty since n_mels > 0
2532 unsafe { NonEmptyVec::new_unchecked(centres) }
2533}
2534
2535/// Spectrogram structure holding the computed spectrogram data and metadata.
2536///
2537/// # Type Parameters
2538///
2539/// * `FreqScale`: The frequency scale type (e.g., `LinearHz`, `Mel`, `LogHz`, etc.).
2540/// * `AmpScale`: The amplitude scale type (e.g., `Power`, `Magnitude`, `Decibels`).
2541///
2542/// # Fields
2543///
2544/// * `data`: A 2D array containing the spectrogram data.
2545/// * `axes`: The axes of the spectrogram (frequency and time).
2546/// * `params`: The parameters used to compute the spectrogram.
2547/// * `_amp`: A phantom data marker for the amplitude scale type.
2548#[derive(Debug, Clone)]
2549#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
2550pub struct Spectrogram<FreqScale, AmpScale>
2551where
2552 AmpScale: AmpScaleSpec + 'static,
2553 FreqScale: Copy + Clone + 'static,
2554{
2555 data: Array2<f64>,
2556 axes: Axes<FreqScale>,
2557 params: SpectrogramParams,
2558 #[cfg_attr(feature = "serde", serde(skip))]
2559 _amp: PhantomData<AmpScale>,
2560}
2561
2562impl<FreqScale, AmpScale> Spectrogram<FreqScale, AmpScale>
2563where
2564 AmpScale: AmpScaleSpec + 'static,
2565 FreqScale: Copy + Clone + 'static,
2566{
2567 /// Get the X-axis label.
2568 ///
2569 /// # Returns
2570 ///
2571 /// A static string slice representing the X-axis label.
2572 #[inline]
2573 #[must_use]
2574 pub const fn x_axis_label() -> &'static str {
2575 "Time (s)"
2576 }
2577
2578 /// Get the Y-axis label based on the frequency scale type.
2579 ///
2580 /// # Returns
2581 ///
2582 /// A static string slice representing the Y-axis label.
2583 #[inline]
2584 #[must_use]
2585 pub fn y_axis_label() -> &'static str {
2586 match std::any::TypeId::of::<FreqScale>() {
2587 id if id == std::any::TypeId::of::<LinearHz>() => "Frequency (Hz)",
2588 id if id == std::any::TypeId::of::<Mel>() => "Frequency (Mel)",
2589 id if id == std::any::TypeId::of::<LogHz>() => "Frequency (Log Hz)",
2590 id if id == std::any::TypeId::of::<Erb>() => "Frequency (ERB)",
2591 id if id == std::any::TypeId::of::<Cqt>() => "Frequency (CQT Bins)",
2592 _ => "Frequency",
2593 }
2594 }
2595
2596 /// Internal constructor. Only callable inside the crate.
2597 ///
2598 /// All inputs must already be validated and consistent.
2599 pub(crate) fn new(data: Array2<f64>, axes: Axes<FreqScale>, params: SpectrogramParams) -> Self {
2600 debug_assert_eq!(data.nrows(), axes.frequencies().len().get());
2601 debug_assert_eq!(data.ncols(), axes.times().len().get());
2602
2603 Self {
2604 data,
2605 axes,
2606 params,
2607 _amp: PhantomData,
2608 }
2609 }
2610
2611 /// Set spectrogram data matrix
2612 ///
2613 /// # Arguments
2614 ///
2615 /// * `data` - The new spectrogram data matrix.
2616 #[inline]
2617 pub fn set_data(&mut self, data: Array2<f64>) {
2618 self.data = data;
2619 }
2620
2621 /// Spectrogram data matrix
2622 ///
2623 /// # Returns
2624 ///
2625 /// A reference to the spectrogram data matrix.
2626 #[inline]
2627 #[must_use]
2628 pub const fn data(&self) -> &Array2<f64> {
2629 &self.data
2630 }
2631
2632 /// Consume the spectrogram and extract the data matrix.
2633 ///
2634 /// This method moves the data out of the spectrogram, consuming it.
2635 /// Useful for transferring ownership to Python without copying.
2636 ///
2637 /// # Returns
2638 ///
2639 /// The owned spectrogram data matrix.
2640 #[inline]
2641 #[must_use]
2642 pub fn into_data(self) -> Array2<f64> {
2643 self.data
2644 }
2645
2646 /// Axes of the spectrogram
2647 ///
2648 /// # Returns
2649 ///
2650 /// A reference to the axes of the spectrogram.
2651 #[inline]
2652 #[must_use]
2653 pub const fn axes(&self) -> &Axes<FreqScale> {
2654 &self.axes
2655 }
2656
2657 /// Frequency axis in Hz
2658 ///
2659 /// # Returns
2660 ///
2661 /// A reference to the frequency axis in Hz.
2662 #[inline]
2663 #[must_use]
2664 pub fn frequencies(&self) -> &NonEmptySlice<f64> {
2665 self.axes.frequencies()
2666 }
2667
2668 /// Frequency range in Hz (min, max)
2669 ///
2670 /// # Returns
2671 ///
2672 /// A tuple containing the minimum and maximum frequencies in Hz.
2673 #[inline]
2674 #[must_use]
2675 pub const fn frequency_range(&self) -> (f64, f64) {
2676 self.axes.frequency_range()
2677 }
2678
2679 /// Time axis in seconds
2680 ///
2681 /// # Returns
2682 ///
2683 /// A reference to the time axis in seconds.
2684 #[inline]
2685 #[must_use]
2686 pub fn times(&self) -> &NonEmptySlice<f64> {
2687 self.axes.times()
2688 }
2689
2690 /// Spectrogram computation parameters
2691 ///
2692 /// # Returns
2693 ///
2694 /// A reference to the spectrogram computation parameters.
2695 #[inline]
2696 #[must_use]
2697 pub const fn params(&self) -> &SpectrogramParams {
2698 &self.params
2699 }
2700
2701 /// Duration of the spectrogram in seconds
2702 ///
2703 /// # Returns
2704 ///
2705 /// The duration of the spectrogram in seconds.
2706 #[inline]
2707 #[must_use]
2708 pub fn duration(&self) -> f64 {
2709 self.axes.duration()
2710 }
2711
2712 /// If this is a dB spectrogram, return the (min, max) dB values. otherwise do the maths to compute dB range.
2713 ///
2714 /// # Returns
2715 ///
2716 /// The (min, max) dB values of the spectrogram, or `None` if the amplitude scale is unknown.
2717 #[inline]
2718 #[must_use]
2719 pub fn db_range(&self) -> Option<(f64, f64)> {
2720 let type_self = std::any::TypeId::of::<AmpScale>();
2721
2722 if type_self == std::any::TypeId::of::<Decibels>() {
2723 let (min, max) = min_max_single_pass(self.data.as_slice()?);
2724 Some((min, max))
2725 } else if type_self == std::any::TypeId::of::<Power>() {
2726 // Not a dB spectrogram; compute dB range from power values
2727 let mut min_db = f64::INFINITY;
2728 let mut max_db = f64::NEG_INFINITY;
2729 for &v in &self.data {
2730 let db = 10.0 * (v + EPS).log10();
2731 if db < min_db {
2732 min_db = db;
2733 }
2734 if db > max_db {
2735 max_db = db;
2736 }
2737 }
2738 Some((min_db, max_db))
2739 } else if type_self == std::any::TypeId::of::<Magnitude>() {
2740 // Not a dB spectrogram; compute dB range from magnitude values
2741 let mut min_db = f64::INFINITY;
2742 let mut max_db = f64::NEG_INFINITY;
2743
2744 for &v in &self.data {
2745 let power = v * v;
2746 let db = 10.0 * (power + EPS).log10();
2747 if db < min_db {
2748 min_db = db;
2749 }
2750 if db > max_db {
2751 max_db = db;
2752 }
2753 }
2754
2755 Some((min_db, max_db))
2756 } else {
2757 // Unknown AmpScale type; return dummy values
2758 None
2759 }
2760 }
2761
2762 /// Number of frequency bins
2763 ///
2764 /// # Returns
2765 ///
2766 /// The number of frequency bins in the spectrogram.
2767 #[inline]
2768 #[must_use]
2769 pub fn n_bins(&self) -> NonZeroUsize {
2770 // safety: data.nrows() > 0 is guaranteed by construction
2771 unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
2772 }
2773
2774 /// Number of time frames in the spectrogram
2775 ///
2776 /// # Returns
2777 ///
2778 /// The number of time frames (columns) in the spectrogram.
2779 #[inline]
2780 #[must_use]
2781 pub fn n_frames(&self) -> NonZeroUsize {
2782 // safety: data.ncols() > 0 is guaranteed by construction
2783 unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
2784 }
2785}
2786
2787impl<FreqScale, AmpScale> AsRef<Array2<f64>> for Spectrogram<FreqScale, AmpScale>
2788where
2789 FreqScale: Copy + Clone + 'static,
2790 AmpScale: AmpScaleSpec + 'static,
2791{
2792 #[inline]
2793 fn as_ref(&self) -> &Array2<f64> {
2794 &self.data
2795 }
2796}
2797
2798impl<FreqScale, AmpScale> Deref for Spectrogram<FreqScale, AmpScale>
2799where
2800 FreqScale: Copy + Clone + 'static,
2801 AmpScale: AmpScaleSpec + 'static,
2802{
2803 type Target = Array2<f64>;
2804
2805 #[inline]
2806 fn deref(&self) -> &Self::Target {
2807 &self.data
2808 }
2809}
2810
2811impl<AmpScale> Spectrogram<LinearHz, AmpScale>
2812where
2813 AmpScale: AmpScaleSpec + 'static,
2814{
2815 /// Compute a linear-frequency spectrogram from audio samples.
2816 ///
2817 /// This is a convenience method that creates a planner internally and computes
2818 /// the spectrogram in one call. For processing multiple signals with the same
2819 /// parameters, use [`SpectrogramPlanner::linear_plan`] to create a reusable plan.
2820 ///
2821 /// # Arguments
2822 ///
2823 /// * `samples` - Audio samples (any type that can be converted to a slice)
2824 /// * `params` - Spectrogram computation parameters
2825 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2826 ///
2827 /// # Returns
2828 ///
2829 /// A linear-frequency spectrogram with the specified amplitude scale.
2830 ///
2831 /// # Errors
2832 ///
2833 /// Returns an error if:
2834 /// - The samples slice is empty
2835 /// - Parameters are invalid
2836 /// - FFT computation fails
2837 ///
2838 /// # Examples
2839 ///
2840 /// ```
2841 /// use spectrograms::*;
2842 /// use non_empty_slice::non_empty_vec;
2843 ///
2844 /// # fn example() -> SpectrogramResult<()> {
2845 /// // Create a simple test signal
2846 /// let sample_rate = 16000.0;
2847 /// let samples_vec: Vec<f64> = (0..16000).map(|i| {
2848 /// (2.0 * std::f64::consts::PI * 440.0 * i as f64 / sample_rate).sin()
2849 /// }).collect();
2850 /// let samples = non_empty_slice::NonEmptyVec::new(samples_vec).unwrap();
2851 ///
2852 /// // Set up parameters
2853 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2854 /// let params = SpectrogramParams::new(stft, sample_rate)?;
2855 ///
2856 /// // Compute power spectrogram
2857 /// let spec = LinearPowerSpectrogram::compute(&samples, ¶ms, None)?;
2858 ///
2859 /// println!("Computed spectrogram: {} bins x {} frames", spec.n_bins(), spec.n_frames());
2860 /// # Ok(())
2861 /// # }
2862 /// ```
2863 #[inline]
2864 pub fn compute(
2865 samples: &NonEmptySlice<f64>,
2866 params: &SpectrogramParams,
2867 db: Option<&LogParams>,
2868 ) -> SpectrogramResult<Self> {
2869 let planner = SpectrogramPlanner::new();
2870 let mut plan = planner.linear_plan(params, db)?;
2871 plan.compute(samples)
2872 }
2873}
2874
2875impl<AmpScale> Spectrogram<Mel, AmpScale>
2876where
2877 AmpScale: AmpScaleSpec + 'static,
2878{
2879 /// Compute a mel-frequency spectrogram from audio samples.
2880 ///
2881 /// This is a convenience method that creates a planner internally and computes
2882 /// the spectrogram in one call. For processing multiple signals with the same
2883 /// parameters, use [`SpectrogramPlanner::mel_plan`] to create a reusable plan.
2884 ///
2885 /// # Arguments
2886 ///
2887 /// * `samples` - Audio samples (any type that can be converted to a slice)
2888 /// * `params` - Spectrogram computation parameters
2889 /// * `mel` - Mel filterbank parameters
2890 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2891 ///
2892 /// # Returns
2893 ///
2894 /// A mel-frequency spectrogram with the specified amplitude scale.
2895 ///
2896 /// # Errors
2897 ///
2898 /// Returns an error if:
2899 /// - The samples slice is empty
2900 /// - Parameters are invalid
2901 /// - Mel `f_max` exceeds Nyquist frequency
2902 /// - FFT computation fails
2903 ///
2904 /// # Examples
2905 ///
2906 /// ```
2907 /// use spectrograms::*;
2908 /// use non_empty_slice::non_empty_vec;
2909 ///
2910 /// # fn example() -> SpectrogramResult<()> {
2911 /// // Create a simple test signal
2912 /// let sample_rate = 16000.0;
2913 /// let samples_vec: Vec<f64> = (0..16000).map(|i| {
2914 /// (2.0 * std::f64::consts::PI * 440.0 * i as f64 / sample_rate).sin()
2915 /// }).collect();
2916 /// let samples = non_empty_slice::NonEmptyVec::new(samples_vec).unwrap();
2917 ///
2918 /// // Set up parameters
2919 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2920 /// let params = SpectrogramParams::new(stft, sample_rate)?;
2921 /// let mel = MelParams::new(nzu!(80), 0.0, 8000.0)?;
2922 ///
2923 /// // Compute mel spectrogram in dB scale
2924 /// let db = LogParams::new(-80.0)?;
2925 /// let spec = MelDbSpectrogram::compute(&samples, ¶ms, &mel, Some(&db))?;
2926 ///
2927 /// println!("Computed mel spectrogram: {} mels x {} frames", spec.n_bins(), spec.n_frames());
2928 /// # Ok(())
2929 /// # }
2930 /// ```
2931 #[inline]
2932 pub fn compute(
2933 samples: &NonEmptySlice<f64>,
2934 params: &SpectrogramParams,
2935 mel: &MelParams,
2936 db: Option<&LogParams>,
2937 ) -> SpectrogramResult<Self> {
2938 let planner = SpectrogramPlanner::new();
2939 let mut plan = planner.mel_plan(params, mel, db)?;
2940 plan.compute(samples)
2941 }
2942}
2943
2944impl<AmpScale> Spectrogram<Erb, AmpScale>
2945where
2946 AmpScale: AmpScaleSpec + 'static,
2947{
2948 /// Compute an ERB-frequency spectrogram from audio samples.
2949 ///
2950 /// This is a convenience method that creates a planner internally and computes
2951 /// the spectrogram in one call. For processing multiple signals with the same
2952 /// parameters, use [`SpectrogramPlanner::erb_plan`] to create a reusable plan.
2953 ///
2954 /// # Arguments
2955 ///
2956 /// * `samples` - Audio samples (any type that can be converted to a slice)
2957 /// * `params` - Spectrogram computation parameters
2958 /// * `erb` - ERB frequency scale parameters
2959 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2960 ///
2961 /// # Returns
2962 ///
2963 /// An ERB-scale spectrogram with the specified amplitude scale.
2964 ///
2965 /// # Errors
2966 ///
2967 /// Returns an error if:
2968 /// - The samples slice is empty
2969 /// - Parameters are invalid
2970 /// - FFT computation fails
2971 ///
2972 /// # Examples
2973 ///
2974 /// ```
2975 /// use spectrograms::*;
2976 /// use non_empty_slice::non_empty_vec;
2977 ///
2978 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2979 /// let samples = non_empty_vec![0.0; nzu!(16000)];
2980 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2981 /// let params = SpectrogramParams::new(stft, 16000.0)?;
2982 /// let erb = ErbParams::speech_standard();
2983 ///
2984 /// let spec = ErbPowerSpectrogram::compute(&samples, ¶ms, &erb, None)?;
2985 /// assert_eq!(spec.n_bins(), nzu!(40));
2986 /// # Ok(())
2987 /// # }
2988 /// ```
2989 #[inline]
2990 pub fn compute(
2991 samples: &NonEmptySlice<f64>,
2992 params: &SpectrogramParams,
2993 erb: &ErbParams,
2994 db: Option<&LogParams>,
2995 ) -> SpectrogramResult<Self> {
2996 let planner = SpectrogramPlanner::new();
2997 let mut plan = planner.erb_plan(params, erb, db)?;
2998 plan.compute(samples)
2999 }
3000}
3001
3002impl<AmpScale> Spectrogram<LogHz, AmpScale>
3003where
3004 AmpScale: AmpScaleSpec + 'static,
3005{
3006 /// Compute a logarithmic-frequency spectrogram from audio samples.
3007 ///
3008 /// This is a convenience method that creates a planner internally and computes
3009 /// the spectrogram in one call. For processing multiple signals with the same
3010 /// parameters, use [`SpectrogramPlanner::log_hz_plan`] to create a reusable plan.
3011 ///
3012 /// # Arguments
3013 ///
3014 /// * `samples` - Audio samples (any type that can be converted to a slice)
3015 /// * `params` - Spectrogram computation parameters
3016 /// * `loghz` - Logarithmic frequency scale parameters
3017 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
3018 ///
3019 /// # Returns
3020 ///
3021 /// A logarithmic-frequency spectrogram with the specified amplitude scale.
3022 ///
3023 /// # Errors
3024 ///
3025 /// Returns an error if:
3026 /// - The samples slice is empty
3027 /// - Parameters are invalid
3028 /// - FFT computation fails
3029 ///
3030 /// # Examples
3031 ///
3032 /// ```
3033 /// use spectrograms::*;
3034 /// use non_empty_slice::non_empty_vec;
3035 ///
3036 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
3037 /// let samples = non_empty_vec![0.0; nzu!(16000)];
3038 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
3039 /// let params = SpectrogramParams::new(stft, 16000.0)?;
3040 /// let loghz = LogHzParams::new(nzu!(128), 20.0, 8000.0)?;
3041 ///
3042 /// let spec = LogHzPowerSpectrogram::compute(&samples, ¶ms, &loghz, None)?;
3043 /// assert_eq!(spec.n_bins(), nzu!(128));
3044 /// # Ok(())
3045 /// # }
3046 /// ```
3047 #[inline]
3048 pub fn compute(
3049 samples: &NonEmptySlice<f64>,
3050 params: &SpectrogramParams,
3051 loghz: &LogHzParams,
3052 db: Option<&LogParams>,
3053 ) -> SpectrogramResult<Self> {
3054 let planner = SpectrogramPlanner::new();
3055 let mut plan = planner.log_hz_plan(params, loghz, db)?;
3056 plan.compute(samples)
3057 }
3058}
3059
3060impl<AmpScale> Spectrogram<Cqt, AmpScale>
3061where
3062 AmpScale: AmpScaleSpec + 'static,
3063{
3064 /// Compute a constant-Q transform (CQT) spectrogram from audio samples.
3065 ///
3066 /// # Arguments
3067 ///
3068 /// * `samples` - Audio samples (any type that can be converted to a slice)
3069 /// * `params` - Spectrogram computation parameters
3070 /// * `cqt` - CQT parameters
3071 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
3072 ///
3073 /// # Returns
3074 ///
3075 /// A CQT spectrogram with the specified amplitude scale.
3076 ///
3077 /// # Errors
3078 ///
3079 /// Returns an error if:
3080 ///
3081 #[inline]
3082 pub fn compute(
3083 samples: &NonEmptySlice<f64>,
3084 params: &SpectrogramParams,
3085 cqt: &CqtParams,
3086 db: Option<&LogParams>,
3087 ) -> SpectrogramResult<Self> {
3088 let planner = SpectrogramPlanner::new();
3089 let mut plan = planner.cqt_plan(params, cqt, db)?;
3090 plan.compute(samples)
3091 }
3092}
3093
3094// ========================
3095// Display implementations
3096// ========================
3097
3098/// Helper function to get amplitude scale name
3099fn amp_scale_name<AmpScale>() -> &'static str
3100where
3101 AmpScale: AmpScaleSpec + 'static,
3102{
3103 match std::any::TypeId::of::<AmpScale>() {
3104 id if id == std::any::TypeId::of::<Power>() => "Power",
3105 id if id == std::any::TypeId::of::<Magnitude>() => "Magnitude",
3106 id if id == std::any::TypeId::of::<Decibels>() => "Decibels",
3107 _ => "Unknown",
3108 }
3109}
3110
3111/// Helper function to get frequency scale name
3112fn freq_scale_name<FreqScale>() -> &'static str
3113where
3114 FreqScale: Copy + Clone + 'static,
3115{
3116 match std::any::TypeId::of::<FreqScale>() {
3117 id if id == std::any::TypeId::of::<LinearHz>() => "Linear Hz",
3118 id if id == std::any::TypeId::of::<LogHz>() => "Log Hz",
3119 id if id == std::any::TypeId::of::<Mel>() => "Mel",
3120 id if id == std::any::TypeId::of::<Erb>() => "ERB",
3121 id if id == std::any::TypeId::of::<Cqt>() => "CQT",
3122 _ => "Unknown",
3123 }
3124}
3125
3126impl<FreqScale, AmpScale> core::fmt::Display for Spectrogram<FreqScale, AmpScale>
3127where
3128 AmpScale: AmpScaleSpec + 'static,
3129 FreqScale: Copy + Clone + 'static,
3130{
3131 #[inline]
3132 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
3133 let (freq_min, freq_max) = self.frequency_range();
3134 let duration = self.duration();
3135 let (rows, cols) = self.data.dim();
3136
3137 // Alternative formatting (#) provides more detailed output with data
3138 if f.alternate() {
3139 writeln!(f, "Spectrogram {{")?;
3140 writeln!(f, " Frequency Scale: {}", freq_scale_name::<FreqScale>())?;
3141 writeln!(f, " Amplitude Scale: {}", amp_scale_name::<AmpScale>())?;
3142 writeln!(f, " Shape: {rows} frequency bins × {cols} time frames")?;
3143 writeln!(f, " Frequency Range: {freq_min:.2} Hz - {freq_max:.2} Hz")?;
3144 writeln!(f, " Duration: {duration:.3} s")?;
3145 writeln!(f)?;
3146
3147 // Display parameters
3148 writeln!(f, " Parameters:")?;
3149 writeln!(f, " Sample Rate: {} Hz", self.params.sample_rate_hz())?;
3150 writeln!(f, " FFT Size: {}", self.params.stft().n_fft())?;
3151 writeln!(f, " Hop Size: {}", self.params.stft().hop_size())?;
3152 writeln!(f, " Window: {:?}", self.params.stft().window())?;
3153 writeln!(f, " Centered: {}", self.params.stft().centre())?;
3154 writeln!(f)?;
3155
3156 // Display data statistics
3157 let data_slice = self.data.as_slice().unwrap_or(&[]);
3158 if !data_slice.is_empty() {
3159 let (min_val, max_val) = min_max_single_pass(data_slice);
3160 let mean = data_slice.iter().sum::<f64>() / data_slice.len() as f64;
3161 writeln!(f, " Data Statistics:")?;
3162 writeln!(f, " Min: {min_val:.6}")?;
3163 writeln!(f, " Max: {max_val:.6}")?;
3164 writeln!(f, " Mean: {mean:.6}")?;
3165 writeln!(f)?;
3166 }
3167
3168 // Display actual data (truncated if too large)
3169 writeln!(f, " Data Matrix:")?;
3170 let max_rows_to_display = 5;
3171 let max_cols_to_display = 5;
3172
3173 for i in 0..rows.min(max_rows_to_display) {
3174 write!(f, " [")?;
3175 for j in 0..cols.min(max_cols_to_display) {
3176 if j > 0 {
3177 write!(f, ", ")?;
3178 }
3179 write!(f, "{:9.4}", self.data[[i, j]])?;
3180 }
3181 if cols > max_cols_to_display {
3182 write!(f, ", ... ({} more)", cols - max_cols_to_display)?;
3183 }
3184 writeln!(f, "]")?;
3185 }
3186
3187 if rows > max_rows_to_display {
3188 writeln!(f, " ... ({} more rows)", rows - max_rows_to_display)?;
3189 }
3190
3191 write!(f, "}}")?;
3192 } else {
3193 // Default formatting: compact summary
3194 write!(
3195 f,
3196 "Spectrogram<{}, {}>[{}x{}] ({:.2}-{:.2} Hz, {:.3}s)",
3197 freq_scale_name::<FreqScale>(),
3198 amp_scale_name::<AmpScale>(),
3199 rows,
3200 cols,
3201 freq_min,
3202 freq_max,
3203 duration
3204 )?;
3205 }
3206
3207 Ok(())
3208 }
3209}
3210
3211#[derive(Debug, Clone)]
3212#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3213pub struct FrequencyAxis<FreqScale>
3214where
3215 FreqScale: Copy + Clone + 'static,
3216{
3217 frequencies: NonEmptyVec<f64>,
3218 #[cfg_attr(feature = "serde", serde(skip))]
3219 _marker: PhantomData<FreqScale>,
3220}
3221
3222impl<FreqScale> FrequencyAxis<FreqScale>
3223where
3224 FreqScale: Copy + Clone + 'static,
3225{
3226 pub(crate) const fn new(frequencies: NonEmptyVec<f64>) -> Self {
3227 Self {
3228 frequencies,
3229 _marker: PhantomData,
3230 }
3231 }
3232
3233 /// Get the frequency values in Hz.
3234 ///
3235 /// # Returns
3236 ///
3237 /// Returns a non-empty slice of frequencies.
3238 #[inline]
3239 #[must_use]
3240 pub fn frequencies(&self) -> &NonEmptySlice<f64> {
3241 &self.frequencies
3242 }
3243
3244 /// Get the frequency range (min, max) in Hz.
3245 ///
3246 /// # Returns
3247 ///
3248 /// Returns a tuple containing the minimum and maximum frequency.
3249 #[inline]
3250 #[must_use]
3251 pub const fn frequency_range(&self) -> (f64, f64) {
3252 let data = self.frequencies.as_slice();
3253 let min = data[0];
3254 let max_idx = data.len().saturating_sub(1); // safe for non-empty
3255 let max = data[max_idx];
3256 (min, max)
3257 }
3258
3259 /// Get the number of frequency bins.
3260 ///
3261 /// # Returns
3262 ///
3263 /// Returns the number of frequency bins as a NonZeroUsize.
3264 #[inline]
3265 #[must_use]
3266 pub const fn len(&self) -> NonZeroUsize {
3267 self.frequencies.len()
3268 }
3269}
3270
3271/// Spectrogram axes container.
3272///
3273/// Holds frequency and time axes.
3274#[derive(Debug, Clone)]
3275#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3276pub struct Axes<FreqScale>
3277where
3278 FreqScale: Copy + Clone + 'static,
3279{
3280 freq: FrequencyAxis<FreqScale>,
3281 times: NonEmptyVec<f64>,
3282}
3283
3284impl<FreqScale> Axes<FreqScale>
3285where
3286 FreqScale: Copy + Clone + 'static,
3287{
3288 pub(crate) const fn new(freq: FrequencyAxis<FreqScale>, times: NonEmptyVec<f64>) -> Self {
3289 Self { freq, times }
3290 }
3291
3292 /// Get the frequency values in Hz.
3293 ///
3294 /// # Returns
3295 ///
3296 /// Returns a non-empty slice of frequencies.
3297 #[inline]
3298 #[must_use]
3299 pub fn frequencies(&self) -> &NonEmptySlice<f64> {
3300 self.freq.frequencies()
3301 }
3302
3303 /// Get the time values in seconds.
3304 ///
3305 /// # Returns
3306 ///
3307 /// Returns a non-empty slice of time values.
3308 #[inline]
3309 #[must_use]
3310 pub fn times(&self) -> &NonEmptySlice<f64> {
3311 &self.times
3312 }
3313
3314 /// Get the frequency range (min, max) in Hz.
3315 ///
3316 /// # Returns
3317 ///
3318 /// Returns a tuple containing the minimum and maximum frequency.
3319 #[inline]
3320 #[must_use]
3321 pub const fn frequency_range(&self) -> (f64, f64) {
3322 self.freq.frequency_range()
3323 }
3324
3325 /// Get the duration of the spectrogram in seconds.
3326 ///
3327 /// # Returns
3328 ///
3329 /// Returns the duration in seconds.
3330 #[inline]
3331 #[must_use]
3332 pub fn duration(&self) -> f64 {
3333 *self.times.last()
3334 }
3335}
3336
3337// Enum types for frequency and amplitude scales
3338
3339/// Linear frequency scale
3340#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3341#[cfg_attr(feature = "python", pyclass(from_py_object))]
3342#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3343#[non_exhaustive]
3344pub enum LinearHz {
3345 _Phantom,
3346}
3347
3348/// Logarithmic frequency scale
3349#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3350#[cfg_attr(feature = "python", pyclass(from_py_object))]
3351#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3352#[non_exhaustive]
3353pub enum LogHz {
3354 _Phantom,
3355}
3356
3357/// Mel frequency scale
3358#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3359#[cfg_attr(feature = "python", pyclass(from_py_object))]
3360#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3361#[non_exhaustive]
3362pub enum Mel {
3363 _Phantom,
3364}
3365
3366/// ERB/gammatone frequency scale
3367#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3368#[cfg_attr(feature = "python", pyclass(from_py_object))]
3369#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3370#[non_exhaustive]
3371pub enum Erb {
3372 _Phantom,
3373}
3374pub type Gammatone = Erb;
3375
3376/// Constant-Q Transform frequency scale
3377#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3378#[cfg_attr(feature = "python", pyclass(from_py_object))]
3379#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3380#[non_exhaustive]
3381pub enum Cqt {
3382 _Phantom,
3383}
3384
3385// Amplitude scales
3386
3387/// Power amplitude scale
3388#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3389#[cfg_attr(feature = "python", pyclass(from_py_object))]
3390#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3391#[non_exhaustive]
3392pub enum Power {
3393 _Phantom,
3394}
3395
3396/// Decibel amplitude scale
3397#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3398#[cfg_attr(feature = "python", pyclass(from_py_object))]
3399#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3400#[non_exhaustive]
3401pub enum Decibels {
3402 _Phantom,
3403}
3404
3405/// Magnitude amplitude scale
3406#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3407#[cfg_attr(feature = "python", pyclass(from_py_object))]
3408#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3409#[non_exhaustive]
3410pub enum Magnitude {
3411 _Phantom,
3412}
3413
3414/// STFT parameters for spectrogram computation.
3415///
3416/// * `n_fft`: Size of the FFT window.
3417/// * `hop_size`: Number of samples between successive frames.
3418/// * window: Window function to apply to each frame.
3419/// * centre: Whether to pad the input signal so that frames are centered.
3420#[derive(Debug, Clone, PartialEq)]
3421#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3422pub struct StftParams {
3423 n_fft: NonZeroUsize,
3424 hop_size: NonZeroUsize,
3425 window: WindowType,
3426 centre: bool,
3427}
3428
3429impl StftParams {
3430 /// Create new STFT parameters.
3431 ///
3432 /// # Arguments
3433 ///
3434 /// * `n_fft` - Size of the FFT window
3435 /// * `hop_size` - Number of samples between successive frames
3436 /// * `window` - Window function to apply to each frame
3437 /// * `centre` - Whether to pad the input signal so that frames are centered
3438 ///
3439 /// # Errors
3440 ///
3441 /// Returns an error if:
3442 /// - `hop_size` > `n_fft`
3443 /// - Custom window size doesn't match `n_fft`
3444 ///
3445 /// # Returns
3446 ///
3447 /// New `StftParams` instance.
3448 #[inline]
3449 pub fn new(
3450 n_fft: NonZeroUsize,
3451 hop_size: NonZeroUsize,
3452 window: WindowType,
3453 centre: bool,
3454 ) -> SpectrogramResult<Self> {
3455 if hop_size.get() > n_fft.get() {
3456 return Err(SpectrogramError::invalid_input("hop_size must be <= n_fft"));
3457 }
3458
3459 // Validate custom window size matches n_fft
3460 if let WindowType::Custom { size, .. } = &window {
3461 if size.get() != n_fft.get() {
3462 return Err(SpectrogramError::invalid_input(format!(
3463 "Custom window size ({}) must match n_fft ({})",
3464 size.get(),
3465 n_fft.get()
3466 )));
3467 }
3468 }
3469
3470 Ok(Self {
3471 n_fft,
3472 hop_size,
3473 window,
3474 centre,
3475 })
3476 }
3477
3478 const unsafe fn new_unchecked(
3479 n_fft: NonZeroUsize,
3480 hop_size: NonZeroUsize,
3481 window: WindowType,
3482 centre: bool,
3483 ) -> Self {
3484 Self {
3485 n_fft,
3486 hop_size,
3487 window,
3488 centre,
3489 }
3490 }
3491
3492 /// Get the FFT window size.
3493 ///
3494 /// # Returns
3495 ///
3496 /// The FFT window size.
3497 #[inline]
3498 #[must_use]
3499 pub const fn n_fft(&self) -> NonZeroUsize {
3500 self.n_fft
3501 }
3502
3503 /// Get the hop size (samples between successive frames).
3504 ///
3505 /// # Returns
3506 ///
3507 /// The hop size.
3508 #[inline]
3509 #[must_use]
3510 pub const fn hop_size(&self) -> NonZeroUsize {
3511 self.hop_size
3512 }
3513
3514 /// Get the window function.
3515 ///
3516 /// # Returns
3517 ///
3518 /// The window function.
3519 #[inline]
3520 #[must_use]
3521 pub fn window(&self) -> WindowType {
3522 self.window.clone()
3523 }
3524
3525 /// Get whether frames are centered (input signal is padded).
3526 ///
3527 /// # Returns
3528 ///
3529 /// `true` if frames are centered, `false` otherwise.
3530 #[inline]
3531 #[must_use]
3532 pub const fn centre(&self) -> bool {
3533 self.centre
3534 }
3535
3536 /// Create a builder for STFT parameters.
3537 ///
3538 /// # Returns
3539 ///
3540 /// A `StftParamsBuilder` instance.
3541 ///
3542 /// # Examples
3543 ///
3544 /// ```
3545 /// use spectrograms::{StftParams, WindowType, nzu};
3546 ///
3547 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
3548 /// let stft = StftParams::builder()
3549 /// .n_fft(nzu!(2048))
3550 /// .hop_size(nzu!(512))
3551 /// .window(WindowType::Hanning)
3552 /// .centre(true)
3553 /// .build()?;
3554 ///
3555 /// assert_eq!(stft.n_fft(), nzu!(2048));
3556 /// assert_eq!(stft.hop_size(), nzu!(512));
3557 /// # Ok(())
3558 /// # }
3559 /// ```
3560 #[inline]
3561 #[must_use]
3562 pub fn builder() -> StftParamsBuilder {
3563 StftParamsBuilder::default()
3564 }
3565}
3566
3567/// Builder for [`StftParams`].
3568#[derive(Debug, Clone)]
3569pub struct StftParamsBuilder {
3570 n_fft: Option<NonZeroUsize>,
3571 hop_size: Option<NonZeroUsize>,
3572 window: WindowType,
3573 centre: bool,
3574}
3575
3576impl Default for StftParamsBuilder {
3577 #[inline]
3578 fn default() -> Self {
3579 Self {
3580 n_fft: None,
3581 hop_size: None,
3582 window: WindowType::Hanning,
3583 centre: true,
3584 }
3585 }
3586}
3587
3588impl StftParamsBuilder {
3589 /// Set the FFT window size.
3590 ///
3591 /// # Arguments
3592 ///
3593 /// * `n_fft` - Size of the FFT window
3594 ///
3595 /// # Returns
3596 ///
3597 /// The builder with the updated FFT window size.
3598 #[inline]
3599 #[must_use]
3600 pub const fn n_fft(mut self, n_fft: NonZeroUsize) -> Self {
3601 self.n_fft = Some(n_fft);
3602 self
3603 }
3604
3605 /// Set the hop size (samples between successive frames).
3606 ///
3607 /// # Arguments
3608 ///
3609 /// * `hop_size` - Number of samples between successive frames
3610 ///
3611 /// # Returns
3612 ///
3613 /// The builder with the updated hop size.
3614 #[inline]
3615 #[must_use]
3616 pub const fn hop_size(mut self, hop_size: NonZeroUsize) -> Self {
3617 self.hop_size = Some(hop_size);
3618 self
3619 }
3620
3621 /// Set the window function.
3622 ///
3623 /// # Arguments
3624 ///
3625 /// * `window` - Window function to apply to each frame
3626 ///
3627 /// # Returns
3628 ///
3629 /// The builder with the updated window function.
3630 #[inline]
3631 #[must_use]
3632 pub fn window(mut self, window: WindowType) -> Self {
3633 self.window = window;
3634 self
3635 }
3636
3637 /// Set whether to center frames (pad input signal).
3638 #[inline]
3639 #[must_use]
3640 pub const fn centre(mut self, centre: bool) -> Self {
3641 self.centre = centre;
3642 self
3643 }
3644
3645 /// Build the [`StftParams`].
3646 ///
3647 /// # Errors
3648 ///
3649 /// Returns an error if:
3650 /// - `n_fft` or `hop_size` are not set or are zero
3651 /// - `hop_size` > `n_fft`
3652 #[inline]
3653 pub fn build(self) -> SpectrogramResult<StftParams> {
3654 let n_fft = self
3655 .n_fft
3656 .ok_or_else(|| SpectrogramError::invalid_input("n_fft must be set"))?;
3657 let hop_size = self
3658 .hop_size
3659 .ok_or_else(|| SpectrogramError::invalid_input("hop_size must be set"))?;
3660
3661 StftParams::new(n_fft, hop_size, self.window, self.centre)
3662 }
3663}
3664
3665//
3666// ========================
3667// Mel parameters
3668// ========================
3669//
3670
3671/// Mel filterbank normalization strategy.
3672///
3673/// Determines how the triangular mel filters are normalized after construction.
3674#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3675#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3676#[non_exhaustive]
3677#[derive(Default)]
3678pub enum MelNorm {
3679 /// No normalization (triangular filters with peak = 1.0).
3680 ///
3681 /// This is the default and fastest option.
3682 #[default]
3683 None,
3684
3685 /// Slaney-style area normalization (librosa default).
3686 ///
3687 /// Each mel filter is divided by its bandwidth in Hz: `2.0 / (f_max - f_min)`.
3688 /// This ensures constant energy per mel band regardless of bandwidth.
3689 ///
3690 /// Use this for compatibility with librosa's default behavior.
3691 Slaney,
3692
3693 /// L1 normalization (sum of weights = 1.0).
3694 ///
3695 /// Each mel filter's weights are divided by their sum.
3696 /// Useful when you want each filter to act as a weighted average.
3697 L1,
3698
3699 /// L2 normalization (Euclidean norm = 1.0).
3700 ///
3701 /// Each mel filter's weights are divided by their L2 norm.
3702 /// Provides unit-norm filters in the L2 sense.
3703 L2,
3704}
3705
3706/// Mel filter bank parameters
3707///
3708/// * `n_mels`: Number of mel bands
3709/// * `f_min`: Minimum frequency (Hz)
3710/// * `f_max`: Maximum frequency (Hz)
3711/// * `norm`: Filterbank normalization strategy
3712#[derive(Debug, Clone, Copy, PartialEq)]
3713#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3714pub struct MelParams {
3715 n_mels: NonZeroUsize,
3716 f_min: f64,
3717 f_max: f64,
3718 norm: MelNorm,
3719}
3720
3721impl MelParams {
3722 /// Create new mel filter bank parameters.
3723 ///
3724 /// # Arguments
3725 ///
3726 /// * `n_mels` - Number of mel bands
3727 /// * `f_min` - Minimum frequency (Hz)
3728 /// * `f_max` - Maximum frequency (Hz)
3729 ///
3730 /// # Errors
3731 ///
3732 /// Returns an error if:
3733 /// - `f_min` is not >= 0
3734 /// - `f_max` is not > `f_min`
3735 ///
3736 /// # Returns
3737 ///
3738 /// A `MelParams` instance with no normalization (default).
3739 #[inline]
3740 pub fn new(n_mels: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
3741 Self::with_norm(n_mels, f_min, f_max, MelNorm::None)
3742 }
3743
3744 /// Create new mel filter bank parameters with specified normalization.
3745 ///
3746 /// # Arguments
3747 ///
3748 /// * `n_mels` - Number of mel bands
3749 /// * `f_min` - Minimum frequency (Hz)
3750 /// * `f_max` - Maximum frequency (Hz)
3751 /// * `norm` - Filterbank normalization strategy
3752 ///
3753 /// # Errors
3754 ///
3755 /// Returns an error if:
3756 /// - `f_min` is not >= 0
3757 /// - `f_max` is not > `f_min`
3758 ///
3759 /// # Returns
3760 ///
3761 /// A `MelParams` instance.
3762 #[inline]
3763 pub fn with_norm(
3764 n_mels: NonZeroUsize,
3765 f_min: f64,
3766 f_max: f64,
3767 norm: MelNorm,
3768 ) -> SpectrogramResult<Self> {
3769 if f_min < 0.0 {
3770 return Err(SpectrogramError::invalid_input("f_min must be >= 0"));
3771 }
3772
3773 if f_max <= f_min {
3774 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
3775 }
3776
3777 Ok(Self {
3778 n_mels,
3779 f_min,
3780 f_max,
3781 norm,
3782 })
3783 }
3784
3785 /// Create new mel filter bank parameters.
3786 ///
3787 /// # Arguments
3788 ///
3789 /// * `n_mels` - Number of mel bands
3790 /// * `f_min` - Minimum frequency (Hz)
3791 /// * `f_max` - Maximum frequency (Hz)
3792 ///
3793 /// # Returns
3794 ///
3795 /// A `MelParams` instance with no normalization (default).
3796 pub const unsafe fn new_unchecked(n_mels: NonZeroUsize, f_min: f64, f_max: f64) -> Self {
3797 Self {
3798 n_mels,
3799 f_min,
3800 f_max,
3801 norm: MelNorm::None,
3802 }
3803 }
3804
3805 /// Get the number of mel bands.
3806 ///
3807 /// # Returns
3808 ///
3809 /// The number of mel bands.
3810 #[inline]
3811 #[must_use]
3812 pub const fn n_mels(&self) -> NonZeroUsize {
3813 self.n_mels
3814 }
3815
3816 /// Get the minimum frequency (Hz).
3817 ///
3818 /// # Returns
3819 ///
3820 /// The minimum frequency in Hz.
3821 #[inline]
3822 #[must_use]
3823 pub const fn f_min(&self) -> f64 {
3824 self.f_min
3825 }
3826
3827 /// Get the maximum frequency (Hz).
3828 ///
3829 /// # Returns
3830 ///
3831 /// The maximum frequency in Hz.
3832 #[inline]
3833 #[must_use]
3834 pub const fn f_max(&self) -> f64 {
3835 self.f_max
3836 }
3837
3838 /// Get the filterbank normalization strategy.
3839 ///
3840 /// # Returns
3841 ///
3842 /// The normalization strategy.
3843 #[inline]
3844 #[must_use]
3845 pub const fn norm(&self) -> MelNorm {
3846 self.norm
3847 }
3848
3849 /// Create standard mel filterbank parameters.
3850 ///
3851 /// Uses 128 mel bands from 0 Hz to the Nyquist frequency.
3852 ///
3853 /// # Arguments
3854 ///
3855 /// * `sample_rate` - Sample rate in Hz (used to determine `f_max`)
3856 ///
3857 /// # Returns
3858 ///
3859 /// A `MelParams` instance with standard settings.
3860 ///
3861 /// # Panics
3862 ///
3863 /// Panics if `sample_rate` is not greater than 0.
3864 #[inline]
3865 #[must_use]
3866 pub const fn standard(sample_rate: f64) -> Self {
3867 assert!(sample_rate > 0.0);
3868 // safety: parameters are known to be valid
3869 unsafe { Self::new_unchecked(nzu!(128), 0.0, sample_rate / 2.0) }
3870 }
3871
3872 /// Create mel filterbank parameters optimized for speech.
3873 ///
3874 /// Uses 40 mel bands from 0 Hz to 8000 Hz (typical speech bandwidth).
3875 ///
3876 /// # Returns
3877 ///
3878 /// A `MelParams` instance with speech-optimized settings.
3879 #[inline]
3880 #[must_use]
3881 pub const fn speech_standard() -> Self {
3882 // safety: parameters are known to be valid
3883 unsafe { Self::new_unchecked(nzu!(40), 0.0, 8000.0) }
3884 }
3885}
3886
3887//
3888// ========================
3889// LogHz parameters
3890// ========================
3891//
3892
3893/// Logarithmic frequency scale parameters
3894///
3895/// * `n_bins`: Number of logarithmically-spaced frequency bins
3896/// * `f_min`: Minimum frequency (Hz)
3897/// * `f_max`: Maximum frequency (Hz)
3898#[derive(Debug, Clone, Copy, PartialEq)]
3899#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3900pub struct LogHzParams {
3901 n_bins: NonZeroUsize,
3902 f_min: f64,
3903 f_max: f64,
3904}
3905
3906impl LogHzParams {
3907 /// Create new logarithmic frequency scale parameters.
3908 ///
3909 /// # Arguments
3910 ///
3911 /// * `n_bins` - Number of logarithmically-spaced frequency bins
3912 /// * `f_min` - Minimum frequency (Hz)
3913 /// * `f_max` - Maximum frequency (Hz)
3914 ///
3915 /// # Errors
3916 ///
3917 /// Returns an error if:
3918 /// - `f_min` is not finite and > 0
3919 /// - `f_max` is not > `f_min`
3920 ///
3921 /// # Returns
3922 ///
3923 /// A `LogHzParams` instance.
3924 #[inline]
3925 pub fn new(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
3926 if !(f_min > 0.0 && f_min.is_finite()) {
3927 return Err(SpectrogramError::invalid_input(
3928 "f_min must be finite and > 0",
3929 ));
3930 }
3931
3932 if f_max <= f_min {
3933 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
3934 }
3935
3936 Ok(Self {
3937 n_bins,
3938 f_min,
3939 f_max,
3940 })
3941 }
3942
3943 const unsafe fn new_unchecked(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> Self {
3944 Self {
3945 n_bins,
3946 f_min,
3947 f_max,
3948 }
3949 }
3950
3951 /// Get the number of frequency bins.
3952 ///
3953 /// # Returns
3954 ///
3955 /// The number of frequency bins.
3956 #[inline]
3957 #[must_use]
3958 pub const fn n_bins(&self) -> NonZeroUsize {
3959 self.n_bins
3960 }
3961
3962 /// Get the minimum frequency (Hz).
3963 ///
3964 /// # Returns
3965 ///
3966 /// The minimum frequency in Hz.
3967 #[inline]
3968 #[must_use]
3969 pub const fn f_min(&self) -> f64 {
3970 self.f_min
3971 }
3972
3973 /// Get the maximum frequency (Hz).
3974 ///
3975 /// # Returns
3976 ///
3977 /// The maximum frequency in Hz.
3978 #[inline]
3979 #[must_use]
3980 pub const fn f_max(&self) -> f64 {
3981 self.f_max
3982 }
3983
3984 /// Create standard logarithmic frequency parameters.
3985 ///
3986 /// Uses 128 log bins from 20 Hz to the Nyquist frequency.
3987 ///
3988 /// # Arguments
3989 ///
3990 /// * `sample_rate` - Sample rate in Hz (used to determine `f_max`)
3991 #[inline]
3992 #[must_use]
3993 pub fn standard(sample_rate: f64) -> Self {
3994 // safety: parameters are known to be valid
3995 unsafe { Self::new_unchecked(nzu!(128), 20.0, sample_rate / 2.0) }
3996 }
3997
3998 /// Create logarithmic frequency parameters optimized for music.
3999 ///
4000 /// Uses 84 bins (7 octaves * 12 bins/octave) from 27.5 Hz (A0) to 4186 Hz (C8).
4001 #[inline]
4002 #[must_use]
4003 pub const fn music_standard() -> Self {
4004 // safety: parameters are known to be valid
4005 unsafe { Self::new_unchecked(nzu!(84), 27.5, 4186.0) }
4006 }
4007}
4008
4009//
4010// ========================
4011// Log scaling parameters
4012// ========================
4013//
4014
4015#[derive(Debug, Clone, Copy, PartialEq)]
4016#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
4017pub struct LogParams {
4018 floor_db: f64,
4019}
4020
4021impl LogParams {
4022 /// Create new logarithmic scaling parameters.
4023 ///
4024 /// # Arguments
4025 ///
4026 /// * `floor_db` - Minimum dB value (floor) for logarithmic scaling
4027 ///
4028 /// # Errors
4029 ///
4030 /// Returns an error if `floor_db` is not finite.
4031 ///
4032 /// # Returns
4033 ///
4034 /// A `LogParams` instance.
4035 #[inline]
4036 pub fn new(floor_db: f64) -> SpectrogramResult<Self> {
4037 if !floor_db.is_finite() {
4038 return Err(SpectrogramError::invalid_input("floor_db must be finite"));
4039 }
4040
4041 Ok(Self { floor_db })
4042 }
4043
4044 /// Create new logarithmic scaling parameters.
4045 ///
4046 /// # Arguments
4047 ///
4048 /// * `floor_db` - Minimum dB value (floor) for logarithmic scaling
4049 ///
4050
4051 ///
4052 /// # Returns
4053 ///
4054 /// A `LogParams` instance.
4055 #[inline]
4056 pub fn new_unchecked(floor_db: f64) -> Self {
4057 Self { floor_db }
4058 }
4059
4060 /// Get the floor dB value.
4061 #[inline]
4062 #[must_use]
4063 pub const fn floor_db(&self) -> f64 {
4064 self.floor_db
4065 }
4066}
4067
4068/// Spectrogram computation parameters.
4069///
4070/// * `stft`: STFT parameters
4071/// * `sample_rate_hz`: Sample rate in Hz
4072#[derive(Debug, Clone)]
4073#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
4074pub struct SpectrogramParams {
4075 pub(crate) stft: StftParams,
4076 pub(crate) sample_rate_hz: f64,
4077}
4078
4079impl SpectrogramParams {
4080 /// Create new spectrogram parameters.
4081 ///
4082 /// # Arguments
4083 ///
4084 /// * `stft` - STFT parameters
4085 /// * `sample_rate_hz` - Sample rate in Hz
4086 ///
4087 /// # Errors
4088 ///
4089 /// Returns an error if the sample rate is not positive and finite.
4090 ///
4091 /// # Returns
4092 ///
4093 /// A `SpectrogramParams` instance.
4094 #[inline]
4095 pub fn new(stft: StftParams, sample_rate_hz: f64) -> SpectrogramResult<Self> {
4096 if !(sample_rate_hz > 0.0 && sample_rate_hz.is_finite()) {
4097 return Err(SpectrogramError::invalid_input(
4098 "sample_rate_hz must be finite and > 0",
4099 ));
4100 }
4101
4102 Ok(Self {
4103 stft,
4104 sample_rate_hz,
4105 })
4106 }
4107
4108 /// Create new spectrogram parameters without checking the arguments.
4109 ///
4110 /// # Arguments
4111 ///
4112 /// * `stft` - STFT parameters
4113 /// * `sample_rate_hz` - Sample rate in Hz
4114 ///
4115 /// # Returns
4116 ///
4117 /// A `SpectrogramParams` instance.
4118 #[inline]
4119 pub fn new_unchecked(stft: StftParams, sample_rate_hz: f64) -> Self {
4120 Self {
4121 stft,
4122 sample_rate_hz,
4123 }
4124 }
4125
4126 /// Create a builder for spectrogram parameters.
4127 ///
4128 /// # Errors
4129 ///
4130 /// Returns an error if required parameters are not set or are invalid.
4131 ///
4132 /// # Returns
4133 ///
4134 /// A builder for [`SpectrogramParams`].
4135 ///
4136 /// # Examples
4137 ///
4138 /// ```
4139 /// use spectrograms::{SpectrogramParams, WindowType, nzu};
4140 ///
4141 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4142 /// let params = SpectrogramParams::builder()
4143 /// .sample_rate(16000.0)
4144 /// .n_fft(nzu!(512))
4145 /// .hop_size(nzu!(256))
4146 /// .window(WindowType::Hanning)
4147 /// .centre(true)
4148 /// .build()?;
4149 ///
4150 /// assert_eq!(params.sample_rate_hz(), 16000.0);
4151 /// # Ok(())
4152 /// # }
4153 /// ```
4154 #[inline]
4155 #[must_use]
4156 pub fn builder() -> SpectrogramParamsBuilder {
4157 SpectrogramParamsBuilder::default()
4158 }
4159
4160 /// Create default parameters for speech processing.
4161 ///
4162 /// # Arguments
4163 ///
4164 /// * `sample_rate_hz` - Sample rate in Hz
4165 ///
4166 /// # Returns
4167 ///
4168 /// A `SpectrogramParams` instance with default settings for music analysis.
4169 ///
4170 /// # Errors
4171 ///
4172 /// Returns an error if the sample rate is not positive and finite.
4173 ///
4174 /// Uses:
4175 /// - `n_fft`: 512 (32ms at 16kHz)
4176 /// - `hop_size`: 160 (10ms at 16kHz)
4177 /// - window: Hanning
4178 /// - centre: true
4179 #[inline]
4180 pub fn speech_default(sample_rate_hz: f64) -> SpectrogramResult<Self> {
4181 // safety: parameters are known to be valid
4182 let stft =
4183 unsafe { StftParams::new_unchecked(nzu!(512), nzu!(160), WindowType::Hanning, true) };
4184
4185 Self::new(stft, sample_rate_hz)
4186 }
4187
4188 /// Create default parameters for music processing.
4189 ///
4190 /// # Arguments
4191 ///
4192 /// * `sample_rate_hz` - Sample rate in Hz
4193 ///
4194 /// # Returns
4195 ///
4196 /// A `SpectrogramParams` instance with default settings for music analysis.
4197 ///
4198 /// # Errors
4199 ///
4200 /// Returns an error if the sample rate is not positive and finite.
4201 ///
4202 /// Uses:
4203 /// - `n_fft`: 2048 (46ms at 44.1kHz)
4204 /// - `hop_size`: 512 (11.6ms at 44.1kHz)
4205 /// - window: Hanning
4206 /// - centre: true
4207 #[inline]
4208 pub fn music_default(sample_rate_hz: f64) -> SpectrogramResult<Self> {
4209 // safety: parameters are known to be valid
4210 let stft =
4211 unsafe { StftParams::new_unchecked(nzu!(2048), nzu!(512), WindowType::Hanning, true) };
4212 Self::new(stft, sample_rate_hz)
4213 }
4214
4215 /// Get the STFT parameters.
4216 #[inline]
4217 #[must_use]
4218 pub const fn stft(&self) -> &StftParams {
4219 &self.stft
4220 }
4221
4222 /// Get the sample rate in Hz.
4223 #[inline]
4224 #[must_use]
4225 pub const fn sample_rate_hz(&self) -> f64 {
4226 self.sample_rate_hz
4227 }
4228
4229 /// Get the frame period in seconds.
4230 #[inline]
4231 #[must_use]
4232 #[allow(clippy::cast_precision_loss)]
4233 pub fn frame_period_seconds(&self) -> f64 {
4234 self.stft.hop_size().get() as f64 / self.sample_rate_hz
4235 }
4236
4237 /// Get the Nyquist frequency in Hz.
4238 #[inline]
4239 #[must_use]
4240 pub fn nyquist_hz(&self) -> f64 {
4241 self.sample_rate_hz * 0.5
4242 }
4243}
4244
4245/// Builder for [`SpectrogramParams`].
4246#[derive(Debug, Clone)]
4247pub struct SpectrogramParamsBuilder {
4248 sample_rate: Option<f64>,
4249 n_fft: Option<NonZeroUsize>,
4250 hop_size: Option<NonZeroUsize>,
4251 window: WindowType,
4252 centre: bool,
4253}
4254
4255impl Default for SpectrogramParamsBuilder {
4256 #[inline]
4257 fn default() -> Self {
4258 Self {
4259 sample_rate: None,
4260 n_fft: None,
4261 hop_size: None,
4262 window: WindowType::Hanning,
4263 centre: true,
4264 }
4265 }
4266}
4267
4268impl SpectrogramParamsBuilder {
4269 /// Set the sample rate in Hz.
4270 ///
4271 /// # Arguments
4272 ///
4273 /// * `sample_rate` - Sample rate in Hz.
4274 ///
4275 /// # Returns
4276 ///
4277 /// The updated builder instance.
4278 #[inline]
4279 #[must_use]
4280 pub const fn sample_rate(mut self, sample_rate: f64) -> Self {
4281 self.sample_rate = Some(sample_rate);
4282 self
4283 }
4284
4285 /// Set the FFT window size.
4286 ///
4287 /// # Arguments
4288 ///
4289 /// * `n_fft` - FFT size.
4290 ///
4291 /// # Returns
4292 ///
4293 /// The updated builder instance.
4294 #[inline]
4295 #[must_use]
4296 pub const fn n_fft(mut self, n_fft: NonZeroUsize) -> Self {
4297 self.n_fft = Some(n_fft);
4298 self
4299 }
4300
4301 /// Set the hop size (samples between successive frames).
4302 ///
4303 /// # Arguments
4304 ///
4305 /// * `hop_size` - Hop size in samples.
4306 ///
4307 /// # Returns
4308 ///
4309 /// The updated builder instance.
4310 #[inline]
4311 #[must_use]
4312 pub const fn hop_size(mut self, hop_size: NonZeroUsize) -> Self {
4313 self.hop_size = Some(hop_size);
4314 self
4315 }
4316
4317 /// Set the window function.
4318 ///
4319 /// # Arguments
4320 ///
4321 /// * `window` - Window function to apply to each frame.
4322 ///
4323 /// # Returns
4324 ///
4325 /// The updated builder instance.
4326 #[inline]
4327 #[must_use]
4328 pub fn window(mut self, window: WindowType) -> Self {
4329 self.window = window;
4330 self
4331 }
4332
4333 /// Set whether to center frames (pad input signal).
4334 ///
4335 /// # Arguments
4336 ///
4337 /// * `centre` - If true, frames are centered by padding the input signal.
4338 ///
4339 /// # Returns
4340 ///
4341 /// The updated builder instance.
4342 #[inline]
4343 #[must_use]
4344 pub const fn centre(mut self, centre: bool) -> Self {
4345 self.centre = centre;
4346 self
4347 }
4348
4349 /// Build the [`SpectrogramParams`].
4350 ///
4351 /// # Errors
4352 ///
4353 /// Returns an error if required parameters are not set or are invalid.
4354 ///
4355 /// # Returns
4356 ///
4357 /// A `SpectrogramParams` instance.
4358 #[inline]
4359 pub fn build(self) -> SpectrogramResult<SpectrogramParams> {
4360 let sample_rate = self
4361 .sample_rate
4362 .ok_or_else(|| SpectrogramError::invalid_input("sample_rate must be set"))?;
4363 let n_fft = self
4364 .n_fft
4365 .ok_or_else(|| SpectrogramError::invalid_input("n_fft must be set"))?;
4366 let hop_size = self
4367 .hop_size
4368 .ok_or_else(|| SpectrogramError::invalid_input("hop_size must be set"))?;
4369
4370 let stft = StftParams::new(n_fft, hop_size, self.window, self.centre)?;
4371 SpectrogramParams::new(stft, sample_rate)
4372 }
4373
4374 /// Build the [`SpectrogramParams`].
4375 ///
4376 /// # Safety
4377 ///
4378 /// The caller is responsible for ensuring the 'n_fft', 'hop_size' are set.
4379 ///
4380 /// # Returns
4381 ///
4382 /// A `SpectrogramParams` instance.
4383 #[inline]
4384 pub unsafe fn build_unchecked(self) -> SpectrogramParams {
4385 // safety: is the repsonsibility of the caller
4386 unsafe {
4387 let n_fft = self.n_fft.unwrap_unchecked();
4388 let hop_size = self.hop_size.unwrap_unchecked();
4389 let stft = StftParams::new_unchecked(n_fft, hop_size, self.window, self.centre);
4390 let sample_rate = self.sample_rate.unwrap_unchecked();
4391 SpectrogramParams::new_unchecked(stft, sample_rate)
4392 }
4393
4394 }
4395
4396}
4397
4398//
4399// ========================
4400// Standalone FFT Functions
4401// ========================
4402//
4403
4404/// Compute the real-to-complex FFT of a real-valued signal.
4405///
4406/// This function performs a forward FFT on real-valued input, returning the
4407/// complex frequency domain representation. Only the positive frequencies
4408/// are returned (length = `n_fft/2` + 1) due to conjugate symmetry.
4409///
4410/// # Arguments
4411///
4412/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4413/// * `n_fft` - FFT size
4414///
4415/// # Returns
4416///
4417/// A vector of complex frequency bins with length `n_fft/2` + 1.
4418///
4419/// # Automatic Zero-Padding
4420///
4421/// If the input signal is shorter than `n_fft`, it will be automatically
4422/// zero-padded to the required length. This is standard DSP practice and
4423/// preserves frequency resolution (bin spacing = sample_rate / n_fft).
4424///
4425/// ```
4426/// use spectrograms::{fft, nzu};
4427/// use non_empty_slice::non_empty_vec;
4428///
4429/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4430/// let signal = non_empty_vec![1.0, 2.0, 3.0]; // Only 3 samples
4431/// let spectrum = fft(&signal, nzu!(8))?; // Automatically padded to 8
4432/// assert_eq!(spectrum.len(), 5); // Output: 8/2 + 1 = 5 bins
4433/// # Ok(())
4434/// # }
4435/// ```
4436///
4437/// # Errors
4438///
4439/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4440///
4441/// # Examples
4442///
4443/// ```
4444/// use spectrograms::*;
4445/// use non_empty_slice::non_empty_vec;
4446///
4447/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4448/// let signal = non_empty_vec![0.0; nzu!(512)];
4449/// let spectrum = fft(&signal, nzu!(512))?;
4450///
4451/// assert_eq!(spectrum.len(), 257); // 512/2 + 1
4452/// # Ok(())
4453/// # }
4454/// ```
4455#[inline]
4456pub fn fft(
4457 samples: &NonEmptySlice<f64>,
4458 n_fft: NonZeroUsize,
4459) -> SpectrogramResult<Array1<Complex<f64>>> {
4460 if samples.len() > n_fft {
4461 return Err(SpectrogramError::invalid_input(format!(
4462 "Input length ({}) exceeds FFT size ({})",
4463 samples.len(),
4464 n_fft
4465 )));
4466 }
4467
4468 let out_len = r2c_output_size(n_fft.get());
4469
4470 // Get FFT plan from global cache (or create if first use)
4471 #[cfg(feature = "realfft")]
4472 let mut fft = {
4473 use crate::fft_backend::get_or_create_r2c_plan;
4474 let plan = get_or_create_r2c_plan(n_fft.get())?;
4475 // Clone the plan to get our own mutable copy with independent scratch buffer
4476 // This is cheap - only clones the scratch buffer, not the expensive twiddle factors
4477 (*plan).clone()
4478 };
4479
4480 #[cfg(feature = "fftw")]
4481 let mut fft = {
4482 use std::sync::Arc;
4483 let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
4484 crate::FftwPlan::new(Arc::new(plan))
4485 };
4486
4487 let input = if samples.len() < n_fft {
4488 let mut padded = vec![0.0; n_fft.get()];
4489 padded[..samples.len().get()].copy_from_slice(samples);
4490 // safety: samples.len() < n_fft checked above and n_fft > 0
4491
4492 unsafe { NonEmptyVec::new_unchecked(padded) }
4493 } else {
4494 samples.to_non_empty_vec()
4495 };
4496
4497 let mut output = vec![Complex::new(0.0, 0.0); out_len];
4498 fft.process(&input, &mut output)?;
4499 let output = Array1::from_vec(output);
4500 Ok(output)
4501}
4502
4503#[inline]
4504/// Compute the real-valued fft of a signal.
4505///
4506/// # Arguments
4507/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4508/// * `n_fft` - FFT size
4509///
4510/// # Returns
4511///
4512/// An array with length `n_fft/2` + 1.
4513///
4514/// # Errors
4515///
4516/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4517///
4518/// # Examples
4519///
4520/// ```
4521/// use spectrograms::*;
4522/// use non_empty_slice::non_empty_vec;
4523///
4524/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4525/// let signal = non_empty_vec![0.0; nzu!(512)];
4526/// let rfft_result = rfft(&signal, nzu!(512))?;
4527/// // equivalent to
4528/// let fft_result = fft(&signal, nzu!(512))?;
4529/// let rfft_result = fft_result.mapv(num_complex::Complex::norm);
4530/// # Ok(())
4531/// # }
4532///
4533pub fn rfft(samples: &NonEmptySlice<f64>, n_fft: NonZeroUsize) -> SpectrogramResult<Array1<f64>> {
4534 Ok(fft(samples, n_fft)?.mapv(Complex::norm))
4535}
4536
4537/// Compute the power spectrum of a signal (|X|²).
4538///
4539/// This function applies an optional window function and computes the
4540/// power spectrum via FFT. The result contains only positive frequencies.
4541///
4542/// # Arguments
4543///
4544/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4545/// * `n_fft` - FFT size
4546/// * `window` - Optional window function (None for rectangular window)
4547///
4548/// # Returns
4549///
4550/// A vector of power values with length `n_fft/2` + 1.
4551///
4552/// # Automatic Zero-Padding
4553///
4554/// If the input signal is shorter than `n_fft`, it will be automatically
4555/// zero-padded to the required length. This is standard DSP practice and
4556/// preserves frequency resolution (bin spacing = sample_rate / n_fft).
4557///
4558/// ```
4559/// use spectrograms::{power_spectrum, nzu};
4560/// use non_empty_slice::non_empty_vec;
4561///
4562/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4563/// let signal = non_empty_vec![1.0, 2.0, 3.0]; // Only 3 samples
4564/// let power = power_spectrum(&signal, nzu!(8), None)?;
4565/// assert_eq!(power.len(), nzu!(5)); // Output: 8/2 + 1 = 5 bins
4566/// # Ok(())
4567/// # }
4568/// ```
4569///
4570/// # Errors
4571///
4572/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4573///
4574/// # Examples
4575///
4576/// ```
4577/// use spectrograms::*;
4578/// use non_empty_slice::non_empty_vec;
4579///
4580/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4581/// let signal = non_empty_vec![0.0; nzu!(512)];
4582/// let power = power_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4583///
4584/// assert_eq!(power.len(), nzu!(257)); // 512/2 + 1
4585/// # Ok(())
4586/// # }
4587/// ```
4588#[inline]
4589pub fn power_spectrum(
4590 samples: &NonEmptySlice<f64>,
4591 n_fft: NonZeroUsize,
4592 window: Option<WindowType>,
4593) -> SpectrogramResult<NonEmptyVec<f64>> {
4594 if samples.len() > n_fft {
4595 return Err(SpectrogramError::invalid_input(format!(
4596 "Input length ({}) exceeds FFT size ({})",
4597 samples.len(),
4598 n_fft
4599 )));
4600 }
4601
4602 let mut windowed = vec![0.0; n_fft.get()];
4603 windowed[..samples.len().get()].copy_from_slice(samples);
4604
4605 if let Some(win_type) = window {
4606 let window_samples = make_window(win_type, n_fft);
4607 for i in 0..n_fft.get() {
4608 windowed[i] *= window_samples[i];
4609 }
4610 }
4611
4612 // safety: windowed is non-empty since n_fft > 0
4613 let windowed = unsafe { NonEmptySlice::new_unchecked(&windowed) };
4614 let fft_result = fft(windowed, n_fft)?;
4615 let fft_result = fft_result
4616 .iter()
4617 .map(num_complex::Complex::norm_sqr)
4618 .collect();
4619 // safety: fft_result is non-empty since fft returned successfully
4620 Ok(unsafe { NonEmptyVec::new_unchecked(fft_result) })
4621}
4622
4623/// Compute the magnitude spectrum of a signal (|X|).
4624///
4625/// This function applies an optional window function and computes the
4626/// magnitude spectrum via FFT. The result contains only positive frequencies.
4627///
4628/// # Arguments
4629///
4630/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4631/// * `n_fft` - FFT size
4632/// * `window` - Optional window function (None for rectangular window)
4633///
4634/// # Automatic Zero-Padding
4635///
4636/// If the input signal is shorter than `n_fft`, it will be automatically
4637/// zero-padded to the required length. This preserves frequency resolution.
4638///
4639/// # Errors
4640///
4641/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4642///
4643/// # Returns
4644///
4645/// A vector of magnitude values with length `n_fft/2` + 1.
4646///
4647/// # Examples
4648///
4649/// ```
4650/// use spectrograms::*;
4651/// use non_empty_slice::non_empty_vec;
4652///
4653/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4654/// let signal = non_empty_vec![0.0; nzu!(512)];
4655/// let magnitude = magnitude_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4656///
4657/// assert_eq!(magnitude.len(), nzu!(257)); // 512/2 + 1
4658/// # Ok(())
4659/// # }
4660/// ```
4661#[inline]
4662pub fn magnitude_spectrum(
4663 samples: &NonEmptySlice<f64>,
4664 n_fft: NonZeroUsize,
4665 window: Option<WindowType>,
4666) -> SpectrogramResult<NonEmptyVec<f64>> {
4667 let power = power_spectrum(samples, n_fft, window)?;
4668 let power = power.iter().map(|&p| p.sqrt()).collect();
4669 // safety: power is non-empty since power_spectrum returned successfully
4670 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
4671}
4672
4673/// Compute the Short-Time Fourier Transform (STFT) of a signal.
4674///
4675/// This function computes the STFT by applying a sliding window and FFT
4676/// to sequential frames of the input signal.
4677///
4678/// # Arguments
4679///
4680/// * `samples` - Input signal (any type that can be converted to a slice)
4681/// * `n_fft` - FFT size
4682/// * `hop_size` - Number of samples between successive frames
4683/// * `window` - Window function to apply to each frame
4684/// * `center` - If true, pad the signal to center frames
4685///
4686/// # Returns
4687///
4688/// A 2D array with shape (`frequency_bins`, `time_frames`) containing complex STFT values.
4689///
4690/// # Errors
4691///
4692/// Returns an error if:
4693/// - `hop_size` > `n_fft`
4694/// - STFT computation fails
4695///
4696/// # Examples
4697///
4698/// ```
4699/// use spectrograms::*;
4700/// use non_empty_slice::non_empty_vec;
4701///
4702/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4703/// let signal = non_empty_vec![0.0; nzu!(16000)];
4704/// let stft_matrix = stft(&signal, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4705///
4706/// println!("STFT: {} bins x {} frames", stft_matrix.nrows(), stft_matrix.ncols());
4707/// # Ok(())
4708/// # }
4709/// ```
4710#[inline]
4711pub fn stft(
4712 samples: &NonEmptySlice<f64>,
4713 n_fft: NonZeroUsize,
4714 hop_size: NonZeroUsize,
4715 window: WindowType,
4716 center: bool,
4717) -> SpectrogramResult<Array2<Complex<f64>>> {
4718 let stft_params = StftParams::new(n_fft, hop_size, window, center)?;
4719 let params = SpectrogramParams::new(stft_params, 1.0)?; // dummy sample rate
4720
4721 let planner = SpectrogramPlanner::new();
4722 let result = planner.compute_stft(samples, ¶ms)?;
4723
4724 Ok(result.data)
4725}
4726
4727/// Compute the inverse real FFT (complex-to-real IFFT).
4728///
4729/// This function performs an inverse FFT, converting frequency domain data
4730/// back to the time domain. Only the positive frequencies need to be provided
4731/// (length = `n_fft/2` + 1) due to conjugate symmetry.
4732///
4733/// # Arguments
4734///
4735/// * `spectrum` - Complex frequency bins (length should be `n_fft/2` + 1)
4736/// * `n_fft` - FFT size (length of the output signal)
4737///
4738/// # Returns
4739///
4740/// A vector of real-valued time-domain samples with length `n_fft`.
4741///
4742/// # Errors
4743///
4744/// Returns an error if:
4745/// - `spectrum` length doesn't match `n_fft/2` + 1
4746/// - Inverse FFT computation fails
4747///
4748/// # Examples
4749///
4750/// ```
4751/// use spectrograms::*;
4752/// use non_empty_slice::{non_empty_vec, NonEmptySlice};
4753/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4754/// // Forward FFT
4755/// let signal = non_empty_vec![1.0, 0.0, -1.0, 0.0, 1.0, 0.0, -1.0, 0.0];
4756/// let spectrum = fft(&signal, nzu!(8))?;
4757/// let slice = spectrum.as_slice().unwrap();
4758/// let spectrum_slice = NonEmptySlice::new(slice).unwrap();
4759/// // Inverse FFT
4760/// let reconstructed = irfft(spectrum_slice, nzu!(8))?;
4761///
4762/// assert_eq!(reconstructed.len(), nzu!(8));
4763/// # Ok(())
4764/// # }
4765/// ```
4766#[inline]
4767pub fn irfft(
4768 spectrum: &NonEmptySlice<Complex<f64>>,
4769 n_fft: NonZeroUsize,
4770) -> SpectrogramResult<NonEmptyVec<f64>> {
4771 use crate::fft_backend::{C2rPlan, r2c_output_size};
4772
4773 let n_fft = n_fft.get();
4774 let expected_len = r2c_output_size(n_fft);
4775 if spectrum.len().get() != expected_len {
4776 return Err(SpectrogramError::dimension_mismatch(
4777 expected_len,
4778 spectrum.len().get(),
4779 ));
4780 }
4781
4782 // Get inverse FFT plan from global cache (or create if first use)
4783 #[cfg(feature = "realfft")]
4784 let mut ifft = {
4785 use crate::fft_backend::get_or_create_c2r_plan;
4786 let plan = get_or_create_c2r_plan(n_fft)?;
4787 // Clone to get our own mutable copy with independent scratch buffer
4788 (*plan).clone()
4789 };
4790
4791 #[cfg(feature = "fftw")]
4792 let mut ifft = {
4793 use crate::fft_backend::C2rPlanner;
4794 let mut planner = crate::FftwPlanner::new();
4795 planner.plan_c2r(n_fft)?
4796 };
4797
4798 let mut output = vec![0.0; n_fft];
4799 ifft.process(spectrum.as_slice(), &mut output)?;
4800
4801 // Safety: output is non-empty since n_fft > 0
4802 Ok(unsafe { NonEmptyVec::new_unchecked(output) })
4803}
4804
4805/// Reconstruct a time-domain signal from its STFT using overlap-add.
4806///
4807/// This function performs the inverse Short-Time Fourier Transform, converting
4808/// a 2D complex STFT matrix back to a 1D time-domain signal using overlap-add
4809/// synthesis with the specified window function.
4810///
4811/// # Arguments
4812///
4813/// * `stft_matrix` - Complex STFT with shape (`frequency_bins`, `time_frames`)
4814/// * `n_fft` - FFT size
4815/// * `hop_size` - Number of samples between successive frames
4816/// * `window` - Window function to apply (should match forward STFT window)
4817/// * `center` - If true, assume the forward STFT was centered
4818///
4819/// # Returns
4820///
4821/// A vector of reconstructed time-domain samples.
4822///
4823/// # Errors
4824///
4825/// Returns an error if:
4826/// - `stft_matrix` dimensions are inconsistent with `n_fft`
4827/// - `hop_size` > `n_fft`
4828/// - Inverse STFT computation fails
4829///
4830/// # Examples
4831///
4832/// ```
4833/// use spectrograms::*;
4834/// use non_empty_slice::non_empty_vec;
4835///
4836/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4837/// // Generate signal
4838/// let signal = non_empty_vec![1.0; nzu!(16000)];
4839///
4840/// // Forward STFT
4841/// let stft_matrix = stft(&signal, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4842///
4843/// // Inverse STFT
4844/// let reconstructed = istft(&stft_matrix, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4845///
4846/// println!("Original: {} samples", signal.len());
4847/// println!("Reconstructed: {} samples", reconstructed.len());
4848/// # Ok(())
4849/// # }
4850/// ```
4851#[inline]
4852pub fn istft(
4853 stft_matrix: &Array2<Complex<f64>>,
4854 n_fft: NonZeroUsize,
4855 hop_size: NonZeroUsize,
4856 window: WindowType,
4857 center: bool,
4858) -> SpectrogramResult<NonEmptyVec<f64>> {
4859 use crate::fft_backend::{C2rPlan, C2rPlanner, r2c_output_size};
4860
4861 let n_bins = stft_matrix.nrows();
4862 let n_frames = stft_matrix.ncols();
4863
4864 let expected_bins = r2c_output_size(n_fft.get());
4865 if n_bins != expected_bins {
4866 return Err(SpectrogramError::dimension_mismatch(expected_bins, n_bins));
4867 }
4868 if hop_size.get() > n_fft.get() {
4869 return Err(SpectrogramError::invalid_input("hop_size must be <= n_fft"));
4870 }
4871 // Create inverse FFT plan
4872 #[cfg(feature = "realfft")]
4873 let mut ifft = {
4874 let mut planner = crate::RealFftPlanner::new();
4875 planner.plan_c2r(n_fft.get())?
4876 };
4877
4878 #[cfg(feature = "fftw")]
4879 let mut ifft = {
4880 let mut planner = crate::FftwPlanner::new();
4881 planner.plan_c2r(n_fft.get())?
4882 };
4883
4884 // Generate window
4885 let window_samples = make_window(window, n_fft);
4886 let n_fft = n_fft.get();
4887 let hop_size = hop_size.get();
4888 // Calculate output length
4889 let pad = if center { n_fft / 2 } else { 0 };
4890 let output_len = (n_frames - 1) * hop_size + n_fft;
4891 // safety: output_len > 0 since n_frames > 0 and n_fft, hop_size > 0
4892 let output_len = unsafe { NonZeroUsize::new_unchecked(output_len) };
4893 let unpadded_len = output_len.get().saturating_sub(2 * pad);
4894
4895 // Allocate output buffer and normalization buffer
4896 let mut output = non_empty_vec![0.0; output_len];
4897 let mut norm = non_empty_vec![0.0; output_len];
4898
4899 // Overlap-add synthesis
4900 let mut frame_buffer = vec![Complex::new(0.0, 0.0); n_bins];
4901 let mut time_frame = vec![0.0; n_fft];
4902
4903 for frame_idx in 0..n_frames {
4904 // Extract complex frame from STFT matrix
4905 for bin_idx in 0..n_bins {
4906 frame_buffer[bin_idx] = stft_matrix[[bin_idx, frame_idx]];
4907 }
4908
4909 // Inverse FFT
4910 ifft.process(&frame_buffer, &mut time_frame)?;
4911
4912 // Apply window
4913 for i in 0..n_fft {
4914 time_frame[i] *= window_samples[i];
4915 }
4916
4917 // Overlap-add into output buffer
4918 let start = frame_idx * hop_size;
4919 for i in 0..n_fft {
4920 let pos = start + i;
4921 if pos < output_len.get() {
4922 output[pos] += time_frame[i];
4923 norm[pos] += window_samples[i] * window_samples[i];
4924 }
4925 }
4926 }
4927
4928 // Normalize by window energy
4929 for i in 0..output_len.get() {
4930 if norm[i] > 1e-10 {
4931 output[i] /= norm[i];
4932 }
4933 }
4934
4935 // Remove padding if centered
4936 if center && unpadded_len > 0 {
4937 let start = pad;
4938 let end = start + unpadded_len;
4939 // safety: start < end <= output_len, therefore slice is non-empty
4940 output = unsafe {
4941 NonEmptySlice::new_unchecked(&output[start..end.min(output_len.get())])
4942 .to_non_empty_vec()
4943 };
4944 }
4945
4946 Ok(output)
4947}
4948
4949//
4950// ========================
4951// Reusable FFT Plans
4952// ========================
4953//
4954
4955/// A reusable FFT planner for efficient repeated FFT operations.
4956///
4957/// This planner caches FFT plans internally, making repeated FFT operations
4958/// of the same size much more efficient than calling `fft()` repeatedly.
4959///
4960/// # Examples
4961///
4962/// ```
4963/// use spectrograms::*;
4964/// use non_empty_slice::non_empty_vec;
4965///
4966/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4967/// let mut planner = FftPlanner::new();
4968///
4969/// // Process multiple signals of the same size efficiently
4970/// for _ in 0..100 {
4971/// let signal = non_empty_vec![0.0; nzu!(512)];
4972/// let spectrum = planner.fft(&signal, nzu!(512))?;
4973/// // ... process spectrum ...
4974/// }
4975/// # Ok(())
4976/// # }
4977/// ```
4978pub struct FftPlanner {
4979 #[cfg(feature = "realfft")]
4980 inner: crate::RealFftPlanner,
4981 #[cfg(feature = "fftw")]
4982 inner: crate::FftwPlanner,
4983}
4984
4985impl FftPlanner {
4986 /// Create a new FFT planner with empty cache.
4987 #[inline]
4988 #[must_use]
4989 pub fn new() -> Self {
4990 Self {
4991 #[cfg(feature = "realfft")]
4992 inner: crate::RealFftPlanner::new(),
4993 #[cfg(feature = "fftw")]
4994 inner: crate::FftwPlanner::new(),
4995 }
4996 }
4997
4998 /// Compute forward FFT, reusing cached plans.
4999 ///
5000 /// This is more efficient than calling the standalone `fft()` function
5001 /// repeatedly for the same FFT size.
5002 ///
5003 /// # Automatic Zero-Padding
5004 ///
5005 /// If the input signal is shorter than `n_fft`, it will be automatically
5006 /// zero-padded to the required length.
5007 ///
5008 /// # Errors
5009 ///
5010 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
5011 ///
5012 /// # Examples
5013 ///
5014 /// ```
5015 /// use spectrograms::*;
5016 /// use non_empty_slice::non_empty_vec;
5017 ///
5018 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
5019 /// let mut planner = FftPlanner::new();
5020 ///
5021 /// let signal = non_empty_vec![1.0; nzu!(512)];
5022 /// let spectrum = planner.fft(&signal, nzu!(512))?;
5023 ///
5024 /// assert_eq!(spectrum.len(), 257); // 512/2 + 1
5025 /// # Ok(())
5026 /// # }
5027 /// ```
5028 #[inline]
5029 pub fn fft(
5030 &mut self,
5031 samples: &NonEmptySlice<f64>,
5032 n_fft: NonZeroUsize,
5033 ) -> SpectrogramResult<Array1<Complex<f64>>> {
5034 use crate::fft_backend::{R2cPlan, R2cPlanner, r2c_output_size};
5035
5036 if samples.len() > n_fft {
5037 return Err(SpectrogramError::invalid_input(format!(
5038 "Input length ({}) exceeds FFT size ({})",
5039 samples.len(),
5040 n_fft
5041 )));
5042 }
5043
5044 let out_len = r2c_output_size(n_fft.get());
5045 let mut plan = self.inner.plan_r2c(n_fft.get())?;
5046
5047 let input = if samples.len() < n_fft {
5048 let mut padded = vec![0.0; n_fft.get()];
5049 padded[..samples.len().get()].copy_from_slice(samples);
5050
5051 // safety: samples.len() < n_fft checked above and n_fft > 0
5052 unsafe { NonEmptyVec::new_unchecked(padded) }
5053 } else {
5054 samples.to_non_empty_vec()
5055 };
5056
5057 let mut output = vec![Complex::new(0.0, 0.0); out_len];
5058 plan.process(&input, &mut output)?;
5059
5060 let output = Array1::from_vec(output);
5061 Ok(output)
5062 }
5063
5064 /// Compute forward real FFT magnitude
5065 ///
5066 /// # Errors
5067 ///
5068 /// Returns an error if:
5069 /// - `n_fft` doesn't match the samples length
5070 ///
5071 ///
5072 #[inline]
5073 pub fn rfft(
5074 &mut self,
5075 samples: &NonEmptySlice<f64>,
5076 n_fft: NonZeroUsize,
5077 ) -> SpectrogramResult<Array1<f64>> {
5078 let fft_with_complex = fft(samples, n_fft)?;
5079 Ok(fft_with_complex.mapv(Complex::norm))
5080 }
5081
5082 /// Compute inverse FFT, reusing cached plans.
5083 ///
5084 /// This is more efficient than calling the standalone `irfft()` function
5085 /// repeatedly for the same FFT size.
5086 ///
5087 /// # Errors
5088 /// Returns an error if:
5089 ///
5090 /// - The calculated expected length of `spectrum` doesn't match its actual length
5091 ///
5092 /// # Examples
5093 ///
5094 /// ```
5095 /// use spectrograms::*;
5096 /// use non_empty_slice::{non_empty_vec, NonEmptySlice};
5097 ///
5098 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
5099 /// let mut planner = FftPlanner::new();
5100 ///
5101 /// // Forward FFT
5102 /// let signal = non_empty_vec![1.0; nzu!(512)];
5103 /// let spectrum = planner.fft(&signal, nzu!(512))?;
5104 ///
5105 /// // Inverse FFT
5106 /// let spectrum_slice = NonEmptySlice::new(spectrum.as_slice().unwrap()).unwrap();
5107 /// let reconstructed = planner.irfft(spectrum_slice, nzu!(512))?;
5108 ///
5109 /// assert_eq!(reconstructed.len(), nzu!(512));
5110 /// # Ok(())
5111 /// # }
5112 /// ```
5113 #[inline]
5114 pub fn irfft(
5115 &mut self,
5116 spectrum: &NonEmptySlice<Complex<f64>>,
5117 n_fft: NonZeroUsize,
5118 ) -> SpectrogramResult<NonEmptyVec<f64>> {
5119 use crate::fft_backend::{C2rPlan, C2rPlanner, r2c_output_size};
5120
5121 let expected_len = r2c_output_size(n_fft.get());
5122 if spectrum.len().get() != expected_len {
5123 return Err(SpectrogramError::dimension_mismatch(
5124 expected_len,
5125 spectrum.len().get(),
5126 ));
5127 }
5128
5129 let mut plan = self.inner.plan_c2r(n_fft.get())?;
5130 let mut output = vec![0.0; n_fft.get()];
5131 plan.process(spectrum, &mut output)?;
5132 // Safety: output is non-empty since n_fft > 0
5133 let output = unsafe { NonEmptyVec::new_unchecked(output) };
5134 Ok(output)
5135 }
5136
5137 /// Compute power spectrum with optional windowing, reusing cached plans.
5138 ///
5139 /// # Automatic Zero-Padding
5140 ///
5141 /// If the input signal is shorter than `n_fft`, it will be automatically
5142 /// zero-padded to the required length.
5143 ///
5144 /// # Errors
5145 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
5146 ///
5147 /// # Examples
5148 ///
5149 /// ```
5150 /// use spectrograms::*;
5151 /// use non_empty_slice::non_empty_vec;
5152 ///
5153 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
5154 /// let mut planner = FftPlanner::new();
5155 ///
5156 /// let signal = non_empty_vec![1.0; nzu!(512)];
5157 /// let power = planner.power_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
5158 ///
5159 /// assert_eq!(power.len(), nzu!(257));
5160 /// # Ok(())
5161 /// # }
5162 /// ```
5163 #[inline]
5164 pub fn power_spectrum(
5165 &mut self,
5166 samples: &NonEmptySlice<f64>,
5167 n_fft: NonZeroUsize,
5168 window: Option<WindowType>,
5169 ) -> SpectrogramResult<NonEmptyVec<f64>> {
5170 if samples.len() > n_fft {
5171 return Err(SpectrogramError::invalid_input(format!(
5172 "Input length ({}) exceeds FFT size ({})",
5173 samples.len(),
5174 n_fft
5175 )));
5176 }
5177
5178 let mut windowed = vec![0.0; n_fft.get()];
5179 windowed[..samples.len().get()].copy_from_slice(samples);
5180 if let Some(win_type) = window {
5181 let window_samples = make_window(win_type, n_fft);
5182 for i in 0..n_fft.get() {
5183 windowed[i] *= window_samples[i];
5184 }
5185 }
5186
5187 // safety: windowed is non-empty since n_fft > 0
5188 let windowed = unsafe { NonEmptySlice::new_unchecked(&windowed) };
5189 let fft_result = self.fft(windowed, n_fft)?;
5190 let f = fft_result
5191 .iter()
5192 .map(num_complex::Complex::norm_sqr)
5193 .collect();
5194 // safety: fft_result is non-empty since fft returned successfully
5195 Ok(unsafe { NonEmptyVec::new_unchecked(f) })
5196 }
5197
5198 /// Compute magnitude spectrum with optional windowing, reusing cached plans.
5199 ///
5200 /// # Automatic Zero-Padding
5201 ///
5202 /// If the input signal is shorter than `n_fft`, it will be automatically
5203 /// zero-padded to the required length.
5204 ///
5205 /// # Errors
5206 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
5207 ///
5208 /// # Examples
5209 ///
5210 /// ```
5211 /// use spectrograms::*;
5212 /// use non_empty_slice::non_empty_vec;
5213 ///
5214 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
5215 /// let mut planner = FftPlanner::new();
5216 ///
5217 /// let signal = non_empty_vec![1.0; nzu!(512)];
5218 /// let magnitude = planner.magnitude_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
5219 ///
5220 /// assert_eq!(magnitude.len(), nzu!(257));
5221 /// # Ok(())
5222 /// # }
5223 /// ```
5224 #[inline]
5225 pub fn magnitude_spectrum(
5226 &mut self,
5227 samples: &NonEmptySlice<f64>,
5228 n_fft: NonZeroUsize,
5229 window: Option<WindowType>,
5230 ) -> SpectrogramResult<NonEmptyVec<f64>> {
5231 let power = self.power_spectrum(samples, n_fft, window)?;
5232 let power = power.iter().map(|&p| p.sqrt()).collect::<Vec<f64>>();
5233 // safety: power is non-empty since power_spectrum returned successfully
5234 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
5235 }
5236}
5237
5238impl Default for FftPlanner {
5239 #[inline]
5240 fn default() -> Self {
5241 Self::new()
5242 }
5243}
5244
5245#[cfg(test)]
5246mod tests {
5247 use super::*;
5248
5249 #[test]
5250 fn test_sparse_matrix_basic() {
5251 // Create a simple 3x5 sparse matrix
5252 let mut sparse = SparseMatrix::new(3, 5);
5253
5254 // Row 0: only column 1 has value 2.0
5255 sparse.set(0, 1, 2.0);
5256
5257 // Row 1: columns 2 and 3
5258 sparse.set(1, 2, 0.5);
5259 sparse.set(1, 3, 1.5);
5260
5261 // Row 2: columns 0 and 4
5262 sparse.set(2, 0, 3.0);
5263 sparse.set(2, 4, 1.0);
5264
5265 // Test matrix-vector multiplication
5266 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
5267 let mut output = vec![0.0; 3];
5268
5269 sparse.multiply_vec(&input, &mut output);
5270
5271 // Expected results:
5272 // Row 0: 2.0 * 2.0 = 4.0
5273 // Row 1: 0.5 * 3.0 + 1.5 * 4.0 = 1.5 + 6.0 = 7.5
5274 // Row 2: 3.0 * 1.0 + 1.0 * 5.0 = 3.0 + 5.0 = 8.0
5275 assert_eq!(output[0], 4.0);
5276 assert_eq!(output[1], 7.5);
5277 assert_eq!(output[2], 8.0);
5278 }
5279
5280 #[test]
5281 fn test_sparse_matrix_zeros_ignored() {
5282 // Verify that zero values are not stored
5283 let mut sparse = SparseMatrix::new(2, 3);
5284
5285 sparse.set(0, 0, 1.0);
5286 sparse.set(0, 1, 0.0); // Should be ignored
5287 sparse.set(0, 2, 2.0);
5288
5289 // Only 2 values should be stored in row 0
5290 assert_eq!(sparse.values[0].len(), 2);
5291 assert_eq!(sparse.indices[0].len(), 2);
5292
5293 // The stored indices should be 0 and 2
5294 assert_eq!(sparse.indices[0], vec![0, 2]);
5295 assert_eq!(sparse.values[0], vec![1.0, 2.0]);
5296 }
5297
5298 #[test]
5299 fn test_loghz_matrix_sparsity() {
5300 // Verify that LogHz matrices are very sparse (1-2 non-zeros per row)
5301 let sample_rate = 16000.0;
5302 let n_fft = nzu!(512);
5303 let n_bins = nzu!(128);
5304 let f_min = 20.0;
5305 let f_max = sample_rate / 2.0;
5306
5307 let (matrix, _freqs) =
5308 build_loghz_matrix(sample_rate, n_fft, n_bins, f_min, f_max).unwrap();
5309
5310 // Each row should have at most 2 non-zero values (linear interpolation)
5311 for row_idx in 0..matrix.nrows() {
5312 let nnz = matrix.values[row_idx].len();
5313 assert!(
5314 nnz <= 2,
5315 "Row {} has {} non-zeros, expected at most 2",
5316 row_idx,
5317 nnz
5318 );
5319 assert!(nnz >= 1, "Row {} has no non-zeros", row_idx);
5320 }
5321
5322 // Total non-zeros should be close to n_bins * 2
5323 let total_nnz: usize = matrix.values.iter().map(|v| v.len()).sum();
5324 assert!(total_nnz <= n_bins.get() * 2);
5325 assert!(total_nnz >= n_bins.get()); // At least 1 per row
5326 }
5327
5328 #[test]
5329 fn test_mel_matrix_sparsity() {
5330 // Verify that Mel matrices are sparse (triangular filters)
5331 let sample_rate = 16000.0;
5332 let n_fft = nzu!(512);
5333 let n_mels = nzu!(40);
5334 let f_min = 0.0;
5335 let f_max = sample_rate / 2.0;
5336
5337 let matrix =
5338 build_mel_filterbank_matrix(sample_rate, n_fft, n_mels, f_min, f_max, MelNorm::None)
5339 .unwrap();
5340
5341 let n_fft_bins = r2c_output_size(n_fft.get());
5342
5343 // Calculate sparsity
5344 let total_nnz: usize = matrix.values.iter().map(|v| v.len()).sum();
5345 let total_elements = n_mels.get() * n_fft_bins;
5346 let sparsity = 1.0 - (total_nnz as f64 / total_elements as f64);
5347
5348 // Mel filterbanks should be >80% sparse
5349 assert!(
5350 sparsity > 0.8,
5351 "Mel matrix sparsity is only {:.1}%, expected >80%",
5352 sparsity * 100.0
5353 );
5354
5355 // Each mel filter should have significantly fewer than n_fft_bins non-zeros
5356 for row_idx in 0..matrix.nrows() {
5357 let nnz = matrix.values[row_idx].len();
5358 assert!(
5359 nnz < n_fft_bins / 2,
5360 "Mel filter {} is not sparse enough",
5361 row_idx
5362 );
5363 }
5364 }
5365}