spectrograms/spectrogram.rs
1use std::marker::PhantomData;
2use std::num::NonZeroUsize;
3use std::ops::{Deref, DerefMut};
4
5use ndarray::{Array1, Array2};
6use non_empty_slice::{NonEmptySlice, NonEmptyVec, non_empty_vec};
7use num_complex::Complex;
8
9#[cfg(feature = "python")]
10use pyo3::prelude::*;
11
12use crate::cqt::CqtKernel;
13use crate::erb::ErbFilterbank;
14use crate::{
15 CqtParams, ErbParams, R2cPlan, SpectrogramError, SpectrogramResult, WindowType,
16 min_max_single_pass, nzu,
17};
18const EPS: f64 = 1e-12;
19
20//
21// ========================
22// Sparse Matrix for efficient filterbank multiplication
23// ========================
24//
25
26/// Row-wise sparse matrix optimized for matrix-vector multiplication.
27///
28/// This structure stores sparse data as a vector of vectors, where each row maintains
29/// its own list of non-zero values and corresponding column indices. This is more
30/// flexible than traditional CSR format and allows efficient row-by-row construction.
31///
32/// This structure is designed for matrices with very few non-zero values per row,
33/// such as mel filterbanks (triangular filters) and logarithmic frequency mappings
34/// (linear interpolation between 1-2 adjacent bins).
35///
36/// For typical spectrograms:
37/// - `LogHz` interpolation: Only 1-2 non-zeros per row (~99% sparse)
38/// - Mel filterbank: ~10-50 non-zeros per row depending on FFT size (~90-98% sparse)
39///
40/// By storing only non-zero values, we avoid wasting CPU cycles multiplying by zero,
41/// which can provide 10-100x speedup compared to dense matrix multiplication.
42#[derive(Debug, Clone)]
43struct SparseMatrix {
44 /// Number of rows
45 nrows: usize,
46 /// Number of columns
47 ncols: usize,
48 /// Non-zero values for each row (row-major order)
49 values: Vec<Vec<f64>>,
50 /// Column indices for each non-zero value
51 indices: Vec<Vec<usize>>,
52}
53
54impl SparseMatrix {
55 /// Create a new sparse matrix with the given dimensions.
56 fn new(nrows: usize, ncols: usize) -> Self {
57 Self {
58 nrows,
59 ncols,
60 values: vec![Vec::new(); nrows],
61 indices: vec![Vec::new(); nrows],
62 }
63 }
64
65 /// Set a value in the matrix. Only stores if value is non-zero.
66 ///
67 /// # Panics (debug mode only)
68 /// Panics in debug builds if row or col are out of bounds.
69 fn set(&mut self, row: usize, col: usize, value: f64) {
70 debug_assert!(
71 row < self.nrows && col < self.ncols,
72 "SparseMatrix index out of bounds: ({}, {}) for {}x{} matrix",
73 row,
74 col,
75 self.nrows,
76 self.ncols
77 );
78 if row >= self.nrows || col >= self.ncols {
79 return;
80 }
81
82 // Only store non-zero values (with small epsilon for numerical stability)
83 if value.abs() > 1e-10 {
84 self.values[row].push(value);
85 self.indices[row].push(col);
86 }
87 }
88
89 /// Get the number of rows.
90 const fn nrows(&self) -> usize {
91 self.nrows
92 }
93
94 /// Get the number of columns.
95 const fn ncols(&self) -> usize {
96 self.ncols
97 }
98
99 /// Perform sparse matrix-vector multiplication: out = self * input
100 /// This is much faster than dense multiplication when the matrix is sparse.
101 #[inline]
102 fn multiply_vec(&self, input: &[f64], out: &mut [f64]) {
103 debug_assert_eq!(input.len(), self.ncols);
104 debug_assert_eq!(out.len(), self.nrows);
105
106 for (row_idx, (row_values, row_indices)) in
107 self.values.iter().zip(&self.indices).enumerate()
108 {
109 let mut acc = 0.0;
110 for (&value, &col_idx) in row_values.iter().zip(row_indices) {
111 acc += value * input[col_idx];
112 }
113 out[row_idx] = acc;
114 }
115 }
116}
117
118// Linear frequency
119pub type LinearPowerSpectrogram = Spectrogram<LinearHz, Power>;
120pub type LinearMagnitudeSpectrogram = Spectrogram<LinearHz, Magnitude>;
121pub type LinearDbSpectrogram = Spectrogram<LinearHz, Decibels>;
122pub type LinearSpectrogram<AmpScale> = Spectrogram<LinearHz, AmpScale>;
123
124// Log-frequency (e.g. CQT-style)
125pub type LogHzPowerSpectrogram = Spectrogram<LogHz, Power>;
126pub type LogHzMagnitudeSpectrogram = Spectrogram<LogHz, Magnitude>;
127pub type LogHzDbSpectrogram = Spectrogram<LogHz, Decibels>;
128pub type LogHzSpectrogram<AmpScale> = Spectrogram<LogHz, AmpScale>;
129
130// ERB / gammatone
131pub type ErbPowerSpectrogram = Spectrogram<Erb, Power>;
132pub type ErbMagnitudeSpectrogram = Spectrogram<Erb, Magnitude>;
133pub type ErbDbSpectrogram = Spectrogram<Erb, Decibels>;
134pub type GammatonePowerSpectrogram = ErbPowerSpectrogram;
135pub type GammatoneMagnitudeSpectrogram = ErbMagnitudeSpectrogram;
136pub type GammatoneDbSpectrogram = ErbDbSpectrogram;
137pub type ErbSpectrogram<AmpScale> = Spectrogram<Erb, AmpScale>;
138pub type GammatoneSpectrogram<AmpScale> = ErbSpectrogram<AmpScale>;
139
140// Mel
141pub type MelMagnitudeSpectrogram = Spectrogram<Mel, Magnitude>;
142pub type MelPowerSpectrogram = Spectrogram<Mel, Power>;
143pub type MelDbSpectrogram = Spectrogram<Mel, Decibels>;
144pub type LogMelSpectrogram = MelDbSpectrogram;
145pub type MelSpectrogram<AmpScale> = Spectrogram<Mel, AmpScale>;
146
147// CQT
148pub type CqtPowerSpectrogram = Spectrogram<Cqt, Power>;
149pub type CqtMagnitudeSpectrogram = Spectrogram<Cqt, Magnitude>;
150pub type CqtDbSpectrogram = Spectrogram<Cqt, Decibels>;
151pub type CqtSpectrogram<AmpScale> = Spectrogram<Cqt, AmpScale>;
152
153use crate::fft_backend::r2c_output_size;
154
155/// A spectrogram plan is the compiled, reusable execution object.
156///
157/// It owns:
158/// - FFT plan (reusable)
159/// - window samples
160/// - mapping (identity / mel filterbank / etc.)
161/// - amplitude scaling config
162/// - workspace buffers to avoid allocations in hot loops
163///
164/// It computes one specific spectrogram type: `Spectrogram<FreqScale, AmpScale>`.
165///
166/// # Type Parameters
167///
168/// - `FreqScale`: Frequency scale type (e.g. `LinearHz`, `LogHz`, `Mel`, etc.)
169/// - `AmpScale`: Amplitude scaling type (e.g. `Power`, `Magnitude`, `Decibels`, etc.)
170pub struct SpectrogramPlan<FreqScale, AmpScale>
171where
172 AmpScale: AmpScaleSpec + 'static,
173 FreqScale: Copy + Clone + 'static,
174{
175 params: SpectrogramParams,
176
177 stft: StftPlan,
178 mapping: FrequencyMapping<FreqScale>,
179 scaling: AmplitudeScaling<AmpScale>,
180
181 freq_axis: FrequencyAxis<FreqScale>,
182 workspace: Workspace,
183
184 _amp: PhantomData<AmpScale>,
185}
186
187impl<FreqScale, AmpScale> SpectrogramPlan<FreqScale, AmpScale>
188where
189 AmpScale: AmpScaleSpec + 'static,
190 FreqScale: Copy + Clone + 'static,
191{
192 /// Get the spectrogram parameters used to create this plan.
193 ///
194 /// # Returns
195 ///
196 /// A reference to the `SpectrogramParams` used in this plan.
197 #[inline]
198 #[must_use]
199 pub const fn params(&self) -> &SpectrogramParams {
200 &self.params
201 }
202
203 /// Get the frequency axis for this spectrogram plan.
204 ///
205 /// # Returns
206 ///
207 /// A reference to the `FrequencyAxis<FreqScale>` used in this plan.
208 #[inline]
209 #[must_use]
210 pub const fn freq_axis(&self) -> &FrequencyAxis<FreqScale> {
211 &self.freq_axis
212 }
213
214 /// Compute a spectrogram for a mono signal.
215 ///
216 /// This function performs:
217 /// - framing + windowing
218 /// - FFT per frame
219 /// - magnitude/power
220 /// - frequency mapping (identity/mel/etc.)
221 /// - amplitude scaling (linear or dB)
222 ///
223 /// It allocates the output `Array2` once, but does not allocate per-frame.
224 ///
225 /// # Arguments
226 ///
227 /// * `samples` - Audio samples
228 ///
229 /// # Returns
230 ///
231 /// A `Spectrogram<FreqScale, AmpScale>` containing the computed spectrogram.
232 ///
233 /// # Errors
234 ///
235 /// Returns an error if STFT computation or mapping fails.
236 #[inline]
237 pub fn compute(
238 &mut self,
239 samples: &NonEmptySlice<f64>,
240 ) -> SpectrogramResult<Spectrogram<FreqScale, AmpScale>> {
241 let n_frames = self.stft.frame_count(samples.len())?;
242 let n_bins = self.mapping.output_bins();
243
244 // Create output matrix: (n_bins, n_frames)
245 let mut data = Array2::<f64>::zeros((n_bins.get(), n_frames.get()));
246
247 // Ensure workspace is correctly sized
248 self.workspace
249 .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
250
251 // Main loop: fill each frame (column)
252 for frame_idx in 0..n_frames.get() {
253 // CQT needs unwindowed frames because its kernels already contain windowing.
254 // Other mappings use the FFT spectrum, so they need the windowed frame.
255 if self.mapping.kind.needs_unwindowed_frame() {
256 // Fill frame without windowing (for CQT)
257 self.stft
258 .fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
259 } else {
260 // Compute windowed frame spectrum (for FFT-based mappings)
261 self.stft
262 .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
263 }
264
265 // mapping: spectrum(out_len) -> mapped(n_bins)
266 // For CQT, this uses workspace.frame; for others, workspace.spectrum
267 // For ERB, we need the complex FFT output (fft_out)
268 // We need to borrow workspace fields separately to avoid borrow conflicts
269 let Workspace {
270 spectrum,
271 mapped,
272 frame,
273 ..
274 } = &mut self.workspace;
275
276 self.mapping.apply(spectrum, frame, mapped)?;
277
278 // amplitude scaling in-place on mapped vector
279 self.scaling.apply_in_place(mapped)?;
280
281 // write column into output
282 for (row, &val) in mapped.iter().enumerate() {
283 data[[row, frame_idx]] = val;
284 }
285 }
286
287 let times = build_time_axis_seconds(&self.params, n_frames);
288 let axes = Axes::new(self.freq_axis.clone(), times);
289
290 Ok(Spectrogram::new(data, axes, self.params.clone()))
291 }
292
293 /// Compute a single frame of the spectrogram.
294 ///
295 /// This is useful for streaming/online processing where you want to
296 /// process audio frame-by-frame without computing the entire spectrogram.
297 ///
298 /// # Arguments
299 ///
300 /// * `samples` - Audio samples (must contain at least enough samples for the requested frame)
301 /// * `frame_idx` - Frame index to compute
302 ///
303 /// # Returns
304 ///
305 /// A vector of frequency bin values for the requested frame.
306 ///
307 /// # Errors
308 ///
309 /// Returns an error if the frame index is out of bounds or if STFT computation fails.
310 ///
311 /// # Examples
312 ///
313 /// ```
314 /// use spectrograms::*;
315 /// use non_empty_slice::non_empty_vec;
316 ///
317 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
318 /// let samples = non_empty_vec![0.0; nzu!(16000)];
319 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
320 /// let params = SpectrogramParams::new(stft, 16000.0)?;
321 ///
322 /// let planner = SpectrogramPlanner::new();
323 /// let mut plan = planner.linear_plan::<Power>(¶ms, None)?;
324 ///
325 /// // Compute just the first frame
326 /// let frame = plan.compute_frame(&samples, 0)?;
327 /// assert_eq!(frame.len(), nzu!(257)); // n_fft/2 + 1
328 /// # Ok(())
329 /// # }
330 /// ```
331 #[inline]
332 pub fn compute_frame(
333 &mut self,
334 samples: &NonEmptySlice<f64>,
335 frame_idx: usize,
336 ) -> SpectrogramResult<NonEmptyVec<f64>> {
337 let n_bins = self.mapping.output_bins();
338
339 // Ensure workspace is correctly sized
340 self.workspace
341 .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
342
343 // CQT needs unwindowed frames because its kernels already contain windowing.
344 // Other mappings use the FFT spectrum, so they need the windowed frame.
345 if self.mapping.kind.needs_unwindowed_frame() {
346 // Fill frame without windowing (for CQT)
347 self.stft
348 .fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
349 } else {
350 // Compute windowed frame spectrum (for FFT-based mappings)
351 self.stft
352 .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
353 }
354
355 // Apply mapping (using split borrows to avoid borrow conflicts)
356 let Workspace {
357 spectrum,
358 mapped,
359 frame,
360 ..
361 } = &mut self.workspace;
362
363 self.mapping.apply(spectrum, frame, mapped)?;
364
365 // Apply amplitude scaling
366 self.scaling.apply_in_place(mapped)?;
367
368 Ok(mapped.clone())
369 }
370
371 /// Compute spectrogram into a pre-allocated buffer.
372 ///
373 /// This avoids allocating the output matrix, which is useful when
374 /// you want to reuse buffers or have strict memory requirements.
375 ///
376 /// # Arguments
377 ///
378 /// * `samples` - Audio samples
379 /// * `output` - Pre-allocated output matrix (must be correct size: `n_bins` x `n_frames`)
380 ///
381 /// # Returns
382 ///
383 /// An empty result on success.
384 ///
385 /// # Errors
386 ///
387 /// Returns an error if the output buffer dimensions don't match the expected size.
388 ///
389 /// # Examples
390 ///
391 /// ```
392 /// use spectrograms::*;
393 /// use ndarray::Array2;
394 /// use non_empty_slice::non_empty_vec;
395 ///
396 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
397 /// let samples = non_empty_vec![0.0; nzu!(16000)];
398 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
399 /// let params = SpectrogramParams::new(stft, 16000.0)?;
400 ///
401 /// let planner = SpectrogramPlanner::new();
402 /// let mut plan = planner.linear_plan::<Power>(¶ms, None)?;
403 ///
404 /// // Pre-allocate output buffer
405 /// let mut output = Array2::<f64>::zeros((257, 63));
406 /// plan.compute_into(&samples, &mut output)?;
407 /// # Ok(())
408 /// # }
409 /// ```
410 #[inline]
411 pub fn compute_into(
412 &mut self,
413 samples: &NonEmptySlice<f64>,
414 output: &mut Array2<f64>,
415 ) -> SpectrogramResult<()> {
416 let n_frames = self.stft.frame_count(samples.len())?;
417 let n_bins = self.mapping.output_bins();
418
419 // Validate output dimensions
420 if output.nrows() != n_bins.get() {
421 return Err(SpectrogramError::dimension_mismatch(
422 n_bins.get(),
423 output.nrows(),
424 ));
425 }
426 if output.ncols() != n_frames.get() {
427 return Err(SpectrogramError::dimension_mismatch(
428 n_frames.get(),
429 output.ncols(),
430 ));
431 }
432
433 // Ensure workspace is correctly sized
434 self.workspace
435 .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
436
437 // Main loop: fill each frame (column)
438 for frame_idx in 0..n_frames.get() {
439 // CQT needs unwindowed frames because its kernels already contain windowing.
440 // Other mappings use the FFT spectrum, so they need the windowed frame.
441 if self.mapping.kind.needs_unwindowed_frame() {
442 // Fill frame without windowing (for CQT)
443 self.stft
444 .fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
445 } else {
446 // Compute windowed frame spectrum (for FFT-based mappings)
447 self.stft
448 .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
449 }
450
451 // mapping: spectrum(out_len) -> mapped(n_bins)
452 // For CQT, this uses workspace.frame; for others, workspace.spectrum
453 // For ERB, we need the complex FFT output (fft_out)
454 // We need to borrow workspace fields separately to avoid borrow conflicts
455 let Workspace {
456 spectrum,
457 mapped,
458 frame,
459 ..
460 } = &mut self.workspace;
461
462 self.mapping.apply(spectrum, frame, mapped)?;
463
464 // amplitude scaling in-place on mapped vector
465 self.scaling.apply_in_place(mapped)?;
466
467 // write column into output
468 for (row, &val) in mapped.iter().enumerate() {
469 output[[row, frame_idx]] = val;
470 }
471 }
472
473 Ok(())
474 }
475
476 /// Get the expected output dimensions for a given signal length.
477 ///
478 /// # Arguments
479 ///
480 /// * `signal_length` - Length of the input signal in samples.
481 ///
482 /// # Returns
483 ///
484 /// A tuple `(n_bins, n_frames)` representing the output spectrogram shape.
485 ///
486 /// # Errors
487 ///
488 /// Returns an error if the signal length is too short to produce any frames.
489 ///
490 /// # Examples
491 ///
492 /// ```
493 /// use spectrograms::*;
494 ///
495 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
496 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
497 /// let params = SpectrogramParams::new(stft, 16000.0)?;
498 ///
499 /// let planner = SpectrogramPlanner::new();
500 /// let plan = planner.linear_plan::<Power>(¶ms, None)?;
501 ///
502 /// let (n_bins, n_frames) = plan.output_shape(nzu!(16000))?;
503 /// assert_eq!(n_bins, nzu!(257));
504 /// assert_eq!(n_frames, nzu!(63));
505 /// # Ok(())
506 /// # }
507 /// ```
508 #[inline]
509 pub fn output_shape(
510 &self,
511 signal_length: NonZeroUsize,
512 ) -> SpectrogramResult<(NonZeroUsize, NonZeroUsize)> {
513 let n_frames = self.stft.frame_count(signal_length)?;
514 let n_bins = self.mapping.output_bins();
515 Ok((n_bins, n_frames))
516 }
517}
518
519/// STFT (Short-Time Fourier Transform) result containing complex frequency bins.
520///
521/// This is the raw STFT output before any frequency mapping or amplitude scaling.
522///
523/// # Fields
524///
525/// - `data`: Complex STFT matrix with shape (`frequency_bins`, `time_frames`)
526/// - `frequencies`: Frequency axis in Hz
527/// - `sample_rate`: Sample rate in Hz
528/// - `params`: STFT computation parameters
529#[derive(Debug, Clone)]
530#[non_exhaustive]
531pub struct StftResult {
532 /// Complex STFT matrix with shape (`frequency_bins`, `time_frames`)
533 pub data: Array2<Complex<f64>>,
534 /// Frequency axis in Hz
535 pub frequencies: NonEmptyVec<f64>,
536 /// Sample rate in Hz
537 pub sample_rate: f64,
538 pub params: StftParams,
539}
540
541impl StftResult {
542 /// Get the number of frequency bins.
543 ///
544 /// # Returns
545 ///
546 /// Number of frequency bins in the STFT result.
547 #[inline]
548 #[must_use]
549 pub fn n_bins(&self) -> NonZeroUsize {
550 // safety: nrows() > 0 for NonEmptyVec
551 unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
552 }
553
554 /// Get the number of time frames.
555 ///
556 /// # Returns
557 ///
558 /// Number of time frames in the STFT result.
559 #[inline]
560 #[must_use]
561 pub fn n_frames(&self) -> NonZeroUsize {
562 // safety: ncols() > 0 for NonEmptyVec
563 unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
564 }
565
566 /// Get the frequency resolution in Hz
567 ///
568 /// # Returns
569 ///
570 /// Frequency bin width in Hz.
571 #[inline]
572 #[must_use]
573 pub fn frequency_resolution(&self) -> f64 {
574 self.sample_rate / self.params.n_fft().get() as f64
575 }
576
577 /// Get the time resolution in seconds.
578 ///
579 /// # Returns
580 ///
581 /// Time between successive frames in seconds.
582 #[inline]
583 #[must_use]
584 pub fn time_resolution(&self) -> f64 {
585 self.params.hop_size().get() as f64 / self.sample_rate
586 }
587
588 /// Normalizes self.data to remove the complex aspect of it.
589 ///
590 /// # Returns
591 ///
592 /// An Array2\<f64\> containing the norms of each complex number in self.data.
593 #[inline]
594 pub fn norm(&self) -> Array2<f64> {
595 self.as_ref().mapv(Complex::norm)
596 }
597}
598
599impl AsRef<Array2<Complex<f64>>> for StftResult {
600 #[inline]
601 fn as_ref(&self) -> &Array2<Complex<f64>> {
602 &self.data
603 }
604}
605
606impl AsMut<Array2<Complex<f64>>> for StftResult {
607 #[inline]
608 fn as_mut(&mut self) -> &mut Array2<Complex<f64>> {
609 &mut self.data
610 }
611}
612
613impl Deref for StftResult {
614 type Target = Array2<Complex<f64>>;
615
616 #[inline]
617 fn deref(&self) -> &Self::Target {
618 &self.data
619 }
620}
621
622impl DerefMut for StftResult {
623 #[inline]
624 fn deref_mut(&mut self) -> &mut Self::Target {
625 &mut self.data
626 }
627}
628
629/// A planner is an object that can build spectrogram plans.
630///
631/// In your design, this is where:
632/// - FFT plans are created
633/// - mapping matrices are compiled
634/// - axes are computed
635///
636/// This allows you to keep plan building separate from the output types.
637#[derive(Debug, Default)]
638#[non_exhaustive]
639pub struct SpectrogramPlanner;
640
641impl SpectrogramPlanner {
642 /// Create a new spectrogram planner.
643 ///
644 /// # Returns
645 ///
646 /// A new `SpectrogramPlanner` instance.
647 #[inline]
648 #[must_use]
649 pub const fn new() -> Self {
650 Self
651 }
652
653 /// Compute the Short-Time Fourier Transform (STFT) of a signal.
654 ///
655 /// This returns the raw complex STFT matrix before any frequency mapping
656 /// or amplitude scaling. Useful for applications that need the full complex
657 /// spectrum or custom processing.
658 ///
659 /// # Arguments
660 ///
661 /// * `samples` - Audio samples (any type that can be converted to a slice)
662 /// * `params` - STFT computation parameters
663 ///
664 /// # Returns
665 ///
666 /// An `StftResult` containing the complex STFT matrix and metadata.
667 ///
668 /// # Errors
669 ///
670 /// Returns an error if STFT computation fails.
671 ///
672 /// # Examples
673 ///
674 /// ```
675 /// use spectrograms::*;
676 /// use non_empty_slice::non_empty_vec;
677 ///
678 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
679 /// let samples = non_empty_vec![0.0; nzu!(16000)];
680 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
681 /// let params = SpectrogramParams::new(stft, 16000.0)?;
682 ///
683 /// let planner = SpectrogramPlanner::new();
684 /// let stft_result = planner.compute_stft(&samples, ¶ms)?;
685 ///
686 /// println!("STFT: {} bins x {} frames", stft_result.n_bins(), stft_result.n_frames());
687 /// # Ok(())
688 /// # }
689 /// ```
690 ///
691 /// # Performance Note
692 ///
693 /// This method creates a new FFT plan each time. For processing multiple
694 /// signals, create a reusable plan with `StftPlan::new()` instead.
695 ///
696 /// # Examples
697 ///
698 /// ```rust
699 /// use spectrograms::*;
700 /// use non_empty_slice::non_empty_vec;
701 ///
702 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
703 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
704 /// let params = SpectrogramParams::new(stft, 16000.0)?;
705 ///
706 /// // One-shot (convenient)
707 /// let planner = SpectrogramPlanner::new();
708 /// let stft_result = planner.compute_stft(&non_empty_vec![0.0; nzu!(16000)], ¶ms)?;
709 ///
710 /// // Reusable plan (efficient for batch)
711 /// let mut plan = StftPlan::new(¶ms)?;
712 /// for signal in &[non_empty_vec![0.0; nzu!(16000)], non_empty_vec![1.0; nzu!(16000)]] {
713 /// let stft = plan.compute(&signal, ¶ms)?;
714 /// }
715 /// # Ok(())
716 /// # }
717 /// ```
718 #[inline]
719 pub fn compute_stft(
720 &self,
721 samples: &NonEmptySlice<f64>,
722 params: &SpectrogramParams,
723 ) -> SpectrogramResult<StftResult> {
724 let mut plan = StftPlan::new(params)?;
725 plan.compute(samples, params)
726 }
727
728 /// Compute the power spectrum of a single audio frame.
729 ///
730 /// This is useful for real-time processing or analyzing individual frames.
731 ///
732 /// # Arguments
733 ///
734 /// * `samples` - Audio frame (length ≤ n_fft, will be zero-padded if shorter)
735 /// * `n_fft` - FFT size
736 /// * `window` - Window type to apply
737 ///
738 /// # Returns
739 ///
740 /// A vector of power values (|X|²) with length `n_fft/2` + 1.
741 ///
742 /// # Automatic Zero-Padding
743 ///
744 /// If the input signal is shorter than `n_fft`, it will be automatically
745 /// zero-padded to the required length.
746 ///
747 /// # Errors
748 ///
749 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
750 ///
751 /// # Examples
752 ///
753 /// ```
754 /// use spectrograms::*;
755 /// use non_empty_slice::non_empty_vec;
756 ///
757 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
758 /// let frame = non_empty_vec![0.0; nzu!(512)];
759 ///
760 /// let planner = SpectrogramPlanner::new();
761 /// let power = planner.compute_power_spectrum(frame.as_ref(), nzu!(512), WindowType::Hanning)?;
762 ///
763 /// assert_eq!(power.len(), nzu!(257)); // 512/2 + 1
764 /// # Ok(())
765 /// # }
766 /// ```
767 #[inline]
768 pub fn compute_power_spectrum(
769 &self,
770 samples: &NonEmptySlice<f64>,
771 n_fft: NonZeroUsize,
772 window: WindowType,
773 ) -> SpectrogramResult<NonEmptyVec<f64>> {
774 if samples.len() > n_fft {
775 return Err(SpectrogramError::invalid_input(format!(
776 "Input length ({}) exceeds FFT size ({})",
777 samples.len(),
778 n_fft
779 )));
780 }
781
782 let window_samples = make_window(window, n_fft);
783 let out_len = r2c_output_size(n_fft.get());
784
785 // Create FFT plan
786 #[cfg(feature = "realfft")]
787 let mut fft = {
788 let mut planner = crate::RealFftPlanner::new();
789 let plan = planner.get_or_create(n_fft.get());
790 crate::RealFftPlan::new(n_fft.get(), plan)
791 };
792
793 #[cfg(feature = "fftw")]
794 let mut fft = {
795 use std::sync::Arc;
796 let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
797 crate::FftwPlan::new(Arc::new(plan))
798 };
799
800 // Apply window and compute FFT
801 let mut windowed = vec![0.0; n_fft.get()];
802 for i in 0..samples.len().get() {
803 windowed[i] = samples[i] * window_samples[i];
804 }
805 // The rest is already zero-padded
806 let mut fft_out = vec![Complex::new(0.0, 0.0); out_len];
807 fft.process(&windowed, &mut fft_out)?;
808
809 // Convert to power
810 let power: Vec<f64> = fft_out.iter().map(num_complex::Complex::norm_sqr).collect();
811 // safety: power is non-empty since n_fft > 0
812 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
813 }
814
815 /// Compute the magnitude spectrum of a single audio frame.
816 ///
817 /// This is useful for real-time processing or analyzing individual frames.
818 ///
819 /// # Arguments
820 ///
821 /// * `samples` - Audio frame (length ≤ n_fft, will be zero-padded if shorter)
822 /// * `n_fft` - FFT size
823 /// * `window` - Window type to apply
824 ///
825 /// # Returns
826 ///
827 /// A vector of magnitude values (|X|) with length `n_fft/2` + 1.
828 ///
829 /// # Automatic Zero-Padding
830 ///
831 /// If the input signal is shorter than `n_fft`, it will be automatically
832 /// zero-padded to the required length.
833 ///
834 /// # Errors
835 ///
836 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
837 ///
838 /// # Examples
839 ///
840 /// ```
841 /// use spectrograms::*;
842 /// use non_empty_slice::non_empty_vec;
843 ///
844 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
845 /// let frame = non_empty_vec![0.0; nzu!(512)];
846 ///
847 /// let planner = SpectrogramPlanner::new();
848 /// let magnitude = planner.compute_magnitude_spectrum(frame.as_ref(), nzu!(512), WindowType::Hanning)?;
849 ///
850 /// assert_eq!(magnitude.len(), nzu!(257)); // 512/2 + 1
851 /// # Ok(())
852 /// # }
853 /// ```
854 #[inline]
855 pub fn compute_magnitude_spectrum(
856 &self,
857 samples: &NonEmptySlice<f64>,
858 n_fft: NonZeroUsize,
859 window: WindowType,
860 ) -> SpectrogramResult<NonEmptyVec<f64>> {
861 let power = self.compute_power_spectrum(samples, n_fft, window)?;
862 let power = power.iter().map(|&p| p.sqrt()).collect::<Vec<f64>>();
863 // safety: power is non-empty since power_spectrum returned successfully
864 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
865 }
866
867 /// Build a linear-frequency spectrogram plan.
868 ///
869 /// # Type Parameters
870 ///
871 /// `AmpScale` determines whether output is:
872 /// - Magnitude
873 /// - Power
874 /// - Decibels
875 ///
876 /// # Arguments
877 ///
878 /// * `params` - Spectrogram parameters
879 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale
880 /// is `Decibels`)
881 ///
882 /// # Returns
883 ///
884 /// A `SpectrogramPlan` configured for linear-frequency spectrogram computation.
885 ///
886 /// # Errors
887 ///
888 /// Returns an error if the plan cannot be created due to invalid parameters.
889 #[inline]
890 pub fn linear_plan<AmpScale>(
891 &self,
892 params: &SpectrogramParams,
893 db: Option<&LogParams>, // only used when AmpScale = Decibels
894 ) -> SpectrogramResult<SpectrogramPlan<LinearHz, AmpScale>>
895 where
896 AmpScale: AmpScaleSpec + 'static,
897 {
898 let stft = StftPlan::new(params)?;
899 let mapping = FrequencyMapping::<LinearHz>::new(params)?;
900 let scaling = AmplitudeScaling::<AmpScale>::new(db);
901 let freq_axis = build_frequency_axis::<LinearHz>(params, &mapping);
902
903 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
904
905 Ok(SpectrogramPlan {
906 params: params.clone(),
907 stft,
908 mapping,
909 scaling,
910 freq_axis,
911 workspace,
912 _amp: PhantomData,
913 })
914 }
915
916 /// Build a mel-frequency spectrogram plan.
917 ///
918 /// This compiles a mel filterbank matrix and caches it inside the plan.
919 ///
920 /// # Type Parameters
921 ///
922 /// `AmpScale`: determines whether output is:
923 /// - Magnitude
924 /// - Power
925 /// - Decibels
926 ///
927 /// # Arguments
928 ///
929 /// * `params` - Spectrogram parameters
930 /// * `mel` - Mel-specific parameters
931 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
932 ///
933 /// # Returns
934 ///
935 /// A `SpectrogramPlan` configured for mel spectrogram computation.
936 ///
937 /// # Errors
938 ///
939 /// Returns an error if the plan cannot be created due to invalid parameters.
940 #[inline]
941 pub fn mel_plan<AmpScale>(
942 &self,
943 params: &SpectrogramParams,
944 mel: &MelParams,
945 db: Option<&LogParams>, // only used when AmpScale = Decibels
946 ) -> SpectrogramResult<SpectrogramPlan<Mel, AmpScale>>
947 where
948 AmpScale: AmpScaleSpec + 'static,
949 {
950 // cross-validation: mel range must be compatible with sample rate
951 let nyquist = params.nyquist_hz();
952 if mel.f_max() > nyquist {
953 return Err(SpectrogramError::invalid_input(
954 "mel f_max must be <= Nyquist",
955 ));
956 }
957
958 let stft = StftPlan::new(params)?;
959 let mapping = FrequencyMapping::<Mel>::new_mel(params, mel)?;
960 let scaling = AmplitudeScaling::<AmpScale>::new(db);
961 let freq_axis = build_frequency_axis::<Mel>(params, &mapping);
962
963 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
964
965 Ok(SpectrogramPlan {
966 params: params.clone(),
967 stft,
968 mapping,
969 scaling,
970 freq_axis,
971 workspace,
972 _amp: PhantomData,
973 })
974 }
975
976 /// Build an ERB-scale spectrogram plan.
977 ///
978 /// This creates a spectrogram with ERB-spaced frequency bands using gammatone
979 /// filterbank approximation in the frequency domain.
980 ///
981 /// # Type Parameters
982 ///
983 /// `AmpScale`: determines whether output is:
984 /// - Magnitude
985 /// - Power
986 /// - Decibels
987 ///
988 /// # Arguments
989 ///
990 /// * `params` - Spectrogram parameters
991 /// * `erb` - ERB-specific parameters
992 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
993 ///
994 /// # Returns
995 ///
996 /// A `SpectrogramPlan` configured for ERB spectrogram computation.
997 ///
998 /// # Errors
999 ///
1000 /// Returns an error if the plan cannot be created due to invalid parameters.
1001 #[inline]
1002 pub fn erb_plan<AmpScale>(
1003 &self,
1004 params: &SpectrogramParams,
1005 erb: &ErbParams,
1006 db: Option<&LogParams>,
1007 ) -> SpectrogramResult<SpectrogramPlan<Erb, AmpScale>>
1008 where
1009 AmpScale: AmpScaleSpec + 'static,
1010 {
1011 // cross-validation: erb range must be compatible with sample rate
1012 let nyquist = params.nyquist_hz();
1013 if erb.f_max() > nyquist {
1014 return Err(SpectrogramError::invalid_input(format!(
1015 "f_max={} exceeds Nyquist={}",
1016 erb.f_max(),
1017 nyquist
1018 )));
1019 }
1020
1021 let stft = StftPlan::new(params)?;
1022 let mapping = FrequencyMapping::<Erb>::new_erb(params, erb)?;
1023 let scaling = AmplitudeScaling::<AmpScale>::new(db);
1024 let freq_axis = build_frequency_axis::<Erb>(params, &mapping);
1025
1026 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1027
1028 Ok(SpectrogramPlan {
1029 params: params.clone(),
1030 stft,
1031 mapping,
1032 scaling,
1033 freq_axis,
1034 workspace,
1035 _amp: PhantomData,
1036 })
1037 }
1038
1039 /// Build a log-frequency plan.
1040 ///
1041 /// This creates a spectrogram with logarithmically-spaced frequency bins.
1042 ///
1043 /// # Type Parameters
1044 ///
1045 /// `AmpScale`: determines whether output is:
1046 /// - Magnitude
1047 /// - Power
1048 /// - Decibels
1049 ///
1050 /// # Arguments
1051 ///
1052 /// * `params` - Spectrogram parameters
1053 /// * `loghz` - LogHz-specific parameters
1054 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
1055 ///
1056 /// # Returns
1057 ///
1058 /// A `SpectrogramPlan` configured for log-frequency spectrogram computation.
1059 ///
1060 /// # Errors
1061 ///
1062 /// Returns an error if the plan cannot be created due to invalid parameters.
1063 #[inline]
1064 pub fn log_hz_plan<AmpScale>(
1065 &self,
1066 params: &SpectrogramParams,
1067 loghz: &LogHzParams,
1068 db: Option<&LogParams>,
1069 ) -> SpectrogramResult<SpectrogramPlan<LogHz, AmpScale>>
1070 where
1071 AmpScale: AmpScaleSpec + 'static,
1072 {
1073 // cross-validation: loghz range must be compatible with sample rate
1074 let nyquist = params.nyquist_hz();
1075 if loghz.f_max() > nyquist {
1076 return Err(SpectrogramError::invalid_input(format!(
1077 "f_max={} exceeds Nyquist={}",
1078 loghz.f_max(),
1079 nyquist
1080 )));
1081 }
1082
1083 let stft = StftPlan::new(params)?;
1084 let mapping = FrequencyMapping::<LogHz>::new_loghz(params, loghz)?;
1085 let scaling = AmplitudeScaling::<AmpScale>::new(db);
1086 let freq_axis = build_frequency_axis::<LogHz>(params, &mapping);
1087
1088 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1089
1090 Ok(SpectrogramPlan {
1091 params: params.clone(),
1092 stft,
1093 mapping,
1094 scaling,
1095 freq_axis,
1096 workspace,
1097 _amp: PhantomData,
1098 })
1099 }
1100
1101 /// Build a cqt spectrogram plan.
1102 ///
1103 /// # Type Parameters
1104 ///
1105 /// `AmpScale`: determines whether output is:
1106 /// - Magnitude
1107 /// - Power
1108 /// - Decibels
1109 ///
1110 /// # Arguments
1111 ///
1112 /// * `params` - Spectrogram parameters
1113 /// * `cqt` - CQT-specific parameters
1114 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
1115 ///
1116 /// # Returns
1117 ///
1118 /// A `SpectrogramPlan` configured for CQT spectrogram computation.
1119 ///
1120 /// # Errors
1121 ///
1122 /// Returns an error if the plan cannot be created due to invalid parameters.
1123 #[inline]
1124 pub fn cqt_plan<AmpScale>(
1125 &self,
1126 params: &SpectrogramParams,
1127 cqt: &CqtParams,
1128 db: Option<&LogParams>, // only used when AmpScale = Decibels
1129 ) -> SpectrogramResult<SpectrogramPlan<Cqt, AmpScale>>
1130 where
1131 AmpScale: AmpScaleSpec + 'static,
1132 {
1133 let stft = StftPlan::new(params)?;
1134 let mapping = FrequencyMapping::<Cqt>::new(params, cqt)?;
1135 let scaling = AmplitudeScaling::<AmpScale>::new(db);
1136 let freq_axis = build_frequency_axis::<Cqt>(params, &mapping);
1137
1138 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1139
1140 Ok(SpectrogramPlan {
1141 params: params.clone(),
1142 stft,
1143 mapping,
1144 scaling,
1145 freq_axis,
1146 workspace,
1147 _amp: PhantomData,
1148 })
1149 }
1150}
1151
1152/// STFT plan containing reusable FFT plan and buffers.
1153///
1154/// This struct is responsible for performing the Short-Time Fourier Transform (STFT)
1155/// on audio signals based on the provided parameters.
1156///
1157/// It encapsulates the FFT plan, windowing function, and internal buffers to efficiently
1158/// compute the STFT for multiple frames of audio data.
1159///
1160/// # Fields
1161///
1162/// - `n_fft`: Size of the FFT.
1163/// - `hop_size`: Hop size between consecutive frames.
1164/// - `window`: Windowing function samples.
1165/// - `centre`: Whether to centre the frames with padding.
1166/// - `out_len`: Length of the FFT output.
1167/// - `fft`: Boxed FFT plan for real-to-complex transformation.
1168/// - `fft_out`: Internal buffer for FFT output.
1169/// - `frame`: Internal buffer for windowed audio frames.
1170pub struct StftPlan {
1171 n_fft: NonZeroUsize,
1172 hop_size: NonZeroUsize,
1173 window: NonEmptyVec<f64>,
1174 centre: bool,
1175
1176 out_len: NonZeroUsize,
1177
1178 // FFT plan (reused for all frames)
1179 fft: Box<dyn R2cPlan>,
1180
1181 // internal scratch
1182 fft_out: NonEmptyVec<Complex<f64>>,
1183 frame: NonEmptyVec<f64>,
1184}
1185
1186impl StftPlan {
1187 /// Create a new STFT plan from parameters.
1188 ///
1189 /// # Arguments
1190 ///
1191 /// * `params` - Spectrogram parameters containing STFT config
1192 ///
1193 /// # Returns
1194 ///
1195 /// A new `StftPlan` instance.
1196 ///
1197 /// # Errors
1198 ///
1199 /// Returns an error if the FFT plan cannot be created.
1200 #[inline]
1201 pub fn new(params: &SpectrogramParams) -> SpectrogramResult<Self> {
1202 let stft = params.stft();
1203 let n_fft = stft.n_fft();
1204 let hop_size = stft.hop_size();
1205 let centre = stft.centre();
1206
1207 let window = make_window(stft.window(), n_fft);
1208
1209 let out_len = r2c_output_size(n_fft.get());
1210 let out_len = NonZeroUsize::new(out_len)
1211 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1212
1213 #[cfg(feature = "realfft")]
1214 let fft = {
1215 let mut planner = crate::RealFftPlanner::new();
1216 let plan = planner.get_or_create(n_fft.get());
1217 let plan = crate::RealFftPlan::new(n_fft.get(), plan);
1218 Box::new(plan)
1219 };
1220
1221 #[cfg(feature = "fftw")]
1222 let fft = {
1223 use std::sync::Arc;
1224 let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
1225 Box::new(crate::FftwPlan::new(Arc::new(plan)))
1226 };
1227
1228 Ok(Self {
1229 n_fft,
1230 hop_size,
1231 window,
1232 centre,
1233 out_len,
1234 fft,
1235 fft_out: non_empty_vec![Complex::new(0.0, 0.0); out_len],
1236 frame: non_empty_vec![0.0; n_fft],
1237 })
1238 }
1239
1240 fn frame_count(&self, n_samples: NonZeroUsize) -> SpectrogramResult<NonZeroUsize> {
1241 // Framing policy:
1242 // - centre = true: implicit padding of n_fft/2 on both sides
1243 // - centre = false: no padding
1244 //
1245 // Define the number of frames such that each frame has a valid centre sample position.
1246 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1247 let padded_len = n_samples.get() + 2 * pad;
1248
1249 if padded_len < self.n_fft.get() {
1250 // still produce 1 frame (all padding / partial)
1251 return Ok(nzu!(1));
1252 }
1253
1254 let remaining = padded_len - self.n_fft.get();
1255 let n_frames = remaining / self.hop_size().get() + 1;
1256 let n_frames = NonZeroUsize::new(n_frames).ok_or_else(|| {
1257 SpectrogramError::invalid_input("computed number of frames must be non-zero")
1258 })?;
1259 Ok(n_frames)
1260 }
1261
1262 /// Compute one frame FFT using internal buffers only.
1263 fn compute_frame_fft_simple(
1264 &mut self,
1265 samples: &NonEmptySlice<f64>,
1266 frame_idx: usize,
1267 ) -> SpectrogramResult<()> {
1268 let out = self.frame.as_mut_slice();
1269 debug_assert_eq!(out.len(), self.n_fft.get());
1270
1271 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1272 let start = frame_idx
1273 .checked_mul(self.hop_size.get())
1274 .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1275
1276 // Fill windowed frame
1277 for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1278 let v_idx = start + i;
1279 let s_idx = v_idx as isize - pad as isize;
1280
1281 let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1282 0.0
1283 } else {
1284 samples[s_idx as usize]
1285 };
1286 *sample = sample_val * self.window[i];
1287 }
1288
1289 // Compute FFT
1290 let fft_out = self.fft_out.as_mut_slice();
1291 self.fft.process(out, fft_out)?;
1292
1293 Ok(())
1294 }
1295
1296 /// Compute one frame spectrum into workspace:
1297 /// - fills windowed frame
1298 /// - runs FFT
1299 /// - converts to magnitude/power based on `AmpScale` later
1300 fn compute_frame_spectrum(
1301 &mut self,
1302 samples: &NonEmptySlice<f64>,
1303 frame_idx: usize,
1304 workspace: &mut Workspace,
1305 ) -> SpectrogramResult<()> {
1306 let out = workspace.frame.as_mut_slice();
1307
1308 // self.fill_frame(samples, frame_idx, frame)?;
1309 debug_assert_eq!(out.len(), self.n_fft.get());
1310
1311 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1312 let start = frame_idx
1313 .checked_mul(self.hop_size().get())
1314 .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1315
1316 // The "virtual" signal is samples with pad zeros on both sides.
1317 // Virtual index 0..padded_len
1318 // Map virtual index to original samples by subtracting pad.
1319 for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1320 let v_idx = start + i;
1321 let s_idx = v_idx as isize - pad as isize;
1322
1323 let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1324 0.0
1325 } else {
1326 samples[s_idx as usize]
1327 };
1328
1329 *sample = sample_val * self.window[i];
1330 }
1331 let fft_out = workspace.fft_out.as_mut_slice();
1332 // FFT
1333 self.fft.process(out, fft_out)?;
1334
1335 // Convert complex spectrum to linear magnitude OR power here? No:
1336 // Keep "spectrum" as power by default? That would entangle semantics.
1337 //
1338 // Instead, we store magnitude^2 (power) as the canonical intermediate,
1339 // and let AmpScale decide later whether output is magnitude or power.
1340 //
1341 // This is consistent and avoids recomputing norms multiple times.
1342 for (i, c) in workspace.fft_out.iter().enumerate() {
1343 workspace.spectrum[i] = c.norm_sqr();
1344 }
1345
1346 Ok(())
1347 }
1348
1349 /// Fill a time-domain frame WITHOUT applying the window.
1350 ///
1351 /// This is used for CQT mapping, where the CQT kernels already contain
1352 /// windowing applied during kernel generation. Applying the STFT window
1353 /// would result in double-windowing.
1354 ///
1355 /// # Arguments
1356 ///
1357 /// * `samples` - Input audio samples
1358 /// * `frame_idx` - Frame index
1359 /// * `workspace` - Workspace containing the frame buffer to fill
1360 ///
1361 /// # Errors
1362 ///
1363 /// Returns an error if frame index is out of bounds.
1364 fn fill_frame_unwindowed(
1365 &self,
1366 samples: &NonEmptySlice<f64>,
1367 frame_idx: usize,
1368 workspace: &mut Workspace,
1369 ) -> SpectrogramResult<()> {
1370 let out = workspace.frame.as_mut_slice();
1371 debug_assert_eq!(out.len(), self.n_fft.get());
1372
1373 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1374 let start = frame_idx
1375 .checked_mul(self.hop_size().get())
1376 .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1377
1378 // Fill frame WITHOUT windowing
1379 for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1380 let v_idx = start + i;
1381 let s_idx = v_idx as isize - pad as isize;
1382
1383 let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1384 0.0
1385 } else {
1386 samples[s_idx as usize]
1387 };
1388
1389 // No window multiplication - just copy the sample
1390 *sample = sample_val;
1391 }
1392
1393 Ok(())
1394 }
1395
1396 /// Compute the full STFT for a signal, returning an `StftResult`.
1397 ///
1398 /// This is a convenience method that handles frame iteration and
1399 /// builds the complete STFT matrix.
1400 ///
1401 ///
1402 /// # Arguments
1403 ///
1404 /// * `samples` - Input audio samples
1405 /// * `params` - STFT computation parameters
1406 ///
1407 /// # Returns
1408 ///
1409 /// An `StftResult` containing the complex STFT matrix and metadata.
1410 ///
1411 /// # Errors
1412 ///
1413 /// Returns an error if computation fails.
1414 ///
1415 /// # Examples
1416 ///
1417 /// ```rust
1418 /// use spectrograms::*;
1419 /// use non_empty_slice::non_empty_vec;
1420 ///
1421 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1422 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1423 /// let params = SpectrogramParams::new(stft, 16000.0)?;
1424 /// let mut plan = StftPlan::new(¶ms)?;
1425 ///
1426 /// let samples = non_empty_vec![0.0; nzu!(16000)];
1427 /// let stft_result = plan.compute(&samples, ¶ms)?;
1428 ///
1429 /// println!("STFT: {} bins x {} frames", stft_result.n_bins(), stft_result.n_frames());
1430 /// # Ok(())
1431 /// # }
1432 /// ```
1433 #[inline]
1434 pub fn compute(
1435 &mut self,
1436 samples: &NonEmptySlice<f64>,
1437 params: &SpectrogramParams,
1438 ) -> SpectrogramResult<StftResult> {
1439 let n_frames = self.frame_count(samples.len())?;
1440 let n_bins = self.out_len;
1441
1442 // Allocate output matrix (frequency_bins x time_frames)
1443 let mut data = Array2::<Complex<f64>>::zeros((n_bins.get(), n_frames.get()));
1444
1445 // Compute each frame
1446 for frame_idx in 0..n_frames.get() {
1447 self.compute_frame_fft_simple(samples, frame_idx)?;
1448
1449 // Copy from internal buffer to output
1450 for (bin_idx, &value) in self.fft_out.iter().enumerate() {
1451 data[[bin_idx, frame_idx]] = value;
1452 }
1453 }
1454
1455 // Build frequency axis
1456 let frequencies: Vec<f64> = (0..n_bins.get())
1457 .map(|k| k as f64 * params.sample_rate_hz() / params.stft().n_fft().get() as f64)
1458 .collect();
1459 // SAFETY: n_bins > 0
1460 let frequencies = unsafe { NonEmptyVec::new_unchecked(frequencies) };
1461
1462 Ok(StftResult {
1463 data,
1464 frequencies,
1465 sample_rate: params.sample_rate_hz(),
1466 params: params.stft().clone(),
1467 })
1468 }
1469
1470 /// Compute a single frame of STFT, returning the complex spectrum.
1471 ///
1472 /// This is useful for streaming/online processing where you want to
1473 /// process audio frame-by-frame.
1474 ///
1475 /// # Arguments
1476 ///
1477 /// * `samples` - Input audio samples
1478 /// * `frame_idx` - Index of the frame to compute
1479 ///
1480 /// # Returns
1481 ///
1482 /// A `NonEmptyVec` containing the complex spectrum for the specified frame.
1483 ///
1484 /// # Errors
1485 ///
1486 /// Returns an error if computation fails.
1487 ///
1488 /// # Examples
1489 ///
1490 /// ```rust
1491 /// use spectrograms::*;
1492 /// use non_empty_slice::non_empty_vec;
1493 ///
1494 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1495 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1496 /// let params = SpectrogramParams::new(stft, 16000.0)?;
1497 /// let mut plan = StftPlan::new(¶ms)?;
1498 ///
1499 /// let samples = non_empty_vec![0.0; nzu!(16000)];
1500 /// let (_, n_frames) = plan.output_shape(samples.len())?;
1501 ///
1502 /// for frame_idx in 0..n_frames.get() {
1503 /// let spectrum = plan.compute_frame_simple(&samples, frame_idx)?;
1504 /// // Process spectrum...
1505 /// }
1506 /// # Ok(())
1507 /// # }
1508 /// ```
1509 #[inline]
1510 pub fn compute_frame_simple(
1511 &mut self,
1512 samples: &NonEmptySlice<f64>,
1513 frame_idx: usize,
1514 ) -> SpectrogramResult<NonEmptyVec<Complex<f64>>> {
1515 self.compute_frame_fft_simple(samples, frame_idx)?;
1516 Ok(self.fft_out.clone())
1517 }
1518
1519 /// Compute STFT into a pre-allocated buffer.
1520 ///
1521 /// This avoids allocating the output matrix, useful for reusing buffers.
1522 ///
1523 /// # Arguments
1524 ///
1525 /// * `samples` - Input audio samples
1526 /// * `output` - Pre-allocated output buffer (shape: `n_bins` x `n_frames`)
1527 ///
1528 /// # Returns
1529 ///
1530 /// `Ok(())` on success, or an error if dimensions mismatch.
1531 ///
1532 /// # Errors
1533 ///
1534 /// Returns `DimensionMismatch` error if the output buffer has incorrect shape.
1535 ///
1536 /// # Examples
1537 ///
1538 /// ```rust
1539 /// use spectrograms::*;
1540 /// use ndarray::Array2;
1541 /// use num_complex::Complex;
1542 /// use non_empty_slice::non_empty_vec;
1543 ///
1544 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1545 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1546 /// let params = SpectrogramParams::new(stft, 16000.0)?;
1547 /// let mut plan = StftPlan::new(¶ms)?;
1548 ///
1549 /// let samples = non_empty_vec![0.0; nzu!(16000)];
1550 /// let (n_bins, n_frames) = plan.output_shape(samples.len())?;
1551 /// let mut output = Array2::<Complex<f64>>::zeros((n_bins.get(), n_frames.get()));
1552 ///
1553 /// plan.compute_into(&samples, &mut output)?;
1554 /// # Ok(())
1555 /// # }
1556 /// ```
1557 #[inline]
1558 pub fn compute_into(
1559 &mut self,
1560 samples: &NonEmptySlice<f64>,
1561 output: &mut Array2<Complex<f64>>,
1562 ) -> SpectrogramResult<()> {
1563 let n_frames = self.frame_count(samples.len())?;
1564 let n_bins = self.out_len;
1565
1566 // Validate output dimensions
1567 if output.nrows() != n_bins.get() {
1568 return Err(SpectrogramError::dimension_mismatch(
1569 n_bins.get(),
1570 output.nrows(),
1571 ));
1572 }
1573 if output.ncols() != n_frames.get() {
1574 return Err(SpectrogramError::dimension_mismatch(
1575 n_frames.get(),
1576 output.ncols(),
1577 ));
1578 }
1579
1580 // Compute into pre-allocated buffer
1581 for frame_idx in 0..n_frames.get() {
1582 self.compute_frame_fft_simple(samples, frame_idx)?;
1583
1584 for (bin_idx, &value) in self.fft_out.iter().enumerate() {
1585 output[[bin_idx, frame_idx]] = value;
1586 }
1587 }
1588
1589 Ok(())
1590 }
1591
1592 /// Get the expected output dimensions for a given signal length.
1593 ///
1594 /// # Arguments
1595 ///
1596 /// * `signal_length` - Length of the input signal in samples
1597 ///
1598 /// # Returns
1599 ///
1600 /// A tuple `(n_frequency_bins, n_time_frames)`.
1601 ///
1602 /// # Errors
1603 ///
1604 /// Returns an error if the computed number of frames is invalid.
1605 #[inline]
1606 pub fn output_shape(
1607 &self,
1608 signal_length: NonZeroUsize,
1609 ) -> SpectrogramResult<(NonZeroUsize, NonZeroUsize)> {
1610 let n_frames = self.frame_count(signal_length)?;
1611 Ok((self.out_len, n_frames))
1612 }
1613
1614 /// Get the number of frequency bins in the output.
1615 ///
1616 /// # Returns
1617 ///
1618 /// The number of frequency bins.
1619 #[inline]
1620 #[must_use]
1621 pub const fn n_bins(&self) -> NonZeroUsize {
1622 self.out_len
1623 }
1624
1625 /// Get the FFT size.
1626 ///
1627 /// # Returns
1628 ///
1629 /// The FFT size.
1630 #[inline]
1631 #[must_use]
1632 pub const fn n_fft(&self) -> NonZeroUsize {
1633 self.n_fft
1634 }
1635
1636 /// Get the hop size.
1637 ///
1638 /// # Returns
1639 ///
1640 /// The hop size.
1641 #[inline]
1642 #[must_use]
1643 pub const fn hop_size(&self) -> NonZeroUsize {
1644 self.hop_size
1645 }
1646}
1647
1648#[derive(Debug, Clone)]
1649enum MappingKind {
1650 Identity {
1651 out_len: NonZeroUsize,
1652 },
1653 Mel {
1654 matrix: SparseMatrix,
1655 }, // shape: (n_mels, out_len)
1656 LogHz {
1657 matrix: SparseMatrix,
1658 frequencies: NonEmptyVec<f64>,
1659 }, // shape: (n_bins, out_len)
1660 Erb {
1661 filterbank: ErbFilterbank,
1662 },
1663 Cqt {
1664 kernel: CqtKernel,
1665 },
1666}
1667
1668impl MappingKind {
1669 /// Check if this mapping requires unwindowed time-domain frames.
1670 ///
1671 /// CQT kernels already contain windowing applied during kernel generation,
1672 /// so they should receive unwindowed frames to avoid double-windowing.
1673 /// All other mappings work on FFT spectra and don't use the time-domain frame.
1674 const fn needs_unwindowed_frame(&self) -> bool {
1675 matches!(self, Self::Cqt { .. })
1676 }
1677}
1678
1679/// Typed mapping wrapper.
1680#[derive(Debug, Clone)]
1681struct FrequencyMapping<FreqScale> {
1682 kind: MappingKind,
1683 _marker: PhantomData<FreqScale>,
1684}
1685
1686impl FrequencyMapping<LinearHz> {
1687 fn new(params: &SpectrogramParams) -> SpectrogramResult<Self> {
1688 let out_len = r2c_output_size(params.stft().n_fft().get());
1689 let out_len = NonZeroUsize::new(out_len)
1690 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1691 Ok(Self {
1692 kind: MappingKind::Identity { out_len },
1693 _marker: PhantomData,
1694 })
1695 }
1696}
1697
1698impl FrequencyMapping<Mel> {
1699 fn new_mel(params: &SpectrogramParams, mel: &MelParams) -> SpectrogramResult<Self> {
1700 let n_fft = params.stft().n_fft();
1701 let out_len = r2c_output_size(n_fft.get());
1702 let out_len = NonZeroUsize::new(out_len)
1703 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1704
1705 // Validate: mel bins must be <= something sensible
1706 if mel.n_mels() > nzu!(10_000) {
1707 return Err(SpectrogramError::invalid_input(
1708 "n_mels is unreasonably large",
1709 ));
1710 }
1711
1712 let matrix = build_mel_filterbank_matrix(
1713 params.sample_rate_hz(),
1714 n_fft,
1715 mel.n_mels(),
1716 mel.f_min(),
1717 mel.f_max(),
1718 mel.norm(),
1719 )?;
1720
1721 // matrix must be (n_mels, out_len)
1722 if matrix.nrows() != mel.n_mels().get() || matrix.ncols() != out_len.get() {
1723 return Err(SpectrogramError::invalid_input(
1724 "mel filterbank matrix shape mismatch",
1725 ));
1726 }
1727
1728 Ok(Self {
1729 kind: MappingKind::Mel { matrix },
1730 _marker: PhantomData,
1731 })
1732 }
1733}
1734
1735impl FrequencyMapping<LogHz> {
1736 fn new_loghz(params: &SpectrogramParams, loghz: &LogHzParams) -> SpectrogramResult<Self> {
1737 let n_fft = params.stft().n_fft();
1738 let out_len = r2c_output_size(n_fft.get());
1739 let out_len = NonZeroUsize::new(out_len)
1740 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1741 // Validate: n_bins must be <= something sensible
1742 if loghz.n_bins() > nzu!(10_000) {
1743 return Err(SpectrogramError::invalid_input(
1744 "n_bins is unreasonably large",
1745 ));
1746 }
1747
1748 let (matrix, frequencies) = build_loghz_matrix(
1749 params.sample_rate_hz(),
1750 n_fft,
1751 loghz.n_bins(),
1752 loghz.f_min(),
1753 loghz.f_max(),
1754 )?;
1755
1756 // matrix must be (n_bins, out_len)
1757 if matrix.nrows() != loghz.n_bins().get() || matrix.ncols() != out_len.get() {
1758 return Err(SpectrogramError::invalid_input(
1759 "loghz matrix shape mismatch",
1760 ));
1761 }
1762
1763 Ok(Self {
1764 kind: MappingKind::LogHz {
1765 matrix,
1766 frequencies,
1767 },
1768 _marker: PhantomData,
1769 })
1770 }
1771}
1772
1773impl FrequencyMapping<Erb> {
1774 fn new_erb(params: &SpectrogramParams, erb: &crate::erb::ErbParams) -> SpectrogramResult<Self> {
1775 let n_fft = params.stft().n_fft();
1776 let sample_rate = params.sample_rate_hz();
1777
1778 // Validate: n_filters must be <= something sensible
1779 if erb.n_filters() > nzu!(10_000) {
1780 return Err(SpectrogramError::invalid_input(
1781 "n_filters is unreasonably large",
1782 ));
1783 }
1784
1785 // Generate ERB filterbank with pre-computed frequency responses
1786 let filterbank = crate::erb::ErbFilterbank::generate(erb, sample_rate, n_fft)?;
1787
1788 Ok(Self {
1789 kind: MappingKind::Erb { filterbank },
1790 _marker: PhantomData,
1791 })
1792 }
1793}
1794
1795impl FrequencyMapping<Cqt> {
1796 fn new(params: &SpectrogramParams, cqt: &CqtParams) -> SpectrogramResult<Self> {
1797 let sample_rate = params.sample_rate_hz();
1798 let n_fft = params.stft().n_fft();
1799
1800 // Validate that frequency range is reasonable
1801 let f_max = cqt.bin_frequency(cqt.num_bins().get().saturating_sub(1));
1802 if f_max >= sample_rate / 2.0 {
1803 return Err(SpectrogramError::invalid_input(
1804 "CQT maximum frequency must be below Nyquist frequency",
1805 ));
1806 }
1807
1808 // Generate CQT kernel using n_fft as the signal length for kernel generation
1809 let kernel = CqtKernel::generate(cqt, sample_rate, n_fft);
1810
1811 Ok(Self {
1812 kind: MappingKind::Cqt { kernel },
1813 _marker: PhantomData,
1814 })
1815 }
1816}
1817
1818impl<FreqScale> FrequencyMapping<FreqScale> {
1819 const fn output_bins(&self) -> NonZeroUsize {
1820 // safety: all variants ensure output bins > 0 OR rely on a matrix that is guaranteed to have rows > 0
1821 match &self.kind {
1822 MappingKind::Identity { out_len } => *out_len,
1823 // safety: matrix.nrows() > 0
1824 MappingKind::LogHz { matrix, .. } | MappingKind::Mel { matrix } => unsafe {
1825 NonZeroUsize::new_unchecked(matrix.nrows())
1826 },
1827 MappingKind::Erb { filterbank, .. } => filterbank.num_filters(),
1828 MappingKind::Cqt { kernel, .. } => kernel.num_bins(),
1829 }
1830 }
1831
1832 fn apply(
1833 &self,
1834 spectrum: &NonEmptySlice<f64>,
1835 frame: &NonEmptySlice<f64>,
1836 out: &mut NonEmptySlice<f64>,
1837 ) -> SpectrogramResult<()> {
1838 match &self.kind {
1839 MappingKind::Identity { out_len } => {
1840 if spectrum.len() != *out_len {
1841 return Err(SpectrogramError::dimension_mismatch(
1842 (*out_len).get(),
1843 spectrum.len().get(),
1844 ));
1845 }
1846 if out.len() != *out_len {
1847 return Err(SpectrogramError::dimension_mismatch(
1848 (*out_len).get(),
1849 out.len().get(),
1850 ));
1851 }
1852 out.copy_from_slice(spectrum);
1853 Ok(())
1854 }
1855 MappingKind::LogHz { matrix, .. } | MappingKind::Mel { matrix } => {
1856 let out_bins = matrix.nrows();
1857 let in_bins = matrix.ncols();
1858
1859 if spectrum.len().get() != in_bins {
1860 return Err(SpectrogramError::dimension_mismatch(
1861 in_bins,
1862 spectrum.len().get(),
1863 ));
1864 }
1865 if out.len().get() != out_bins {
1866 return Err(SpectrogramError::dimension_mismatch(
1867 out_bins,
1868 out.len().get(),
1869 ));
1870 }
1871
1872 // Sparse matrix-vector multiplication: out = matrix * spectrum
1873 matrix.multiply_vec(spectrum, out);
1874 Ok(())
1875 }
1876 MappingKind::Erb { filterbank } => {
1877 // Apply ERB filterbank using pre-computed frequency responses
1878 // The filterbank already has |H(f)|^2 pre-computed, so we just
1879 // apply it to the power spectrum
1880 let erb_out = filterbank.apply_to_power_spectrum(spectrum)?;
1881
1882 if out.len().get() != erb_out.len().get() {
1883 return Err(SpectrogramError::dimension_mismatch(
1884 erb_out.len().get(),
1885 out.len().get(),
1886 ));
1887 }
1888
1889 out.copy_from_slice(&erb_out);
1890 Ok(())
1891 }
1892 MappingKind::Cqt { kernel } => {
1893 // CQT works on time-domain unwindowed frame (not FFT spectrum).
1894 // The CQT kernels contain windowing applied during kernel generation,
1895 // so the input frame should be unwindowed to avoid double-windowing.
1896 let cqt_complex = kernel.apply(frame)?;
1897
1898 if out.len().get() != cqt_complex.len().get() {
1899 return Err(SpectrogramError::dimension_mismatch(
1900 cqt_complex.len().get(),
1901 out.len().get(),
1902 ));
1903 }
1904
1905 // Convert complex coefficients to power (|z|^2)
1906 // This matches the convention where intermediate values are in power domain
1907 for (i, c) in cqt_complex.iter().enumerate() {
1908 out[i] = c.norm_sqr();
1909 }
1910
1911 Ok(())
1912 }
1913 }
1914 }
1915
1916 fn frequencies_hz(&self, params: &SpectrogramParams) -> NonEmptyVec<f64> {
1917 match &self.kind {
1918 MappingKind::Identity { out_len } => {
1919 // Standard R2C bins: k * sr / n_fft
1920 let n_fft = params.stft().n_fft().get() as f64;
1921 let sr = params.sample_rate_hz();
1922 let df = sr / n_fft;
1923
1924 let mut f = Vec::with_capacity((*out_len).get());
1925 for k in 0..(*out_len).get() {
1926 f.push(k as f64 * df);
1927 }
1928 // safety: out_len > 0
1929 unsafe { NonEmptyVec::new_unchecked(f) }
1930 }
1931 MappingKind::Mel { matrix } => {
1932 // For mel, the axis is defined by the mel band centre frequencies.
1933 // We compute and store them consistently with how we built the filterbank.
1934 let n_mels = matrix.nrows();
1935 // safety: n_mels > 0
1936 let n_mels = unsafe { NonZeroUsize::new_unchecked(n_mels) };
1937 mel_band_centres_hz(n_mels, params.sample_rate_hz(), params.nyquist_hz())
1938 }
1939 MappingKind::LogHz { frequencies, .. } => {
1940 // Frequencies are stored when the mapping is created
1941 frequencies.clone()
1942 }
1943 MappingKind::Erb { filterbank, .. } => {
1944 // ERB center frequencies
1945 filterbank.center_frequencies().to_non_empty_vec()
1946 }
1947 MappingKind::Cqt { kernel, .. } => {
1948 // CQT center frequencies from the kernel
1949 kernel.frequencies().to_non_empty_vec()
1950 }
1951 }
1952 }
1953}
1954
1955//
1956// ========================
1957// Amplitude scaling
1958// ========================
1959//
1960
1961/// Marker trait so we can specialise behaviour by `AmpScale`.
1962pub trait AmpScaleSpec: Sized + Send + Sync {
1963 /// Apply conversion from power-domain value to the desired amplitude scale.
1964 ///
1965 /// # Arguments
1966 ///
1967 /// - `power`: input power-domain value.
1968 ///
1969 /// # Returns
1970 ///
1971 /// Converted amplitude value.
1972 fn apply_from_power(power: f64) -> f64;
1973
1974 /// Apply dB conversion in-place on a power-domain vector.
1975 ///
1976 /// This is a no-op for Power and Magnitude scales.
1977 ///
1978 /// # Arguments
1979 ///
1980 /// - `x`: power-domain values to convert to dB in-place.
1981 /// - `floor_db`: dB floor value to apply.
1982 ///
1983 /// # Returns
1984 ///
1985 /// `SpectrogramResult<()>`: Ok on success, error on invalid input.
1986 ///
1987 /// # Errors
1988 ///
1989 /// - If `floor_db` is not finite.
1990 fn apply_db_in_place(x: &mut [f64], floor_db: f64) -> SpectrogramResult<()>;
1991}
1992
1993impl AmpScaleSpec for Power {
1994 #[inline]
1995 fn apply_from_power(power: f64) -> f64 {
1996 power
1997 }
1998
1999 #[inline]
2000 fn apply_db_in_place(_x: &mut [f64], _floor_db: f64) -> SpectrogramResult<()> {
2001 Ok(())
2002 }
2003}
2004
2005impl AmpScaleSpec for Magnitude {
2006 #[inline]
2007 fn apply_from_power(power: f64) -> f64 {
2008 power.sqrt()
2009 }
2010
2011 #[inline]
2012 fn apply_db_in_place(_x: &mut [f64], _floor_db: f64) -> SpectrogramResult<()> {
2013 Ok(())
2014 }
2015}
2016
2017impl AmpScaleSpec for Decibels {
2018 #[inline]
2019 fn apply_from_power(power: f64) -> f64 {
2020 // dB conversion is applied in batch, not here.
2021 power
2022 }
2023
2024 #[inline]
2025 fn apply_db_in_place(x: &mut [f64], floor_db: f64) -> SpectrogramResult<()> {
2026 // Convert power -> dB: 10*log10(max(power, eps))
2027 // where eps is derived from floor_db to ensure consistency
2028 if !floor_db.is_finite() {
2029 return Err(SpectrogramError::invalid_input("floor_db must be finite"));
2030 }
2031
2032 // Convert floor_db to linear scale to get epsilon
2033 // e.g., floor_db = -80 dB -> eps = 10^(-80/10) = 1e-8
2034 let eps = 10.0_f64.powf(floor_db / 10.0);
2035
2036 for v in x.iter_mut() {
2037 // Clamp power to epsilon before log to avoid log(0) and ensure floor
2038 *v = 10.0 * v.max(eps).log10();
2039 }
2040 Ok(())
2041 }
2042}
2043
2044/// Amplitude scaling configuration.
2045///
2046/// This handles conversion from power-domain intermediate to the desired amplitude scale (Power, Magnitude, Decibels).
2047#[derive(Debug, Clone)]
2048struct AmplitudeScaling<AmpScale> {
2049 db_floor: Option<f64>,
2050 _marker: PhantomData<AmpScale>,
2051}
2052
2053impl<AmpScale> AmplitudeScaling<AmpScale>
2054where
2055 AmpScale: AmpScaleSpec + 'static,
2056{
2057 fn new(db: Option<&LogParams>) -> Self {
2058 let db_floor = db.map(LogParams::floor_db);
2059 Self {
2060 db_floor,
2061 _marker: PhantomData,
2062 }
2063 }
2064
2065 /// Apply amplitude scaling in-place on a mapped spectrum vector.
2066 ///
2067 /// The input vector is assumed to be in the *power* domain (|X|^2),
2068 /// because the STFT stage produces power as the canonical intermediate.
2069 ///
2070 /// - Power: leaves values unchanged.
2071 /// - Magnitude: sqrt(power).
2072 /// - Decibels: converts power -> dB and floors at `db_floor`.
2073 pub fn apply_in_place(&self, x: &mut [f64]) -> SpectrogramResult<()> {
2074 // Convert from canonical power-domain intermediate into the requested linear domain.
2075 for v in x.iter_mut() {
2076 *v = AmpScale::apply_from_power(*v);
2077 }
2078
2079 // Apply dB conversion if configured (no-op for Power/Magnitude via trait impls).
2080 if let Some(floor_db) = self.db_floor {
2081 AmpScale::apply_db_in_place(x, floor_db)?;
2082 }
2083
2084 Ok(())
2085 }
2086}
2087
2088#[derive(Debug, Clone)]
2089struct Workspace {
2090 spectrum: NonEmptyVec<f64>, // out_len (power spectrum)
2091 mapped: NonEmptyVec<f64>, // n_bins (after mapping)
2092 frame: NonEmptyVec<f64>, // n_fft (windowed frame for FFT)
2093 fft_out: NonEmptyVec<Complex<f64>>, // out_len (FFT output)
2094}
2095
2096impl Workspace {
2097 fn new(n_fft: NonZeroUsize, out_len: NonZeroUsize, n_bins: NonZeroUsize) -> Self {
2098 Self {
2099 spectrum: non_empty_vec![0.0; out_len],
2100 mapped: non_empty_vec![0.0; n_bins],
2101 frame: non_empty_vec![0.0; n_fft],
2102 fft_out: non_empty_vec![Complex::new(0.0, 0.0); out_len],
2103 }
2104 }
2105
2106 fn ensure_sizes(&mut self, n_fft: NonZeroUsize, out_len: NonZeroUsize, n_bins: NonZeroUsize) {
2107 if self.spectrum.len() != out_len {
2108 self.spectrum.resize(out_len, 0.0);
2109 }
2110 if self.mapped.len() != n_bins {
2111 self.mapped.resize(n_bins, 0.0);
2112 }
2113 if self.frame.len() != n_fft {
2114 self.frame.resize(n_fft, 0.0);
2115 }
2116 if self.fft_out.len() != out_len {
2117 self.fft_out.resize(out_len, Complex::new(0.0, 0.0));
2118 }
2119 }
2120}
2121
2122fn build_frequency_axis<FreqScale>(
2123 params: &SpectrogramParams,
2124 mapping: &FrequencyMapping<FreqScale>,
2125) -> FrequencyAxis<FreqScale>
2126where
2127 FreqScale: Copy + Clone + 'static,
2128{
2129 let frequencies = mapping.frequencies_hz(params);
2130 FrequencyAxis::new(frequencies)
2131}
2132
2133fn build_time_axis_seconds(params: &SpectrogramParams, n_frames: NonZeroUsize) -> NonEmptyVec<f64> {
2134 let dt = params.frame_period_seconds();
2135 let mut times = Vec::with_capacity(n_frames.get());
2136
2137 for i in 0..n_frames.get() {
2138 times.push(i as f64 * dt);
2139 }
2140
2141 // safety: times is guaranteed non-empty since n_frames > 0
2142
2143 unsafe { NonEmptyVec::new_unchecked(times) }
2144}
2145
2146/// Generate window function samples.
2147///
2148/// Supports various window types including Rectangular, Hanning, Hamming, Blackman, Kaiser, and Gaussian.
2149///
2150/// # Arguments
2151///
2152/// * `window` - The type of window function to generate.
2153/// * `n_fft` - The size of the FFT, which determines the length of the window.
2154///
2155/// # Returns
2156///
2157/// A `NonEmptyVec<f64>` containing the window function samples.
2158///
2159/// # Panics
2160///
2161/// Panics if a custom window is provided with a size that does not match `n_fft`.
2162#[inline]
2163#[must_use]
2164pub fn make_window(window: WindowType, n_fft: NonZeroUsize) -> NonEmptyVec<f64> {
2165 let n_fft = n_fft.get();
2166 let mut w = vec![0.0; n_fft];
2167
2168 match window {
2169 WindowType::Rectangular => {
2170 w.fill(1.0);
2171 }
2172 WindowType::Hanning => {
2173 // Hann: 0.5 - 0.5*cos(2Ï€n/(N-1))
2174 let n1 = (n_fft - 1) as f64;
2175 for (n, v) in w.iter_mut().enumerate() {
2176 *v = 0.5f64.mul_add(-(2.0 * std::f64::consts::PI * (n as f64) / n1).cos(), 0.5);
2177 }
2178 }
2179 WindowType::Hamming => {
2180 // Hamming: 0.54 - 0.46*cos(2Ï€n/(N-1))
2181 let n1 = (n_fft - 1) as f64;
2182 for (n, v) in w.iter_mut().enumerate() {
2183 *v = 0.46f64.mul_add(-(2.0 * std::f64::consts::PI * (n as f64) / n1).cos(), 0.54);
2184 }
2185 }
2186 WindowType::Blackman => {
2187 // Blackman: 0.42 - 0.5*cos(2Ï€n/(N-1)) + 0.08*cos(4Ï€n/(N-1))
2188 let n1 = (n_fft - 1) as f64;
2189 for (n, v) in w.iter_mut().enumerate() {
2190 let a = 2.0 * std::f64::consts::PI * (n as f64) / n1;
2191 *v = 0.08f64.mul_add((2.0 * a).cos(), 0.5f64.mul_add(-a.cos(), 0.42));
2192 }
2193 }
2194 WindowType::Kaiser { beta } => {
2195 if n_fft == 1 {
2196 w[0] = 1.0;
2197 } else {
2198 let denom = modified_bessel_i0(beta);
2199 let n_max = (n_fft - 1) as f64 / 2.0;
2200
2201 for (i, value) in w.iter_mut().enumerate() {
2202 let n = i as f64 - n_max;
2203 let ratio = if n_max == 0.0 {
2204 0.0
2205 } else {
2206 let normalized = n / n_max;
2207 (1.0 - normalized * normalized).max(0.0)
2208 };
2209 let arg = beta * ratio.sqrt();
2210 *value = if denom == 0.0 {
2211 0.0
2212 } else {
2213 modified_bessel_i0(arg) / denom
2214 };
2215 }
2216 }
2217 }
2218 WindowType::Gaussian { std } => (0..n_fft).for_each(|i| {
2219 let n = i as f64;
2220 let center: f64 = (n_fft - 1) as f64 / 2.0;
2221 let exponent: f64 = -0.5 * ((n - center) / std).powi(2);
2222 w[i] = exponent.exp();
2223 }),
2224 WindowType::Custom { coefficients, size } => {
2225 assert!(
2226 size.get() == n_fft,
2227 "Custom window size mismatch: expected {}, got {}. \
2228 Custom windows must be pre-computed with the exact FFT size.",
2229 n_fft,
2230 size.get()
2231 );
2232 w.copy_from_slice(&coefficients);
2233 }
2234 }
2235
2236 // safety: window is guaranteed non-empty since n_fft > 0
2237 unsafe { NonEmptyVec::new_unchecked(w) }
2238}
2239
2240fn modified_bessel_i0(x: f64) -> f64 {
2241 let ax = x.abs();
2242 if ax <= 3.75 {
2243 let t = x / 3.75;
2244 let t2 = t * t;
2245 1.0 + t2
2246 * (3.515_622_9
2247 + t2 * (3.089_942_4
2248 + t2 * (1.206_749_2
2249 + t2 * (0.265_973_2 + t2 * (0.036_076_8 + t2 * 0.004_581_3)))))
2250 } else {
2251 let t = 3.75 / ax;
2252 let poly = 0.398_942_28
2253 + t * (0.013_285_92
2254 + t * (0.002_253_19
2255 + t * (-0.001_575_65
2256 + t * (0.009_162_81
2257 + t * (-0.020_577_06
2258 + t * (0.026_355_37 + t * (-0.016_476_33 + t * 0.003_923_77)))))));
2259
2260 (ax.exp() / (ax.sqrt() * (2.0 * std::f64::consts::PI).sqrt())) * poly
2261 }
2262}
2263
2264/// Convert Hz to mel scale using Slaney formula (librosa default, htk=False).
2265///
2266/// Uses a hybrid scale:
2267/// - Linear below 1000 Hz: mel = hz / (200/3)
2268/// - Logarithmic above 1000 Hz: mel = 15 + log(hz/1000) / log_step
2269///
2270/// This matches librosa's default behavior.
2271fn hz_to_mel(hz: f64) -> f64 {
2272 const F_MIN: f64 = 0.0;
2273 const F_SP: f64 = 200.0 / 3.0; // ~66.667
2274 const MIN_LOG_HZ: f64 = 1000.0;
2275 const MIN_LOG_MEL: f64 = (MIN_LOG_HZ - F_MIN) / F_SP; // = 15.0
2276 const LOGSTEP: f64 = 0.068_751_777_420_949_23; // ln(6.4) / 27
2277 if hz >= MIN_LOG_HZ {
2278 // Logarithmic region
2279 MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / LOGSTEP
2280 } else {
2281 // Linear region
2282 (hz - F_MIN) / F_SP
2283 }
2284}
2285
2286/// Convert mel to Hz using Slaney formula (librosa default, htk=False).
2287///
2288/// Inverse of hz_to_mel.
2289fn mel_to_hz(mel: f64) -> f64 {
2290 const F_MIN: f64 = 0.0;
2291 const F_SP: f64 = 200.0 / 3.0; // ~66.667
2292 const MIN_LOG_HZ: f64 = 1000.0;
2293 const MIN_LOG_MEL: f64 = (MIN_LOG_HZ - F_MIN) / F_SP; // = 15.0
2294 const LOGSTEP: f64 = 0.068_751_777_420_949_23; // ln(6.4) / 27
2295
2296 if mel >= MIN_LOG_MEL {
2297 // Logarithmic region
2298 MIN_LOG_HZ * (LOGSTEP * (mel - MIN_LOG_MEL)).exp()
2299 } else {
2300 // Linear region
2301 F_SP.mul_add(mel, F_MIN)
2302 }
2303}
2304
2305fn build_mel_filterbank_matrix(
2306 sample_rate_hz: f64,
2307 n_fft: NonZeroUsize,
2308 n_mels: NonZeroUsize,
2309 f_min: f64,
2310 f_max: f64,
2311 norm: MelNorm,
2312) -> SpectrogramResult<SparseMatrix> {
2313 if sample_rate_hz <= 0.0 || !sample_rate_hz.is_finite() {
2314 return Err(SpectrogramError::invalid_input(
2315 "sample_rate_hz must be finite and > 0",
2316 ));
2317 }
2318 if f_min < 0.0 || f_min.is_infinite() {
2319 return Err(SpectrogramError::invalid_input("f_min must be >= 0"));
2320 }
2321 if f_max <= f_min {
2322 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
2323 }
2324 if f_max > sample_rate_hz * 0.5 {
2325 return Err(SpectrogramError::invalid_input("f_max must be <= Nyquist"));
2326 }
2327 let n_mels = n_mels.get();
2328 let n_fft = n_fft.get();
2329 let out_len = r2c_output_size(n_fft);
2330
2331 // FFT bin frequencies
2332 let df = sample_rate_hz / n_fft as f64;
2333
2334 // Mel points: n_mels + 2 (for triangular edges)
2335 let mel_min = hz_to_mel(f_min);
2336 let mel_max = hz_to_mel(f_max);
2337
2338 let n_points = n_mels + 2;
2339 let step = (mel_max - mel_min) / (n_points - 1) as f64;
2340
2341 let mut mel_points = Vec::with_capacity(n_points);
2342 for i in 0..n_points {
2343 mel_points.push((i as f64).mul_add(step, mel_min));
2344 }
2345
2346 let mut hz_points = Vec::with_capacity(n_points);
2347 for m in &mel_points {
2348 hz_points.push(mel_to_hz(*m));
2349 }
2350
2351 // Build filterbank as sparse matrix (librosa-style, in frequency space)
2352 // This builds triangular filters based on actual frequencies, not bin indices
2353 let mut fb = SparseMatrix::new(n_mels, out_len);
2354
2355 for m in 0..n_mels {
2356 let freq_left = hz_points[m];
2357 let freq_center = hz_points[m + 1];
2358 let freq_right = hz_points[m + 2];
2359
2360 let fdiff_left = freq_center - freq_left;
2361 let fdiff_right = freq_right - freq_center;
2362
2363 if fdiff_left == 0.0 || fdiff_right == 0.0 {
2364 // Degenerate triangle, skip
2365 continue;
2366 }
2367
2368 // For each FFT bin, compute the triangular weight based on its frequency
2369 for k in 0..out_len {
2370 let bin_freq = k as f64 * df;
2371
2372 // Lower slope: rises from freq_left to freq_center
2373 let lower = (bin_freq - freq_left) / fdiff_left;
2374
2375 // Upper slope: falls from freq_center to freq_right
2376 let upper = (freq_right - bin_freq) / fdiff_right;
2377
2378 // Triangle is the minimum of the two slopes, clipped to [0, 1]
2379 let weight = lower.min(upper).clamp(0.0, 1.0);
2380
2381 if weight > 0.0 {
2382 fb.set(m, k, weight);
2383 }
2384 }
2385 }
2386
2387 // Apply normalization
2388 match norm {
2389 MelNorm::None => {
2390 // No normalization needed
2391 }
2392 MelNorm::Slaney => {
2393 // Slaney-style area normalization: 2 / (hz_max - hz_min) for each triangle
2394 // NOTE: Uses Hz bandwidth, not mel bandwidth (to match librosa's implementation)
2395 for m in 0..n_mels {
2396 let mel_left = mel_points[m];
2397 let mel_right = mel_points[m + 2];
2398 let hz_left = mel_to_hz(mel_left);
2399 let hz_right = mel_to_hz(mel_right);
2400 let enorm = 2.0 / (hz_right - hz_left);
2401
2402 // Normalize all values in this row
2403 for val in &mut fb.values[m] {
2404 *val *= enorm;
2405 }
2406 }
2407 }
2408 MelNorm::L1 => {
2409 // L1 normalization: sum of weights = 1.0
2410 for m in 0..n_mels {
2411 let sum: f64 = fb.values[m].iter().sum();
2412 if sum > 0.0 {
2413 let normalizer = 1.0 / sum;
2414 for val in &mut fb.values[m] {
2415 *val *= normalizer;
2416 }
2417 }
2418 }
2419 }
2420 MelNorm::L2 => {
2421 // L2 normalization: L2 norm = 1.0
2422 for m in 0..n_mels {
2423 let norm_val: f64 = fb.values[m].iter().map(|&v| v * v).sum::<f64>().sqrt();
2424 if norm_val > 0.0 {
2425 let normalizer = 1.0 / norm_val;
2426 for val in &mut fb.values[m] {
2427 *val *= normalizer;
2428 }
2429 }
2430 }
2431 }
2432 }
2433
2434 Ok(fb)
2435}
2436
2437/// Build a logarithmic frequency interpolation matrix.
2438///
2439/// Maps linearly-spaced FFT bins to logarithmically-spaced frequency bins
2440/// using linear interpolation.
2441fn build_loghz_matrix(
2442 sample_rate_hz: f64,
2443 n_fft: NonZeroUsize,
2444 n_bins: NonZeroUsize,
2445 f_min: f64,
2446 f_max: f64,
2447) -> SpectrogramResult<(SparseMatrix, NonEmptyVec<f64>)> {
2448 if sample_rate_hz <= 0.0 || !sample_rate_hz.is_finite() {
2449 return Err(SpectrogramError::invalid_input(
2450 "sample_rate_hz must be finite and > 0",
2451 ));
2452 }
2453 if f_min <= 0.0 || f_min.is_infinite() {
2454 return Err(SpectrogramError::invalid_input(
2455 "f_min must be finite and > 0",
2456 ));
2457 }
2458 if f_max <= f_min {
2459 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
2460 }
2461 if f_max > sample_rate_hz * 0.5 {
2462 return Err(SpectrogramError::invalid_input("f_max must be <= Nyquist"));
2463 }
2464
2465 let n_bins = n_bins.get();
2466 let n_fft = n_fft.get();
2467
2468 let out_len = r2c_output_size(n_fft);
2469 let df = sample_rate_hz / n_fft as f64;
2470
2471 // Generate logarithmically-spaced frequencies
2472 let log_f_min = f_min.ln();
2473 let log_f_max = f_max.ln();
2474 let log_step = (log_f_max - log_f_min) / (n_bins - 1) as f64;
2475
2476 let mut log_frequencies = Vec::with_capacity(n_bins);
2477 for i in 0..n_bins {
2478 let log_f = (i as f64).mul_add(log_step, log_f_min);
2479 log_frequencies.push(log_f.exp());
2480 }
2481 // safety: n_bins > 0
2482 let log_frequencies = unsafe { NonEmptyVec::new_unchecked(log_frequencies) };
2483
2484 // Build interpolation matrix as sparse matrix
2485 let mut matrix = SparseMatrix::new(n_bins, out_len);
2486
2487 for (bin_idx, &target_freq) in log_frequencies.iter().enumerate() {
2488 // Find the two FFT bins that bracket this frequency
2489 let exact_bin = target_freq / df;
2490 let lower_bin = exact_bin.floor() as usize;
2491 let upper_bin = (exact_bin.ceil() as usize).min(out_len - 1);
2492
2493 if lower_bin >= out_len {
2494 continue;
2495 }
2496
2497 if lower_bin == upper_bin {
2498 // Exact match
2499 matrix.set(bin_idx, lower_bin, 1.0);
2500 } else {
2501 // Linear interpolation
2502 let frac = exact_bin - lower_bin as f64;
2503 matrix.set(bin_idx, lower_bin, 1.0 - frac);
2504 if upper_bin < out_len {
2505 matrix.set(bin_idx, upper_bin, frac);
2506 }
2507 }
2508 }
2509
2510 Ok((matrix, log_frequencies))
2511}
2512
2513fn mel_band_centres_hz(
2514 n_mels: NonZeroUsize,
2515 sample_rate_hz: f64,
2516 nyquist_hz: f64,
2517) -> NonEmptyVec<f64> {
2518 let f_min = 0.0;
2519 let f_max = nyquist_hz.min(sample_rate_hz * 0.5);
2520
2521 let mel_min = hz_to_mel(f_min);
2522 let mel_max = hz_to_mel(f_max);
2523 let n_mels = n_mels.get();
2524 let step = (mel_max - mel_min) / (n_mels + 1) as f64;
2525
2526 let mut centres = Vec::with_capacity(n_mels);
2527 for i in 0..n_mels {
2528 let mel = (i as f64 + 1.0).mul_add(step, mel_min);
2529 centres.push(mel_to_hz(mel));
2530 }
2531 // safety: centres is guaranteed non-empty since n_mels > 0
2532 unsafe { NonEmptyVec::new_unchecked(centres) }
2533}
2534
2535/// Spectrogram structure holding the computed spectrogram data and metadata.
2536///
2537/// # Type Parameters
2538///
2539/// * `FreqScale`: The frequency scale type (e.g., `LinearHz`, `Mel`, `LogHz`, etc.).
2540/// * `AmpScale`: The amplitude scale type (e.g., `Power`, `Magnitude`, `Decibels`).
2541///
2542/// # Fields
2543///
2544/// * `data`: A 2D array containing the spectrogram data.
2545/// * `axes`: The axes of the spectrogram (frequency and time).
2546/// * `params`: The parameters used to compute the spectrogram.
2547/// * `_amp`: A phantom data marker for the amplitude scale type.
2548#[derive(Debug, Clone)]
2549#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
2550pub struct Spectrogram<FreqScale, AmpScale>
2551where
2552 AmpScale: AmpScaleSpec + 'static,
2553 FreqScale: Copy + Clone + 'static,
2554{
2555 data: Array2<f64>,
2556 axes: Axes<FreqScale>,
2557 params: SpectrogramParams,
2558 #[cfg_attr(feature = "serde", serde(skip))]
2559 _amp: PhantomData<AmpScale>,
2560}
2561
2562impl<FreqScale, AmpScale> Spectrogram<FreqScale, AmpScale>
2563where
2564 AmpScale: AmpScaleSpec + 'static,
2565 FreqScale: Copy + Clone + 'static,
2566{
2567 /// Get the X-axis label.
2568 ///
2569 /// # Returns
2570 ///
2571 /// A static string slice representing the X-axis label.
2572 #[inline]
2573 #[must_use]
2574 pub const fn x_axis_label() -> &'static str {
2575 "Time (s)"
2576 }
2577
2578 /// Get the Y-axis label based on the frequency scale type.
2579 ///
2580 /// # Returns
2581 ///
2582 /// A static string slice representing the Y-axis label.
2583 #[inline]
2584 #[must_use]
2585 pub fn y_axis_label() -> &'static str {
2586 match std::any::TypeId::of::<FreqScale>() {
2587 id if id == std::any::TypeId::of::<LinearHz>() => "Frequency (Hz)",
2588 id if id == std::any::TypeId::of::<Mel>() => "Frequency (Mel)",
2589 id if id == std::any::TypeId::of::<LogHz>() => "Frequency (Log Hz)",
2590 id if id == std::any::TypeId::of::<Erb>() => "Frequency (ERB)",
2591 id if id == std::any::TypeId::of::<Cqt>() => "Frequency (CQT Bins)",
2592 _ => "Frequency",
2593 }
2594 }
2595
2596 /// Internal constructor. Only callable inside the crate.
2597 ///
2598 /// All inputs must already be validated and consistent.
2599 pub(crate) fn new(data: Array2<f64>, axes: Axes<FreqScale>, params: SpectrogramParams) -> Self {
2600 debug_assert_eq!(data.nrows(), axes.frequencies().len().get());
2601 debug_assert_eq!(data.ncols(), axes.times().len().get());
2602
2603 Self {
2604 data,
2605 axes,
2606 params,
2607 _amp: PhantomData,
2608 }
2609 }
2610
2611 /// Set spectrogram data matrix
2612 ///
2613 /// # Arguments
2614 ///
2615 /// * `data` - The new spectrogram data matrix.
2616 #[inline]
2617 pub fn set_data(&mut self, data: Array2<f64>) {
2618 self.data = data;
2619 }
2620
2621 /// Spectrogram data matrix
2622 ///
2623 /// # Returns
2624 ///
2625 /// A reference to the spectrogram data matrix.
2626 #[inline]
2627 #[must_use]
2628 pub const fn data(&self) -> &Array2<f64> {
2629 &self.data
2630 }
2631
2632 /// Consume the spectrogram and extract the data matrix.
2633 ///
2634 /// This method moves the data out of the spectrogram, consuming it.
2635 /// Useful for transferring ownership to Python without copying.
2636 ///
2637 /// # Returns
2638 ///
2639 /// The owned spectrogram data matrix.
2640 #[inline]
2641 #[must_use]
2642 pub fn into_data(self) -> Array2<f64> {
2643 self.data
2644 }
2645
2646 /// Axes of the spectrogram
2647 ///
2648 /// # Returns
2649 ///
2650 /// A reference to the axes of the spectrogram.
2651 #[inline]
2652 #[must_use]
2653 pub const fn axes(&self) -> &Axes<FreqScale> {
2654 &self.axes
2655 }
2656
2657 /// Frequency axis in Hz
2658 ///
2659 /// # Returns
2660 ///
2661 /// A reference to the frequency axis in Hz.
2662 #[inline]
2663 #[must_use]
2664 pub fn frequencies(&self) -> &NonEmptySlice<f64> {
2665 self.axes.frequencies()
2666 }
2667
2668 /// Frequency range in Hz (min, max)
2669 ///
2670 /// # Returns
2671 ///
2672 /// A tuple containing the minimum and maximum frequencies in Hz.
2673 #[inline]
2674 #[must_use]
2675 pub const fn frequency_range(&self) -> (f64, f64) {
2676 self.axes.frequency_range()
2677 }
2678
2679 /// Time axis in seconds
2680 ///
2681 /// # Returns
2682 ///
2683 /// A reference to the time axis in seconds.
2684 #[inline]
2685 #[must_use]
2686 pub fn times(&self) -> &NonEmptySlice<f64> {
2687 self.axes.times()
2688 }
2689
2690 /// Spectrogram computation parameters
2691 ///
2692 /// # Returns
2693 ///
2694 /// A reference to the spectrogram computation parameters.
2695 #[inline]
2696 #[must_use]
2697 pub const fn params(&self) -> &SpectrogramParams {
2698 &self.params
2699 }
2700
2701 /// Duration of the spectrogram in seconds
2702 ///
2703 /// # Returns
2704 ///
2705 /// The duration of the spectrogram in seconds.
2706 #[inline]
2707 #[must_use]
2708 pub fn duration(&self) -> f64 {
2709 self.axes.duration()
2710 }
2711
2712 /// If this is a dB spectrogram, return the (min, max) dB values. otherwise do the maths to compute dB range.
2713 ///
2714 /// # Returns
2715 ///
2716 /// The (min, max) dB values of the spectrogram, or `None` if the amplitude scale is unknown.
2717 #[inline]
2718 #[must_use]
2719 pub fn db_range(&self) -> Option<(f64, f64)> {
2720 let type_self = std::any::TypeId::of::<AmpScale>();
2721
2722 if type_self == std::any::TypeId::of::<Decibels>() {
2723 let (min, max) = min_max_single_pass(self.data.as_slice()?);
2724 Some((min, max))
2725 } else if type_self == std::any::TypeId::of::<Power>() {
2726 // Not a dB spectrogram; compute dB range from power values
2727 let mut min_db = f64::INFINITY;
2728 let mut max_db = f64::NEG_INFINITY;
2729 for &v in &self.data {
2730 let db = 10.0 * (v + EPS).log10();
2731 if db < min_db {
2732 min_db = db;
2733 }
2734 if db > max_db {
2735 max_db = db;
2736 }
2737 }
2738 Some((min_db, max_db))
2739 } else if type_self == std::any::TypeId::of::<Magnitude>() {
2740 // Not a dB spectrogram; compute dB range from magnitude values
2741 let mut min_db = f64::INFINITY;
2742 let mut max_db = f64::NEG_INFINITY;
2743
2744 for &v in &self.data {
2745 let power = v * v;
2746 let db = 10.0 * (power + EPS).log10();
2747 if db < min_db {
2748 min_db = db;
2749 }
2750 if db > max_db {
2751 max_db = db;
2752 }
2753 }
2754
2755 Some((min_db, max_db))
2756 } else {
2757 // Unknown AmpScale type; return dummy values
2758 None
2759 }
2760 }
2761
2762 /// Number of frequency bins
2763 ///
2764 /// # Returns
2765 ///
2766 /// The number of frequency bins in the spectrogram.
2767 #[inline]
2768 #[must_use]
2769 pub fn n_bins(&self) -> NonZeroUsize {
2770 // safety: data.nrows() > 0 is guaranteed by construction
2771 unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
2772 }
2773
2774 /// Number of time frames in the spectrogram
2775 ///
2776 /// # Returns
2777 ///
2778 /// The number of time frames (columns) in the spectrogram.
2779 #[inline]
2780 #[must_use]
2781 pub fn n_frames(&self) -> NonZeroUsize {
2782 // safety: data.ncols() > 0 is guaranteed by construction
2783 unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
2784 }
2785}
2786
2787impl<FreqScale, AmpScale> AsRef<Array2<f64>> for Spectrogram<FreqScale, AmpScale>
2788where
2789 FreqScale: Copy + Clone + 'static,
2790 AmpScale: AmpScaleSpec + 'static,
2791{
2792 #[inline]
2793 fn as_ref(&self) -> &Array2<f64> {
2794 &self.data
2795 }
2796}
2797
2798impl<FreqScale, AmpScale> Deref for Spectrogram<FreqScale, AmpScale>
2799where
2800 FreqScale: Copy + Clone + 'static,
2801 AmpScale: AmpScaleSpec + 'static,
2802{
2803 type Target = Array2<f64>;
2804
2805 #[inline]
2806 fn deref(&self) -> &Self::Target {
2807 &self.data
2808 }
2809}
2810
2811impl<AmpScale> Spectrogram<LinearHz, AmpScale>
2812where
2813 AmpScale: AmpScaleSpec + 'static,
2814{
2815 /// Compute a linear-frequency spectrogram from audio samples.
2816 ///
2817 /// This is a convenience method that creates a planner internally and computes
2818 /// the spectrogram in one call. For processing multiple signals with the same
2819 /// parameters, use [`SpectrogramPlanner::linear_plan`] to create a reusable plan.
2820 ///
2821 /// # Arguments
2822 ///
2823 /// * `samples` - Audio samples (any type that can be converted to a slice)
2824 /// * `params` - Spectrogram computation parameters
2825 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2826 ///
2827 /// # Returns
2828 ///
2829 /// A linear-frequency spectrogram with the specified amplitude scale.
2830 ///
2831 /// # Errors
2832 ///
2833 /// Returns an error if:
2834 /// - The samples slice is empty
2835 /// - Parameters are invalid
2836 /// - FFT computation fails
2837 ///
2838 /// # Examples
2839 ///
2840 /// ```
2841 /// use spectrograms::*;
2842 /// use non_empty_slice::non_empty_vec;
2843 ///
2844 /// # fn example() -> SpectrogramResult<()> {
2845 /// // Create a simple test signal
2846 /// let sample_rate = 16000.0;
2847 /// let samples_vec: Vec<f64> = (0..16000).map(|i| {
2848 /// (2.0 * std::f64::consts::PI * 440.0 * i as f64 / sample_rate).sin()
2849 /// }).collect();
2850 /// let samples = non_empty_slice::NonEmptyVec::new(samples_vec).unwrap();
2851 ///
2852 /// // Set up parameters
2853 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2854 /// let params = SpectrogramParams::new(stft, sample_rate)?;
2855 ///
2856 /// // Compute power spectrogram
2857 /// let spec = LinearPowerSpectrogram::compute(&samples, ¶ms, None)?;
2858 ///
2859 /// println!("Computed spectrogram: {} bins x {} frames", spec.n_bins(), spec.n_frames());
2860 /// # Ok(())
2861 /// # }
2862 /// ```
2863 #[inline]
2864 pub fn compute(
2865 samples: &NonEmptySlice<f64>,
2866 params: &SpectrogramParams,
2867 db: Option<&LogParams>,
2868 ) -> SpectrogramResult<Self> {
2869 let planner = SpectrogramPlanner::new();
2870 let mut plan = planner.linear_plan(params, db)?;
2871 plan.compute(samples)
2872 }
2873}
2874
2875impl<AmpScale> Spectrogram<Mel, AmpScale>
2876where
2877 AmpScale: AmpScaleSpec + 'static,
2878{
2879 /// Compute a mel-frequency spectrogram from audio samples.
2880 ///
2881 /// This is a convenience method that creates a planner internally and computes
2882 /// the spectrogram in one call. For processing multiple signals with the same
2883 /// parameters, use [`SpectrogramPlanner::mel_plan`] to create a reusable plan.
2884 ///
2885 /// # Arguments
2886 ///
2887 /// * `samples` - Audio samples (any type that can be converted to a slice)
2888 /// * `params` - Spectrogram computation parameters
2889 /// * `mel` - Mel filterbank parameters
2890 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2891 ///
2892 /// # Returns
2893 ///
2894 /// A mel-frequency spectrogram with the specified amplitude scale.
2895 ///
2896 /// # Errors
2897 ///
2898 /// Returns an error if:
2899 /// - The samples slice is empty
2900 /// - Parameters are invalid
2901 /// - Mel `f_max` exceeds Nyquist frequency
2902 /// - FFT computation fails
2903 ///
2904 /// # Examples
2905 ///
2906 /// ```
2907 /// use spectrograms::*;
2908 /// use non_empty_slice::non_empty_vec;
2909 ///
2910 /// # fn example() -> SpectrogramResult<()> {
2911 /// // Create a simple test signal
2912 /// let sample_rate = 16000.0;
2913 /// let samples_vec: Vec<f64> = (0..16000).map(|i| {
2914 /// (2.0 * std::f64::consts::PI * 440.0 * i as f64 / sample_rate).sin()
2915 /// }).collect();
2916 /// let samples = non_empty_slice::NonEmptyVec::new(samples_vec).unwrap();
2917 ///
2918 /// // Set up parameters
2919 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2920 /// let params = SpectrogramParams::new(stft, sample_rate)?;
2921 /// let mel = MelParams::new(nzu!(80), 0.0, 8000.0)?;
2922 ///
2923 /// // Compute mel spectrogram in dB scale
2924 /// let db = LogParams::new(-80.0)?;
2925 /// let spec = MelDbSpectrogram::compute(&samples, ¶ms, &mel, Some(&db))?;
2926 ///
2927 /// println!("Computed mel spectrogram: {} mels x {} frames", spec.n_bins(), spec.n_frames());
2928 /// # Ok(())
2929 /// # }
2930 /// ```
2931 #[inline]
2932 pub fn compute(
2933 samples: &NonEmptySlice<f64>,
2934 params: &SpectrogramParams,
2935 mel: &MelParams,
2936 db: Option<&LogParams>,
2937 ) -> SpectrogramResult<Self> {
2938 let planner = SpectrogramPlanner::new();
2939 let mut plan = planner.mel_plan(params, mel, db)?;
2940 plan.compute(samples)
2941 }
2942}
2943
2944impl<AmpScale> Spectrogram<Erb, AmpScale>
2945where
2946 AmpScale: AmpScaleSpec + 'static,
2947{
2948 /// Compute an ERB-frequency spectrogram from audio samples.
2949 ///
2950 /// This is a convenience method that creates a planner internally and computes
2951 /// the spectrogram in one call. For processing multiple signals with the same
2952 /// parameters, use [`SpectrogramPlanner::erb_plan`] to create a reusable plan.
2953 ///
2954 /// # Arguments
2955 ///
2956 /// * `samples` - Audio samples (any type that can be converted to a slice)
2957 /// * `params` - Spectrogram computation parameters
2958 /// * `erb` - ERB frequency scale parameters
2959 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2960 ///
2961 /// # Returns
2962 ///
2963 /// An ERB-scale spectrogram with the specified amplitude scale.
2964 ///
2965 /// # Errors
2966 ///
2967 /// Returns an error if:
2968 /// - The samples slice is empty
2969 /// - Parameters are invalid
2970 /// - FFT computation fails
2971 ///
2972 /// # Examples
2973 ///
2974 /// ```
2975 /// use spectrograms::*;
2976 /// use non_empty_slice::non_empty_vec;
2977 ///
2978 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2979 /// let samples = non_empty_vec![0.0; nzu!(16000)];
2980 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2981 /// let params = SpectrogramParams::new(stft, 16000.0)?;
2982 /// let erb = ErbParams::speech_standard();
2983 ///
2984 /// let spec = ErbPowerSpectrogram::compute(&samples, ¶ms, &erb, None)?;
2985 /// assert_eq!(spec.n_bins(), nzu!(40));
2986 /// # Ok(())
2987 /// # }
2988 /// ```
2989 #[inline]
2990 pub fn compute(
2991 samples: &NonEmptySlice<f64>,
2992 params: &SpectrogramParams,
2993 erb: &ErbParams,
2994 db: Option<&LogParams>,
2995 ) -> SpectrogramResult<Self> {
2996 let planner = SpectrogramPlanner::new();
2997 let mut plan = planner.erb_plan(params, erb, db)?;
2998 plan.compute(samples)
2999 }
3000}
3001
3002impl<AmpScale> Spectrogram<LogHz, AmpScale>
3003where
3004 AmpScale: AmpScaleSpec + 'static,
3005{
3006 /// Compute a logarithmic-frequency spectrogram from audio samples.
3007 ///
3008 /// This is a convenience method that creates a planner internally and computes
3009 /// the spectrogram in one call. For processing multiple signals with the same
3010 /// parameters, use [`SpectrogramPlanner::log_hz_plan`] to create a reusable plan.
3011 ///
3012 /// # Arguments
3013 ///
3014 /// * `samples` - Audio samples (any type that can be converted to a slice)
3015 /// * `params` - Spectrogram computation parameters
3016 /// * `loghz` - Logarithmic frequency scale parameters
3017 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
3018 ///
3019 /// # Returns
3020 ///
3021 /// A logarithmic-frequency spectrogram with the specified amplitude scale.
3022 ///
3023 /// # Errors
3024 ///
3025 /// Returns an error if:
3026 /// - The samples slice is empty
3027 /// - Parameters are invalid
3028 /// - FFT computation fails
3029 ///
3030 /// # Examples
3031 ///
3032 /// ```
3033 /// use spectrograms::*;
3034 /// use non_empty_slice::non_empty_vec;
3035 ///
3036 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
3037 /// let samples = non_empty_vec![0.0; nzu!(16000)];
3038 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
3039 /// let params = SpectrogramParams::new(stft, 16000.0)?;
3040 /// let loghz = LogHzParams::new(nzu!(128), 20.0, 8000.0)?;
3041 ///
3042 /// let spec = LogHzPowerSpectrogram::compute(&samples, ¶ms, &loghz, None)?;
3043 /// assert_eq!(spec.n_bins(), nzu!(128));
3044 /// # Ok(())
3045 /// # }
3046 /// ```
3047 #[inline]
3048 pub fn compute(
3049 samples: &NonEmptySlice<f64>,
3050 params: &SpectrogramParams,
3051 loghz: &LogHzParams,
3052 db: Option<&LogParams>,
3053 ) -> SpectrogramResult<Self> {
3054 let planner = SpectrogramPlanner::new();
3055 let mut plan = planner.log_hz_plan(params, loghz, db)?;
3056 plan.compute(samples)
3057 }
3058}
3059
3060impl<AmpScale> Spectrogram<Cqt, AmpScale>
3061where
3062 AmpScale: AmpScaleSpec + 'static,
3063{
3064 /// Compute a constant-Q transform (CQT) spectrogram from audio samples.
3065 ///
3066 /// # Arguments
3067 ///
3068 /// * `samples` - Audio samples (any type that can be converted to a slice)
3069 /// * `params` - Spectrogram computation parameters
3070 /// * `cqt` - CQT parameters
3071 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
3072 ///
3073 /// # Returns
3074 ///
3075 /// A CQT spectrogram with the specified amplitude scale.
3076 ///
3077 /// # Errors
3078 ///
3079 /// Returns an error if:
3080 ///
3081 #[inline]
3082 pub fn compute(
3083 samples: &NonEmptySlice<f64>,
3084 params: &SpectrogramParams,
3085 cqt: &CqtParams,
3086 db: Option<&LogParams>,
3087 ) -> SpectrogramResult<Self> {
3088 let planner = SpectrogramPlanner::new();
3089 let mut plan = planner.cqt_plan(params, cqt, db)?;
3090 plan.compute(samples)
3091 }
3092}
3093
3094// ========================
3095// Display implementations
3096// ========================
3097
3098/// Helper function to get amplitude scale name
3099fn amp_scale_name<AmpScale>() -> &'static str
3100where
3101 AmpScale: AmpScaleSpec + 'static,
3102{
3103 match std::any::TypeId::of::<AmpScale>() {
3104 id if id == std::any::TypeId::of::<Power>() => "Power",
3105 id if id == std::any::TypeId::of::<Magnitude>() => "Magnitude",
3106 id if id == std::any::TypeId::of::<Decibels>() => "Decibels",
3107 _ => "Unknown",
3108 }
3109}
3110
3111/// Helper function to get frequency scale name
3112fn freq_scale_name<FreqScale>() -> &'static str
3113where
3114 FreqScale: Copy + Clone + 'static,
3115{
3116 match std::any::TypeId::of::<FreqScale>() {
3117 id if id == std::any::TypeId::of::<LinearHz>() => "Linear Hz",
3118 id if id == std::any::TypeId::of::<LogHz>() => "Log Hz",
3119 id if id == std::any::TypeId::of::<Mel>() => "Mel",
3120 id if id == std::any::TypeId::of::<Erb>() => "ERB",
3121 id if id == std::any::TypeId::of::<Cqt>() => "CQT",
3122 _ => "Unknown",
3123 }
3124}
3125
3126impl<FreqScale, AmpScale> core::fmt::Display for Spectrogram<FreqScale, AmpScale>
3127where
3128 AmpScale: AmpScaleSpec + 'static,
3129 FreqScale: Copy + Clone + 'static,
3130{
3131 #[inline]
3132 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
3133 let (freq_min, freq_max) = self.frequency_range();
3134 let duration = self.duration();
3135 let (rows, cols) = self.data.dim();
3136
3137 // Alternative formatting (#) provides more detailed output with data
3138 if f.alternate() {
3139 writeln!(f, "Spectrogram {{")?;
3140 writeln!(f, " Frequency Scale: {}", freq_scale_name::<FreqScale>())?;
3141 writeln!(f, " Amplitude Scale: {}", amp_scale_name::<AmpScale>())?;
3142 writeln!(f, " Shape: {rows} frequency bins × {cols} time frames")?;
3143 writeln!(f, " Frequency Range: {freq_min:.2} Hz - {freq_max:.2} Hz")?;
3144 writeln!(f, " Duration: {duration:.3} s")?;
3145 writeln!(f)?;
3146
3147 // Display parameters
3148 writeln!(f, " Parameters:")?;
3149 writeln!(f, " Sample Rate: {} Hz", self.params.sample_rate_hz())?;
3150 writeln!(f, " FFT Size: {}", self.params.stft().n_fft())?;
3151 writeln!(f, " Hop Size: {}", self.params.stft().hop_size())?;
3152 writeln!(f, " Window: {:?}", self.params.stft().window())?;
3153 writeln!(f, " Centered: {}", self.params.stft().centre())?;
3154 writeln!(f)?;
3155
3156 // Display data statistics
3157 let data_slice = self.data.as_slice().unwrap_or(&[]);
3158 if !data_slice.is_empty() {
3159 let (min_val, max_val) = min_max_single_pass(data_slice);
3160 let mean = data_slice.iter().sum::<f64>() / data_slice.len() as f64;
3161 writeln!(f, " Data Statistics:")?;
3162 writeln!(f, " Min: {min_val:.6}")?;
3163 writeln!(f, " Max: {max_val:.6}")?;
3164 writeln!(f, " Mean: {mean:.6}")?;
3165 writeln!(f)?;
3166 }
3167
3168 // Display actual data (truncated if too large)
3169 writeln!(f, " Data Matrix:")?;
3170 let max_rows_to_display = 5;
3171 let max_cols_to_display = 5;
3172
3173 for i in 0..rows.min(max_rows_to_display) {
3174 write!(f, " [")?;
3175 for j in 0..cols.min(max_cols_to_display) {
3176 if j > 0 {
3177 write!(f, ", ")?;
3178 }
3179 write!(f, "{:9.4}", self.data[[i, j]])?;
3180 }
3181 if cols > max_cols_to_display {
3182 write!(f, ", ... ({} more)", cols - max_cols_to_display)?;
3183 }
3184 writeln!(f, "]")?;
3185 }
3186
3187 if rows > max_rows_to_display {
3188 writeln!(f, " ... ({} more rows)", rows - max_rows_to_display)?;
3189 }
3190
3191 write!(f, "}}")?;
3192 } else {
3193 // Default formatting: compact summary
3194 write!(
3195 f,
3196 "Spectrogram<{}, {}>[{}x{}] ({:.2}-{:.2} Hz, {:.3}s)",
3197 freq_scale_name::<FreqScale>(),
3198 amp_scale_name::<AmpScale>(),
3199 rows,
3200 cols,
3201 freq_min,
3202 freq_max,
3203 duration
3204 )?;
3205 }
3206
3207 Ok(())
3208 }
3209}
3210
3211#[derive(Debug, Clone)]
3212#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3213pub struct FrequencyAxis<FreqScale>
3214where
3215 FreqScale: Copy + Clone + 'static,
3216{
3217 frequencies: NonEmptyVec<f64>,
3218 #[cfg_attr(feature = "serde", serde(skip))]
3219 _marker: PhantomData<FreqScale>,
3220}
3221
3222impl<FreqScale> FrequencyAxis<FreqScale>
3223where
3224 FreqScale: Copy + Clone + 'static,
3225{
3226 pub(crate) const fn new(frequencies: NonEmptyVec<f64>) -> Self {
3227 Self {
3228 frequencies,
3229 _marker: PhantomData,
3230 }
3231 }
3232
3233 /// Get the frequency values in Hz.
3234 ///
3235 /// # Returns
3236 ///
3237 /// Returns a non-empty slice of frequencies.
3238 #[inline]
3239 #[must_use]
3240 pub fn frequencies(&self) -> &NonEmptySlice<f64> {
3241 &self.frequencies
3242 }
3243
3244 /// Get the frequency range (min, max) in Hz.
3245 ///
3246 /// # Returns
3247 ///
3248 /// Returns a tuple containing the minimum and maximum frequency.
3249 #[inline]
3250 #[must_use]
3251 pub const fn frequency_range(&self) -> (f64, f64) {
3252 let data = self.frequencies.as_slice();
3253 let min = data[0];
3254 let max_idx = data.len().saturating_sub(1); // safe for non-empty
3255 let max = data[max_idx];
3256 (min, max)
3257 }
3258
3259 /// Get the number of frequency bins.
3260 ///
3261 /// # Returns
3262 ///
3263 /// Returns the number of frequency bins as a NonZeroUsize.
3264 #[inline]
3265 #[must_use]
3266 pub const fn len(&self) -> NonZeroUsize {
3267 self.frequencies.len()
3268 }
3269}
3270
3271/// Spectrogram axes container.
3272///
3273/// Holds frequency and time axes.
3274#[derive(Debug, Clone)]
3275#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3276pub struct Axes<FreqScale>
3277where
3278 FreqScale: Copy + Clone + 'static,
3279{
3280 freq: FrequencyAxis<FreqScale>,
3281 times: NonEmptyVec<f64>,
3282}
3283
3284impl<FreqScale> Axes<FreqScale>
3285where
3286 FreqScale: Copy + Clone + 'static,
3287{
3288 pub(crate) const fn new(freq: FrequencyAxis<FreqScale>, times: NonEmptyVec<f64>) -> Self {
3289 Self { freq, times }
3290 }
3291
3292 /// Get the frequency values in Hz.
3293 ///
3294 /// # Returns
3295 ///
3296 /// Returns a non-empty slice of frequencies.
3297 #[inline]
3298 #[must_use]
3299 pub fn frequencies(&self) -> &NonEmptySlice<f64> {
3300 self.freq.frequencies()
3301 }
3302
3303 /// Get the time values in seconds.
3304 ///
3305 /// # Returns
3306 ///
3307 /// Returns a non-empty slice of time values.
3308 #[inline]
3309 #[must_use]
3310 pub fn times(&self) -> &NonEmptySlice<f64> {
3311 &self.times
3312 }
3313
3314 /// Get the frequency range (min, max) in Hz.
3315 ///
3316 /// # Returns
3317 ///
3318 /// Returns a tuple containing the minimum and maximum frequency.
3319 #[inline]
3320 #[must_use]
3321 pub const fn frequency_range(&self) -> (f64, f64) {
3322 self.freq.frequency_range()
3323 }
3324
3325 /// Get the duration of the spectrogram in seconds.
3326 ///
3327 /// # Returns
3328 ///
3329 /// Returns the duration in seconds.
3330 #[inline]
3331 #[must_use]
3332 pub fn duration(&self) -> f64 {
3333 *self.times.last()
3334 }
3335}
3336
3337// Enum types for frequency and amplitude scales
3338
3339/// Linear frequency scale
3340#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3341#[cfg_attr(feature = "python", pyclass(from_py_object))]
3342#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3343#[non_exhaustive]
3344pub enum LinearHz {
3345 _Phantom,
3346}
3347
3348/// Logarithmic frequency scale
3349#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3350#[cfg_attr(feature = "python", pyclass(from_py_object))]
3351#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3352#[non_exhaustive]
3353pub enum LogHz {
3354 _Phantom,
3355}
3356
3357/// Mel frequency scale
3358#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3359#[cfg_attr(feature = "python", pyclass(from_py_object))]
3360#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3361#[non_exhaustive]
3362pub enum Mel {
3363 _Phantom,
3364}
3365
3366/// ERB/gammatone frequency scale
3367#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3368#[cfg_attr(feature = "python", pyclass(from_py_object))]
3369#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3370#[non_exhaustive]
3371pub enum Erb {
3372 _Phantom,
3373}
3374pub type Gammatone = Erb;
3375
3376/// Constant-Q Transform frequency scale
3377#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3378#[cfg_attr(feature = "python", pyclass(from_py_object))]
3379#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3380#[non_exhaustive]
3381pub enum Cqt {
3382 _Phantom,
3383}
3384
3385// Amplitude scales
3386
3387/// Power amplitude scale
3388#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3389#[cfg_attr(feature = "python", pyclass(from_py_object))]
3390#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3391#[non_exhaustive]
3392pub enum Power {
3393 _Phantom,
3394}
3395
3396/// Decibel amplitude scale
3397#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3398#[cfg_attr(feature = "python", pyclass(from_py_object))]
3399#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3400#[non_exhaustive]
3401pub enum Decibels {
3402 _Phantom,
3403}
3404
3405/// Magnitude amplitude scale
3406#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3407#[cfg_attr(feature = "python", pyclass(from_py_object))]
3408#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3409#[non_exhaustive]
3410pub enum Magnitude {
3411 _Phantom,
3412}
3413
3414/// STFT parameters for spectrogram computation.
3415///
3416/// * `n_fft`: Size of the FFT window.
3417/// * `hop_size`: Number of samples between successive frames.
3418/// * window: Window function to apply to each frame.
3419/// * centre: Whether to pad the input signal so that frames are centered.
3420#[derive(Debug, Clone, PartialEq)]
3421#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3422pub struct StftParams {
3423 n_fft: NonZeroUsize,
3424 hop_size: NonZeroUsize,
3425 window: WindowType,
3426 centre: bool,
3427}
3428
3429impl StftParams {
3430 /// Create new STFT parameters.
3431 ///
3432 /// # Arguments
3433 ///
3434 /// * `n_fft` - Size of the FFT window
3435 /// * `hop_size` - Number of samples between successive frames
3436 /// * `window` - Window function to apply to each frame
3437 /// * `centre` - Whether to pad the input signal so that frames are centered
3438 ///
3439 /// # Errors
3440 ///
3441 /// Returns an error if:
3442 /// - `hop_size` > `n_fft`
3443 /// - Custom window size doesn't match `n_fft`
3444 ///
3445 /// # Returns
3446 ///
3447 /// New `StftParams` instance.
3448 #[inline]
3449 pub fn new(
3450 n_fft: NonZeroUsize,
3451 hop_size: NonZeroUsize,
3452 window: WindowType,
3453 centre: bool,
3454 ) -> SpectrogramResult<Self> {
3455 if hop_size.get() > n_fft.get() {
3456 return Err(SpectrogramError::invalid_input("hop_size must be <= n_fft"));
3457 }
3458
3459 // Validate custom window size matches n_fft
3460 if let WindowType::Custom { size, .. } = &window {
3461 if size.get() != n_fft.get() {
3462 return Err(SpectrogramError::invalid_input(format!(
3463 "Custom window size ({}) must match n_fft ({})",
3464 size.get(),
3465 n_fft.get()
3466 )));
3467 }
3468 }
3469
3470 Ok(Self {
3471 n_fft,
3472 hop_size,
3473 window,
3474 centre,
3475 })
3476 }
3477
3478 const unsafe fn new_unchecked(
3479 n_fft: NonZeroUsize,
3480 hop_size: NonZeroUsize,
3481 window: WindowType,
3482 centre: bool,
3483 ) -> Self {
3484 Self {
3485 n_fft,
3486 hop_size,
3487 window,
3488 centre,
3489 }
3490 }
3491
3492 /// Get the FFT window size.
3493 ///
3494 /// # Returns
3495 ///
3496 /// The FFT window size.
3497 #[inline]
3498 #[must_use]
3499 pub const fn n_fft(&self) -> NonZeroUsize {
3500 self.n_fft
3501 }
3502
3503 /// Get the hop size (samples between successive frames).
3504 ///
3505 /// # Returns
3506 ///
3507 /// The hop size.
3508 #[inline]
3509 #[must_use]
3510 pub const fn hop_size(&self) -> NonZeroUsize {
3511 self.hop_size
3512 }
3513
3514 /// Get the window function.
3515 ///
3516 /// # Returns
3517 ///
3518 /// The window function.
3519 #[inline]
3520 #[must_use]
3521 pub fn window(&self) -> WindowType {
3522 self.window.clone()
3523 }
3524
3525 /// Get whether frames are centered (input signal is padded).
3526 ///
3527 /// # Returns
3528 ///
3529 /// `true` if frames are centered, `false` otherwise.
3530 #[inline]
3531 #[must_use]
3532 pub const fn centre(&self) -> bool {
3533 self.centre
3534 }
3535
3536 /// Create a builder for STFT parameters.
3537 ///
3538 /// # Returns
3539 ///
3540 /// A `StftParamsBuilder` instance.
3541 ///
3542 /// # Examples
3543 ///
3544 /// ```
3545 /// use spectrograms::{StftParams, WindowType, nzu};
3546 ///
3547 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
3548 /// let stft = StftParams::builder()
3549 /// .n_fft(nzu!(2048))
3550 /// .hop_size(nzu!(512))
3551 /// .window(WindowType::Hanning)
3552 /// .centre(true)
3553 /// .build()?;
3554 ///
3555 /// assert_eq!(stft.n_fft(), nzu!(2048));
3556 /// assert_eq!(stft.hop_size(), nzu!(512));
3557 /// # Ok(())
3558 /// # }
3559 /// ```
3560 #[inline]
3561 #[must_use]
3562 pub fn builder() -> StftParamsBuilder {
3563 StftParamsBuilder::default()
3564 }
3565}
3566
3567/// Builder for [`StftParams`].
3568#[derive(Debug, Clone)]
3569pub struct StftParamsBuilder {
3570 n_fft: Option<NonZeroUsize>,
3571 hop_size: Option<NonZeroUsize>,
3572 window: WindowType,
3573 centre: bool,
3574}
3575
3576impl Default for StftParamsBuilder {
3577 #[inline]
3578 fn default() -> Self {
3579 Self {
3580 n_fft: None,
3581 hop_size: None,
3582 window: WindowType::Hanning,
3583 centre: true,
3584 }
3585 }
3586}
3587
3588impl StftParamsBuilder {
3589 /// Set the FFT window size.
3590 ///
3591 /// # Arguments
3592 ///
3593 /// * `n_fft` - Size of the FFT window
3594 ///
3595 /// # Returns
3596 ///
3597 /// The builder with the updated FFT window size.
3598 #[inline]
3599 #[must_use]
3600 pub const fn n_fft(mut self, n_fft: NonZeroUsize) -> Self {
3601 self.n_fft = Some(n_fft);
3602 self
3603 }
3604
3605 /// Set the hop size (samples between successive frames).
3606 ///
3607 /// # Arguments
3608 ///
3609 /// * `hop_size` - Number of samples between successive frames
3610 ///
3611 /// # Returns
3612 ///
3613 /// The builder with the updated hop size.
3614 #[inline]
3615 #[must_use]
3616 pub const fn hop_size(mut self, hop_size: NonZeroUsize) -> Self {
3617 self.hop_size = Some(hop_size);
3618 self
3619 }
3620
3621 /// Set the window function.
3622 ///
3623 /// # Arguments
3624 ///
3625 /// * `window` - Window function to apply to each frame
3626 ///
3627 /// # Returns
3628 ///
3629 /// The builder with the updated window function.
3630 #[inline]
3631 #[must_use]
3632 pub fn window(mut self, window: WindowType) -> Self {
3633 self.window = window;
3634 self
3635 }
3636
3637 /// Set whether to center frames (pad input signal).
3638 #[inline]
3639 #[must_use]
3640 pub const fn centre(mut self, centre: bool) -> Self {
3641 self.centre = centre;
3642 self
3643 }
3644
3645 /// Build the [`StftParams`].
3646 ///
3647 /// # Errors
3648 ///
3649 /// Returns an error if:
3650 /// - `n_fft` or `hop_size` are not set or are zero
3651 /// - `hop_size` > `n_fft`
3652 #[inline]
3653 pub fn build(self) -> SpectrogramResult<StftParams> {
3654 let n_fft = self
3655 .n_fft
3656 .ok_or_else(|| SpectrogramError::invalid_input("n_fft must be set"))?;
3657 let hop_size = self
3658 .hop_size
3659 .ok_or_else(|| SpectrogramError::invalid_input("hop_size must be set"))?;
3660
3661 StftParams::new(n_fft, hop_size, self.window, self.centre)
3662 }
3663}
3664
3665//
3666// ========================
3667// Mel parameters
3668// ========================
3669//
3670
3671/// Mel filterbank normalization strategy.
3672///
3673/// Determines how the triangular mel filters are normalized after construction.
3674#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3675#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3676#[non_exhaustive]
3677#[derive(Default)]
3678pub enum MelNorm {
3679 /// No normalization (triangular filters with peak = 1.0).
3680 ///
3681 /// This is the default and fastest option.
3682 #[default]
3683 None,
3684
3685 /// Slaney-style area normalization (librosa default).
3686 ///
3687 /// Each mel filter is divided by its bandwidth in Hz: `2.0 / (f_max - f_min)`.
3688 /// This ensures constant energy per mel band regardless of bandwidth.
3689 ///
3690 /// Use this for compatibility with librosa's default behavior.
3691 Slaney,
3692
3693 /// L1 normalization (sum of weights = 1.0).
3694 ///
3695 /// Each mel filter's weights are divided by their sum.
3696 /// Useful when you want each filter to act as a weighted average.
3697 L1,
3698
3699 /// L2 normalization (Euclidean norm = 1.0).
3700 ///
3701 /// Each mel filter's weights are divided by their L2 norm.
3702 /// Provides unit-norm filters in the L2 sense.
3703 L2,
3704}
3705
3706/// Mel filter bank parameters
3707///
3708/// * `n_mels`: Number of mel bands
3709/// * `f_min`: Minimum frequency (Hz)
3710/// * `f_max`: Maximum frequency (Hz)
3711/// * `norm`: Filterbank normalization strategy
3712#[derive(Debug, Clone, Copy, PartialEq)]
3713#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3714pub struct MelParams {
3715 n_mels: NonZeroUsize,
3716 f_min: f64,
3717 f_max: f64,
3718 norm: MelNorm,
3719}
3720
3721impl MelParams {
3722 /// Create new mel filter bank parameters.
3723 ///
3724 /// # Arguments
3725 ///
3726 /// * `n_mels` - Number of mel bands
3727 /// * `f_min` - Minimum frequency (Hz)
3728 /// * `f_max` - Maximum frequency (Hz)
3729 ///
3730 /// # Errors
3731 ///
3732 /// Returns an error if:
3733 /// - `f_min` is not >= 0
3734 /// - `f_max` is not > `f_min`
3735 ///
3736 /// # Returns
3737 ///
3738 /// A `MelParams` instance with no normalization (default).
3739 #[inline]
3740 pub fn new(n_mels: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
3741 Self::with_norm(n_mels, f_min, f_max, MelNorm::None)
3742 }
3743
3744 /// Create new mel filter bank parameters with specified normalization.
3745 ///
3746 /// # Arguments
3747 ///
3748 /// * `n_mels` - Number of mel bands
3749 /// * `f_min` - Minimum frequency (Hz)
3750 /// * `f_max` - Maximum frequency (Hz)
3751 /// * `norm` - Filterbank normalization strategy
3752 ///
3753 /// # Errors
3754 ///
3755 /// Returns an error if:
3756 /// - `f_min` is not >= 0
3757 /// - `f_max` is not > `f_min`
3758 ///
3759 /// # Returns
3760 ///
3761 /// A `MelParams` instance.
3762 #[inline]
3763 pub fn with_norm(
3764 n_mels: NonZeroUsize,
3765 f_min: f64,
3766 f_max: f64,
3767 norm: MelNorm,
3768 ) -> SpectrogramResult<Self> {
3769 if f_min < 0.0 {
3770 return Err(SpectrogramError::invalid_input("f_min must be >= 0"));
3771 }
3772
3773 if f_max <= f_min {
3774 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
3775 }
3776
3777 Ok(Self {
3778 n_mels,
3779 f_min,
3780 f_max,
3781 norm,
3782 })
3783 }
3784
3785 const unsafe fn new_unchecked(n_mels: NonZeroUsize, f_min: f64, f_max: f64) -> Self {
3786 Self {
3787 n_mels,
3788 f_min,
3789 f_max,
3790 norm: MelNorm::None,
3791 }
3792 }
3793
3794 /// Get the number of mel bands.
3795 ///
3796 /// # Returns
3797 ///
3798 /// The number of mel bands.
3799 #[inline]
3800 #[must_use]
3801 pub const fn n_mels(&self) -> NonZeroUsize {
3802 self.n_mels
3803 }
3804
3805 /// Get the minimum frequency (Hz).
3806 ///
3807 /// # Returns
3808 ///
3809 /// The minimum frequency in Hz.
3810 #[inline]
3811 #[must_use]
3812 pub const fn f_min(&self) -> f64 {
3813 self.f_min
3814 }
3815
3816 /// Get the maximum frequency (Hz).
3817 ///
3818 /// # Returns
3819 ///
3820 /// The maximum frequency in Hz.
3821 #[inline]
3822 #[must_use]
3823 pub const fn f_max(&self) -> f64 {
3824 self.f_max
3825 }
3826
3827 /// Get the filterbank normalization strategy.
3828 ///
3829 /// # Returns
3830 ///
3831 /// The normalization strategy.
3832 #[inline]
3833 #[must_use]
3834 pub const fn norm(&self) -> MelNorm {
3835 self.norm
3836 }
3837
3838 /// Create standard mel filterbank parameters.
3839 ///
3840 /// Uses 128 mel bands from 0 Hz to the Nyquist frequency.
3841 ///
3842 /// # Arguments
3843 ///
3844 /// * `sample_rate` - Sample rate in Hz (used to determine `f_max`)
3845 ///
3846 /// # Returns
3847 ///
3848 /// A `MelParams` instance with standard settings.
3849 ///
3850 /// # Panics
3851 ///
3852 /// Panics if `sample_rate` is not greater than 0.
3853 #[inline]
3854 #[must_use]
3855 pub const fn standard(sample_rate: f64) -> Self {
3856 assert!(sample_rate > 0.0);
3857 // safety: parameters are known to be valid
3858 unsafe { Self::new_unchecked(nzu!(128), 0.0, sample_rate / 2.0) }
3859 }
3860
3861 /// Create mel filterbank parameters optimized for speech.
3862 ///
3863 /// Uses 40 mel bands from 0 Hz to 8000 Hz (typical speech bandwidth).
3864 ///
3865 /// # Returns
3866 ///
3867 /// A `MelParams` instance with speech-optimized settings.
3868 #[inline]
3869 #[must_use]
3870 pub const fn speech_standard() -> Self {
3871 // safety: parameters are known to be valid
3872 unsafe { Self::new_unchecked(nzu!(40), 0.0, 8000.0) }
3873 }
3874}
3875
3876//
3877// ========================
3878// LogHz parameters
3879// ========================
3880//
3881
3882/// Logarithmic frequency scale parameters
3883///
3884/// * `n_bins`: Number of logarithmically-spaced frequency bins
3885/// * `f_min`: Minimum frequency (Hz)
3886/// * `f_max`: Maximum frequency (Hz)
3887#[derive(Debug, Clone, Copy, PartialEq)]
3888#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3889pub struct LogHzParams {
3890 n_bins: NonZeroUsize,
3891 f_min: f64,
3892 f_max: f64,
3893}
3894
3895impl LogHzParams {
3896 /// Create new logarithmic frequency scale parameters.
3897 ///
3898 /// # Arguments
3899 ///
3900 /// * `n_bins` - Number of logarithmically-spaced frequency bins
3901 /// * `f_min` - Minimum frequency (Hz)
3902 /// * `f_max` - Maximum frequency (Hz)
3903 ///
3904 /// # Errors
3905 ///
3906 /// Returns an error if:
3907 /// - `f_min` is not finite and > 0
3908 /// - `f_max` is not > `f_min`
3909 ///
3910 /// # Returns
3911 ///
3912 /// A `LogHzParams` instance.
3913 #[inline]
3914 pub fn new(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
3915 if !(f_min > 0.0 && f_min.is_finite()) {
3916 return Err(SpectrogramError::invalid_input(
3917 "f_min must be finite and > 0",
3918 ));
3919 }
3920
3921 if f_max <= f_min {
3922 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
3923 }
3924
3925 Ok(Self {
3926 n_bins,
3927 f_min,
3928 f_max,
3929 })
3930 }
3931
3932 const unsafe fn new_unchecked(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> Self {
3933 Self {
3934 n_bins,
3935 f_min,
3936 f_max,
3937 }
3938 }
3939
3940 /// Get the number of frequency bins.
3941 ///
3942 /// # Returns
3943 ///
3944 /// The number of frequency bins.
3945 #[inline]
3946 #[must_use]
3947 pub const fn n_bins(&self) -> NonZeroUsize {
3948 self.n_bins
3949 }
3950
3951 /// Get the minimum frequency (Hz).
3952 ///
3953 /// # Returns
3954 ///
3955 /// The minimum frequency in Hz.
3956 #[inline]
3957 #[must_use]
3958 pub const fn f_min(&self) -> f64 {
3959 self.f_min
3960 }
3961
3962 /// Get the maximum frequency (Hz).
3963 ///
3964 /// # Returns
3965 ///
3966 /// The maximum frequency in Hz.
3967 #[inline]
3968 #[must_use]
3969 pub const fn f_max(&self) -> f64 {
3970 self.f_max
3971 }
3972
3973 /// Create standard logarithmic frequency parameters.
3974 ///
3975 /// Uses 128 log bins from 20 Hz to the Nyquist frequency.
3976 ///
3977 /// # Arguments
3978 ///
3979 /// * `sample_rate` - Sample rate in Hz (used to determine `f_max`)
3980 #[inline]
3981 #[must_use]
3982 pub fn standard(sample_rate: f64) -> Self {
3983 // safety: parameters are known to be valid
3984 unsafe { Self::new_unchecked(nzu!(128), 20.0, sample_rate / 2.0) }
3985 }
3986
3987 /// Create logarithmic frequency parameters optimized for music.
3988 ///
3989 /// Uses 84 bins (7 octaves * 12 bins/octave) from 27.5 Hz (A0) to 4186 Hz (C8).
3990 #[inline]
3991 #[must_use]
3992 pub const fn music_standard() -> Self {
3993 // safety: parameters are known to be valid
3994 unsafe { Self::new_unchecked(nzu!(84), 27.5, 4186.0) }
3995 }
3996}
3997
3998//
3999// ========================
4000// Log scaling parameters
4001// ========================
4002//
4003
4004#[derive(Debug, Clone, Copy, PartialEq)]
4005#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
4006pub struct LogParams {
4007 floor_db: f64,
4008}
4009
4010impl LogParams {
4011 /// Create new logarithmic scaling parameters.
4012 ///
4013 /// # Arguments
4014 ///
4015 /// * `floor_db` - Minimum dB value (floor) for logarithmic scaling
4016 ///
4017 /// # Errors
4018 ///
4019 /// Returns an error if `floor_db` is not finite.
4020 ///
4021 /// # Returns
4022 ///
4023 /// A `LogParams` instance.
4024 #[inline]
4025 pub fn new(floor_db: f64) -> SpectrogramResult<Self> {
4026 if !floor_db.is_finite() {
4027 return Err(SpectrogramError::invalid_input("floor_db must be finite"));
4028 }
4029
4030 Ok(Self { floor_db })
4031 }
4032
4033 /// Get the floor dB value.
4034 #[inline]
4035 #[must_use]
4036 pub const fn floor_db(&self) -> f64 {
4037 self.floor_db
4038 }
4039}
4040
4041/// Spectrogram computation parameters.
4042///
4043/// * `stft`: STFT parameters
4044/// * `sample_rate_hz`: Sample rate in Hz
4045#[derive(Debug, Clone)]
4046#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
4047pub struct SpectrogramParams {
4048 stft: StftParams,
4049 sample_rate_hz: f64,
4050}
4051
4052impl SpectrogramParams {
4053 /// Create new spectrogram parameters.
4054 ///
4055 /// # Arguments
4056 ///
4057 /// * `stft` - STFT parameters
4058 /// * `sample_rate_hz` - Sample rate in Hz
4059 ///
4060 /// # Errors
4061 ///
4062 /// Returns an error if the sample rate is not positive and finite.
4063 ///
4064 /// # Returns
4065 ///
4066 /// A `SpectrogramParams` instance.
4067 #[inline]
4068 pub fn new(stft: StftParams, sample_rate_hz: f64) -> SpectrogramResult<Self> {
4069 if !(sample_rate_hz > 0.0 && sample_rate_hz.is_finite()) {
4070 return Err(SpectrogramError::invalid_input(
4071 "sample_rate_hz must be finite and > 0",
4072 ));
4073 }
4074
4075 Ok(Self {
4076 stft,
4077 sample_rate_hz,
4078 })
4079 }
4080
4081 /// Create a builder for spectrogram parameters.
4082 ///
4083 /// # Errors
4084 ///
4085 /// Returns an error if required parameters are not set or are invalid.
4086 ///
4087 /// # Returns
4088 ///
4089 /// A builder for [`SpectrogramParams`].
4090 ///
4091 /// # Examples
4092 ///
4093 /// ```
4094 /// use spectrograms::{SpectrogramParams, WindowType, nzu};
4095 ///
4096 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4097 /// let params = SpectrogramParams::builder()
4098 /// .sample_rate(16000.0)
4099 /// .n_fft(nzu!(512))
4100 /// .hop_size(nzu!(256))
4101 /// .window(WindowType::Hanning)
4102 /// .centre(true)
4103 /// .build()?;
4104 ///
4105 /// assert_eq!(params.sample_rate_hz(), 16000.0);
4106 /// # Ok(())
4107 /// # }
4108 /// ```
4109 #[inline]
4110 #[must_use]
4111 pub fn builder() -> SpectrogramParamsBuilder {
4112 SpectrogramParamsBuilder::default()
4113 }
4114
4115 /// Create default parameters for speech processing.
4116 ///
4117 /// # Arguments
4118 ///
4119 /// * `sample_rate_hz` - Sample rate in Hz
4120 ///
4121 /// # Returns
4122 ///
4123 /// A `SpectrogramParams` instance with default settings for music analysis.
4124 ///
4125 /// # Errors
4126 ///
4127 /// Returns an error if the sample rate is not positive and finite.
4128 ///
4129 /// Uses:
4130 /// - `n_fft`: 512 (32ms at 16kHz)
4131 /// - `hop_size`: 160 (10ms at 16kHz)
4132 /// - window: Hanning
4133 /// - centre: true
4134 #[inline]
4135 pub fn speech_default(sample_rate_hz: f64) -> SpectrogramResult<Self> {
4136 // safety: parameters are known to be valid
4137 let stft =
4138 unsafe { StftParams::new_unchecked(nzu!(512), nzu!(160), WindowType::Hanning, true) };
4139
4140 Self::new(stft, sample_rate_hz)
4141 }
4142
4143 /// Create default parameters for music processing.
4144 ///
4145 /// # Arguments
4146 ///
4147 /// * `sample_rate_hz` - Sample rate in Hz
4148 ///
4149 /// # Returns
4150 ///
4151 /// A `SpectrogramParams` instance with default settings for music analysis.
4152 ///
4153 /// # Errors
4154 ///
4155 /// Returns an error if the sample rate is not positive and finite.
4156 ///
4157 /// Uses:
4158 /// - `n_fft`: 2048 (46ms at 44.1kHz)
4159 /// - `hop_size`: 512 (11.6ms at 44.1kHz)
4160 /// - window: Hanning
4161 /// - centre: true
4162 #[inline]
4163 pub fn music_default(sample_rate_hz: f64) -> SpectrogramResult<Self> {
4164 // safety: parameters are known to be valid
4165 let stft =
4166 unsafe { StftParams::new_unchecked(nzu!(2048), nzu!(512), WindowType::Hanning, true) };
4167 Self::new(stft, sample_rate_hz)
4168 }
4169
4170 /// Get the STFT parameters.
4171 #[inline]
4172 #[must_use]
4173 pub const fn stft(&self) -> &StftParams {
4174 &self.stft
4175 }
4176
4177 /// Get the sample rate in Hz.
4178 #[inline]
4179 #[must_use]
4180 pub const fn sample_rate_hz(&self) -> f64 {
4181 self.sample_rate_hz
4182 }
4183
4184 /// Get the frame period in seconds.
4185 #[inline]
4186 #[must_use]
4187 #[allow(clippy::cast_precision_loss)]
4188 pub fn frame_period_seconds(&self) -> f64 {
4189 self.stft.hop_size().get() as f64 / self.sample_rate_hz
4190 }
4191
4192 /// Get the Nyquist frequency in Hz.
4193 #[inline]
4194 #[must_use]
4195 pub fn nyquist_hz(&self) -> f64 {
4196 self.sample_rate_hz * 0.5
4197 }
4198}
4199
4200/// Builder for [`SpectrogramParams`].
4201#[derive(Debug, Clone)]
4202pub struct SpectrogramParamsBuilder {
4203 sample_rate: Option<f64>,
4204 n_fft: Option<NonZeroUsize>,
4205 hop_size: Option<NonZeroUsize>,
4206 window: WindowType,
4207 centre: bool,
4208}
4209
4210impl Default for SpectrogramParamsBuilder {
4211 #[inline]
4212 fn default() -> Self {
4213 Self {
4214 sample_rate: None,
4215 n_fft: None,
4216 hop_size: None,
4217 window: WindowType::Hanning,
4218 centre: true,
4219 }
4220 }
4221}
4222
4223impl SpectrogramParamsBuilder {
4224 /// Set the sample rate in Hz.
4225 ///
4226 /// # Arguments
4227 ///
4228 /// * `sample_rate` - Sample rate in Hz.
4229 ///
4230 /// # Returns
4231 ///
4232 /// The updated builder instance.
4233 #[inline]
4234 #[must_use]
4235 pub const fn sample_rate(mut self, sample_rate: f64) -> Self {
4236 self.sample_rate = Some(sample_rate);
4237 self
4238 }
4239
4240 /// Set the FFT window size.
4241 ///
4242 /// # Arguments
4243 ///
4244 /// * `n_fft` - FFT size.
4245 ///
4246 /// # Returns
4247 ///
4248 /// The updated builder instance.
4249 #[inline]
4250 #[must_use]
4251 pub const fn n_fft(mut self, n_fft: NonZeroUsize) -> Self {
4252 self.n_fft = Some(n_fft);
4253 self
4254 }
4255
4256 /// Set the hop size (samples between successive frames).
4257 ///
4258 /// # Arguments
4259 ///
4260 /// * `hop_size` - Hop size in samples.
4261 ///
4262 /// # Returns
4263 ///
4264 /// The updated builder instance.
4265 #[inline]
4266 #[must_use]
4267 pub const fn hop_size(mut self, hop_size: NonZeroUsize) -> Self {
4268 self.hop_size = Some(hop_size);
4269 self
4270 }
4271
4272 /// Set the window function.
4273 ///
4274 /// # Arguments
4275 ///
4276 /// * `window` - Window function to apply to each frame.
4277 ///
4278 /// # Returns
4279 ///
4280 /// The updated builder instance.
4281 #[inline]
4282 #[must_use]
4283 pub fn window(mut self, window: WindowType) -> Self {
4284 self.window = window;
4285 self
4286 }
4287
4288 /// Set whether to center frames (pad input signal).
4289 ///
4290 /// # Arguments
4291 ///
4292 /// * `centre` - If true, frames are centered by padding the input signal.
4293 ///
4294 /// # Returns
4295 ///
4296 /// The updated builder instance.
4297 #[inline]
4298 #[must_use]
4299 pub const fn centre(mut self, centre: bool) -> Self {
4300 self.centre = centre;
4301 self
4302 }
4303
4304 /// Build the [`SpectrogramParams`].
4305 ///
4306 /// # Errors
4307 ///
4308 /// Returns an error if required parameters are not set or are invalid.
4309 ///
4310 /// # Returns
4311 ///
4312 /// A `SpectrogramParams` instance.
4313 #[inline]
4314 pub fn build(self) -> SpectrogramResult<SpectrogramParams> {
4315 let sample_rate = self
4316 .sample_rate
4317 .ok_or_else(|| SpectrogramError::invalid_input("sample_rate must be set"))?;
4318 let n_fft = self
4319 .n_fft
4320 .ok_or_else(|| SpectrogramError::invalid_input("n_fft must be set"))?;
4321 let hop_size = self
4322 .hop_size
4323 .ok_or_else(|| SpectrogramError::invalid_input("hop_size must be set"))?;
4324
4325 let stft = StftParams::new(n_fft, hop_size, self.window, self.centre)?;
4326 SpectrogramParams::new(stft, sample_rate)
4327 }
4328}
4329
4330//
4331// ========================
4332// Standalone FFT Functions
4333// ========================
4334//
4335
4336/// Compute the real-to-complex FFT of a real-valued signal.
4337///
4338/// This function performs a forward FFT on real-valued input, returning the
4339/// complex frequency domain representation. Only the positive frequencies
4340/// are returned (length = `n_fft/2` + 1) due to conjugate symmetry.
4341///
4342/// # Arguments
4343///
4344/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4345/// * `n_fft` - FFT size
4346///
4347/// # Returns
4348///
4349/// A vector of complex frequency bins with length `n_fft/2` + 1.
4350///
4351/// # Automatic Zero-Padding
4352///
4353/// If the input signal is shorter than `n_fft`, it will be automatically
4354/// zero-padded to the required length. This is standard DSP practice and
4355/// preserves frequency resolution (bin spacing = sample_rate / n_fft).
4356///
4357/// ```
4358/// use spectrograms::{fft, nzu};
4359/// use non_empty_slice::non_empty_vec;
4360///
4361/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4362/// let signal = non_empty_vec![1.0, 2.0, 3.0]; // Only 3 samples
4363/// let spectrum = fft(&signal, nzu!(8))?; // Automatically padded to 8
4364/// assert_eq!(spectrum.len(), 5); // Output: 8/2 + 1 = 5 bins
4365/// # Ok(())
4366/// # }
4367/// ```
4368///
4369/// # Errors
4370///
4371/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4372///
4373/// # Examples
4374///
4375/// ```
4376/// use spectrograms::*;
4377/// use non_empty_slice::non_empty_vec;
4378///
4379/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4380/// let signal = non_empty_vec![0.0; nzu!(512)];
4381/// let spectrum = fft(&signal, nzu!(512))?;
4382///
4383/// assert_eq!(spectrum.len(), 257); // 512/2 + 1
4384/// # Ok(())
4385/// # }
4386/// ```
4387#[inline]
4388pub fn fft(
4389 samples: &NonEmptySlice<f64>,
4390 n_fft: NonZeroUsize,
4391) -> SpectrogramResult<Array1<Complex<f64>>> {
4392 if samples.len() > n_fft {
4393 return Err(SpectrogramError::invalid_input(format!(
4394 "Input length ({}) exceeds FFT size ({})",
4395 samples.len(),
4396 n_fft
4397 )));
4398 }
4399
4400 let out_len = r2c_output_size(n_fft.get());
4401
4402 // Get FFT plan from global cache (or create if first use)
4403 #[cfg(feature = "realfft")]
4404 let mut fft = {
4405 use crate::fft_backend::get_or_create_r2c_plan;
4406 let plan = get_or_create_r2c_plan(n_fft.get())?;
4407 // Clone the plan to get our own mutable copy with independent scratch buffer
4408 // This is cheap - only clones the scratch buffer, not the expensive twiddle factors
4409 (*plan).clone()
4410 };
4411
4412 #[cfg(feature = "fftw")]
4413 let mut fft = {
4414 use std::sync::Arc;
4415 let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
4416 crate::FftwPlan::new(Arc::new(plan))
4417 };
4418
4419 let input = if samples.len() < n_fft {
4420 let mut padded = vec![0.0; n_fft.get()];
4421 padded[..samples.len().get()].copy_from_slice(samples);
4422 // safety: samples.len() < n_fft checked above and n_fft > 0
4423
4424 unsafe { NonEmptyVec::new_unchecked(padded) }
4425 } else {
4426 samples.to_non_empty_vec()
4427 };
4428
4429 let mut output = vec![Complex::new(0.0, 0.0); out_len];
4430 fft.process(&input, &mut output)?;
4431 let output = Array1::from_vec(output);
4432 Ok(output)
4433}
4434
4435#[inline]
4436/// Compute the real-valued fft of a signal.
4437///
4438/// # Arguments
4439/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4440/// * `n_fft` - FFT size
4441///
4442/// # Returns
4443///
4444/// An array with length `n_fft/2` + 1.
4445///
4446/// # Errors
4447///
4448/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4449///
4450/// # Examples
4451///
4452/// ```
4453/// use spectrograms::*;
4454/// use non_empty_slice::non_empty_vec;
4455///
4456/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4457/// let signal = non_empty_vec![0.0; nzu!(512)];
4458/// let rfft_result = rfft(&signal, nzu!(512))?;
4459/// // equivalent to
4460/// let fft_result = fft(&signal, nzu!(512))?;
4461/// let rfft_result = fft_result.mapv(num_complex::Complex::norm);
4462/// # Ok(())
4463/// # }
4464///
4465pub fn rfft(samples: &NonEmptySlice<f64>, n_fft: NonZeroUsize) -> SpectrogramResult<Array1<f64>> {
4466 Ok(fft(samples, n_fft)?.mapv(Complex::norm))
4467}
4468
4469/// Compute the power spectrum of a signal (|X|²).
4470///
4471/// This function applies an optional window function and computes the
4472/// power spectrum via FFT. The result contains only positive frequencies.
4473///
4474/// # Arguments
4475///
4476/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4477/// * `n_fft` - FFT size
4478/// * `window` - Optional window function (None for rectangular window)
4479///
4480/// # Returns
4481///
4482/// A vector of power values with length `n_fft/2` + 1.
4483///
4484/// # Automatic Zero-Padding
4485///
4486/// If the input signal is shorter than `n_fft`, it will be automatically
4487/// zero-padded to the required length. This is standard DSP practice and
4488/// preserves frequency resolution (bin spacing = sample_rate / n_fft).
4489///
4490/// ```
4491/// use spectrograms::{power_spectrum, nzu};
4492/// use non_empty_slice::non_empty_vec;
4493///
4494/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4495/// let signal = non_empty_vec![1.0, 2.0, 3.0]; // Only 3 samples
4496/// let power = power_spectrum(&signal, nzu!(8), None)?;
4497/// assert_eq!(power.len(), nzu!(5)); // Output: 8/2 + 1 = 5 bins
4498/// # Ok(())
4499/// # }
4500/// ```
4501///
4502/// # Errors
4503///
4504/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4505///
4506/// # Examples
4507///
4508/// ```
4509/// use spectrograms::*;
4510/// use non_empty_slice::non_empty_vec;
4511///
4512/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4513/// let signal = non_empty_vec![0.0; nzu!(512)];
4514/// let power = power_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4515///
4516/// assert_eq!(power.len(), nzu!(257)); // 512/2 + 1
4517/// # Ok(())
4518/// # }
4519/// ```
4520#[inline]
4521pub fn power_spectrum(
4522 samples: &NonEmptySlice<f64>,
4523 n_fft: NonZeroUsize,
4524 window: Option<WindowType>,
4525) -> SpectrogramResult<NonEmptyVec<f64>> {
4526 if samples.len() > n_fft {
4527 return Err(SpectrogramError::invalid_input(format!(
4528 "Input length ({}) exceeds FFT size ({})",
4529 samples.len(),
4530 n_fft
4531 )));
4532 }
4533
4534 let mut windowed = vec![0.0; n_fft.get()];
4535 windowed[..samples.len().get()].copy_from_slice(samples);
4536
4537 if let Some(win_type) = window {
4538 let window_samples = make_window(win_type, n_fft);
4539 for i in 0..n_fft.get() {
4540 windowed[i] *= window_samples[i];
4541 }
4542 }
4543
4544 // safety: windowed is non-empty since n_fft > 0
4545 let windowed = unsafe { NonEmptySlice::new_unchecked(&windowed) };
4546 let fft_result = fft(windowed, n_fft)?;
4547 let fft_result = fft_result
4548 .iter()
4549 .map(num_complex::Complex::norm_sqr)
4550 .collect();
4551 // safety: fft_result is non-empty since fft returned successfully
4552 Ok(unsafe { NonEmptyVec::new_unchecked(fft_result) })
4553}
4554
4555/// Compute the magnitude spectrum of a signal (|X|).
4556///
4557/// This function applies an optional window function and computes the
4558/// magnitude spectrum via FFT. The result contains only positive frequencies.
4559///
4560/// # Arguments
4561///
4562/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4563/// * `n_fft` - FFT size
4564/// * `window` - Optional window function (None for rectangular window)
4565///
4566/// # Automatic Zero-Padding
4567///
4568/// If the input signal is shorter than `n_fft`, it will be automatically
4569/// zero-padded to the required length. This preserves frequency resolution.
4570///
4571/// # Errors
4572///
4573/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4574///
4575/// # Returns
4576///
4577/// A vector of magnitude values with length `n_fft/2` + 1.
4578///
4579/// # Examples
4580///
4581/// ```
4582/// use spectrograms::*;
4583/// use non_empty_slice::non_empty_vec;
4584///
4585/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4586/// let signal = non_empty_vec![0.0; nzu!(512)];
4587/// let magnitude = magnitude_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4588///
4589/// assert_eq!(magnitude.len(), nzu!(257)); // 512/2 + 1
4590/// # Ok(())
4591/// # }
4592/// ```
4593#[inline]
4594pub fn magnitude_spectrum(
4595 samples: &NonEmptySlice<f64>,
4596 n_fft: NonZeroUsize,
4597 window: Option<WindowType>,
4598) -> SpectrogramResult<NonEmptyVec<f64>> {
4599 let power = power_spectrum(samples, n_fft, window)?;
4600 let power = power.iter().map(|&p| p.sqrt()).collect();
4601 // safety: power is non-empty since power_spectrum returned successfully
4602 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
4603}
4604
4605/// Compute the Short-Time Fourier Transform (STFT) of a signal.
4606///
4607/// This function computes the STFT by applying a sliding window and FFT
4608/// to sequential frames of the input signal.
4609///
4610/// # Arguments
4611///
4612/// * `samples` - Input signal (any type that can be converted to a slice)
4613/// * `n_fft` - FFT size
4614/// * `hop_size` - Number of samples between successive frames
4615/// * `window` - Window function to apply to each frame
4616/// * `center` - If true, pad the signal to center frames
4617///
4618/// # Returns
4619///
4620/// A 2D array with shape (`frequency_bins`, `time_frames`) containing complex STFT values.
4621///
4622/// # Errors
4623///
4624/// Returns an error if:
4625/// - `hop_size` > `n_fft`
4626/// - STFT computation fails
4627///
4628/// # Examples
4629///
4630/// ```
4631/// use spectrograms::*;
4632/// use non_empty_slice::non_empty_vec;
4633///
4634/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4635/// let signal = non_empty_vec![0.0; nzu!(16000)];
4636/// let stft_matrix = stft(&signal, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4637///
4638/// println!("STFT: {} bins x {} frames", stft_matrix.nrows(), stft_matrix.ncols());
4639/// # Ok(())
4640/// # }
4641/// ```
4642#[inline]
4643pub fn stft(
4644 samples: &NonEmptySlice<f64>,
4645 n_fft: NonZeroUsize,
4646 hop_size: NonZeroUsize,
4647 window: WindowType,
4648 center: bool,
4649) -> SpectrogramResult<Array2<Complex<f64>>> {
4650 let stft_params = StftParams::new(n_fft, hop_size, window, center)?;
4651 let params = SpectrogramParams::new(stft_params, 1.0)?; // dummy sample rate
4652
4653 let planner = SpectrogramPlanner::new();
4654 let result = planner.compute_stft(samples, ¶ms)?;
4655
4656 Ok(result.data)
4657}
4658
4659/// Compute the inverse real FFT (complex-to-real IFFT).
4660///
4661/// This function performs an inverse FFT, converting frequency domain data
4662/// back to the time domain. Only the positive frequencies need to be provided
4663/// (length = `n_fft/2` + 1) due to conjugate symmetry.
4664///
4665/// # Arguments
4666///
4667/// * `spectrum` - Complex frequency bins (length should be `n_fft/2` + 1)
4668/// * `n_fft` - FFT size (length of the output signal)
4669///
4670/// # Returns
4671///
4672/// A vector of real-valued time-domain samples with length `n_fft`.
4673///
4674/// # Errors
4675///
4676/// Returns an error if:
4677/// - `spectrum` length doesn't match `n_fft/2` + 1
4678/// - Inverse FFT computation fails
4679///
4680/// # Examples
4681///
4682/// ```
4683/// use spectrograms::*;
4684/// use non_empty_slice::{non_empty_vec, NonEmptySlice};
4685/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4686/// // Forward FFT
4687/// let signal = non_empty_vec![1.0, 0.0, -1.0, 0.0, 1.0, 0.0, -1.0, 0.0];
4688/// let spectrum = fft(&signal, nzu!(8))?;
4689/// let slice = spectrum.as_slice().unwrap();
4690/// let spectrum_slice = NonEmptySlice::new(slice).unwrap();
4691/// // Inverse FFT
4692/// let reconstructed = irfft(spectrum_slice, nzu!(8))?;
4693///
4694/// assert_eq!(reconstructed.len(), nzu!(8));
4695/// # Ok(())
4696/// # }
4697/// ```
4698#[inline]
4699pub fn irfft(
4700 spectrum: &NonEmptySlice<Complex<f64>>,
4701 n_fft: NonZeroUsize,
4702) -> SpectrogramResult<NonEmptyVec<f64>> {
4703 use crate::fft_backend::{C2rPlan, r2c_output_size};
4704
4705 let n_fft = n_fft.get();
4706 let expected_len = r2c_output_size(n_fft);
4707 if spectrum.len().get() != expected_len {
4708 return Err(SpectrogramError::dimension_mismatch(
4709 expected_len,
4710 spectrum.len().get(),
4711 ));
4712 }
4713
4714 // Get inverse FFT plan from global cache (or create if first use)
4715 #[cfg(feature = "realfft")]
4716 let mut ifft = {
4717 use crate::fft_backend::get_or_create_c2r_plan;
4718 let plan = get_or_create_c2r_plan(n_fft)?;
4719 // Clone to get our own mutable copy with independent scratch buffer
4720 (*plan).clone()
4721 };
4722
4723 #[cfg(feature = "fftw")]
4724 let mut ifft = {
4725 use crate::fft_backend::C2rPlanner;
4726 let mut planner = crate::FftwPlanner::new();
4727 planner.plan_c2r(n_fft)?
4728 };
4729
4730 let mut output = vec![0.0; n_fft];
4731 ifft.process(spectrum.as_slice(), &mut output)?;
4732
4733 // Safety: output is non-empty since n_fft > 0
4734 Ok(unsafe { NonEmptyVec::new_unchecked(output) })
4735}
4736
4737/// Reconstruct a time-domain signal from its STFT using overlap-add.
4738///
4739/// This function performs the inverse Short-Time Fourier Transform, converting
4740/// a 2D complex STFT matrix back to a 1D time-domain signal using overlap-add
4741/// synthesis with the specified window function.
4742///
4743/// # Arguments
4744///
4745/// * `stft_matrix` - Complex STFT with shape (`frequency_bins`, `time_frames`)
4746/// * `n_fft` - FFT size
4747/// * `hop_size` - Number of samples between successive frames
4748/// * `window` - Window function to apply (should match forward STFT window)
4749/// * `center` - If true, assume the forward STFT was centered
4750///
4751/// # Returns
4752///
4753/// A vector of reconstructed time-domain samples.
4754///
4755/// # Errors
4756///
4757/// Returns an error if:
4758/// - `stft_matrix` dimensions are inconsistent with `n_fft`
4759/// - `hop_size` > `n_fft`
4760/// - Inverse STFT computation fails
4761///
4762/// # Examples
4763///
4764/// ```
4765/// use spectrograms::*;
4766/// use non_empty_slice::non_empty_vec;
4767///
4768/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4769/// // Generate signal
4770/// let signal = non_empty_vec![1.0; nzu!(16000)];
4771///
4772/// // Forward STFT
4773/// let stft_matrix = stft(&signal, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4774///
4775/// // Inverse STFT
4776/// let reconstructed = istft(&stft_matrix, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4777///
4778/// println!("Original: {} samples", signal.len());
4779/// println!("Reconstructed: {} samples", reconstructed.len());
4780/// # Ok(())
4781/// # }
4782/// ```
4783#[inline]
4784pub fn istft(
4785 stft_matrix: &Array2<Complex<f64>>,
4786 n_fft: NonZeroUsize,
4787 hop_size: NonZeroUsize,
4788 window: WindowType,
4789 center: bool,
4790) -> SpectrogramResult<NonEmptyVec<f64>> {
4791 use crate::fft_backend::{C2rPlan, C2rPlanner, r2c_output_size};
4792
4793 let n_bins = stft_matrix.nrows();
4794 let n_frames = stft_matrix.ncols();
4795
4796 let expected_bins = r2c_output_size(n_fft.get());
4797 if n_bins != expected_bins {
4798 return Err(SpectrogramError::dimension_mismatch(expected_bins, n_bins));
4799 }
4800 if hop_size.get() > n_fft.get() {
4801 return Err(SpectrogramError::invalid_input("hop_size must be <= n_fft"));
4802 }
4803 // Create inverse FFT plan
4804 #[cfg(feature = "realfft")]
4805 let mut ifft = {
4806 let mut planner = crate::RealFftPlanner::new();
4807 planner.plan_c2r(n_fft.get())?
4808 };
4809
4810 #[cfg(feature = "fftw")]
4811 let mut ifft = {
4812 let mut planner = crate::FftwPlanner::new();
4813 planner.plan_c2r(n_fft.get())?
4814 };
4815
4816 // Generate window
4817 let window_samples = make_window(window, n_fft);
4818 let n_fft = n_fft.get();
4819 let hop_size = hop_size.get();
4820 // Calculate output length
4821 let pad = if center { n_fft / 2 } else { 0 };
4822 let output_len = (n_frames - 1) * hop_size + n_fft;
4823 // safety: output_len > 0 since n_frames > 0 and n_fft, hop_size > 0
4824 let output_len = unsafe { NonZeroUsize::new_unchecked(output_len) };
4825 let unpadded_len = output_len.get().saturating_sub(2 * pad);
4826
4827 // Allocate output buffer and normalization buffer
4828 let mut output = non_empty_vec![0.0; output_len];
4829 let mut norm = non_empty_vec![0.0; output_len];
4830
4831 // Overlap-add synthesis
4832 let mut frame_buffer = vec![Complex::new(0.0, 0.0); n_bins];
4833 let mut time_frame = vec![0.0; n_fft];
4834
4835 for frame_idx in 0..n_frames {
4836 // Extract complex frame from STFT matrix
4837 for bin_idx in 0..n_bins {
4838 frame_buffer[bin_idx] = stft_matrix[[bin_idx, frame_idx]];
4839 }
4840
4841 // Inverse FFT
4842 ifft.process(&frame_buffer, &mut time_frame)?;
4843
4844 // Apply window
4845 for i in 0..n_fft {
4846 time_frame[i] *= window_samples[i];
4847 }
4848
4849 // Overlap-add into output buffer
4850 let start = frame_idx * hop_size;
4851 for i in 0..n_fft {
4852 let pos = start + i;
4853 if pos < output_len.get() {
4854 output[pos] += time_frame[i];
4855 norm[pos] += window_samples[i] * window_samples[i];
4856 }
4857 }
4858 }
4859
4860 // Normalize by window energy
4861 for i in 0..output_len.get() {
4862 if norm[i] > 1e-10 {
4863 output[i] /= norm[i];
4864 }
4865 }
4866
4867 // Remove padding if centered
4868 if center && unpadded_len > 0 {
4869 let start = pad;
4870 let end = start + unpadded_len;
4871 // safety: start < end <= output_len, therefore slice is non-empty
4872 output = unsafe {
4873 NonEmptySlice::new_unchecked(&output[start..end.min(output_len.get())])
4874 .to_non_empty_vec()
4875 };
4876 }
4877
4878 Ok(output)
4879}
4880
4881//
4882// ========================
4883// Reusable FFT Plans
4884// ========================
4885//
4886
4887/// A reusable FFT planner for efficient repeated FFT operations.
4888///
4889/// This planner caches FFT plans internally, making repeated FFT operations
4890/// of the same size much more efficient than calling `fft()` repeatedly.
4891///
4892/// # Examples
4893///
4894/// ```
4895/// use spectrograms::*;
4896/// use non_empty_slice::non_empty_vec;
4897///
4898/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4899/// let mut planner = FftPlanner::new();
4900///
4901/// // Process multiple signals of the same size efficiently
4902/// for _ in 0..100 {
4903/// let signal = non_empty_vec![0.0; nzu!(512)];
4904/// let spectrum = planner.fft(&signal, nzu!(512))?;
4905/// // ... process spectrum ...
4906/// }
4907/// # Ok(())
4908/// # }
4909/// ```
4910pub struct FftPlanner {
4911 #[cfg(feature = "realfft")]
4912 inner: crate::RealFftPlanner,
4913 #[cfg(feature = "fftw")]
4914 inner: crate::FftwPlanner,
4915}
4916
4917impl FftPlanner {
4918 /// Create a new FFT planner with empty cache.
4919 #[inline]
4920 #[must_use]
4921 pub fn new() -> Self {
4922 Self {
4923 #[cfg(feature = "realfft")]
4924 inner: crate::RealFftPlanner::new(),
4925 #[cfg(feature = "fftw")]
4926 inner: crate::FftwPlanner::new(),
4927 }
4928 }
4929
4930 /// Compute forward FFT, reusing cached plans.
4931 ///
4932 /// This is more efficient than calling the standalone `fft()` function
4933 /// repeatedly for the same FFT size.
4934 ///
4935 /// # Automatic Zero-Padding
4936 ///
4937 /// If the input signal is shorter than `n_fft`, it will be automatically
4938 /// zero-padded to the required length.
4939 ///
4940 /// # Errors
4941 ///
4942 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4943 ///
4944 /// # Examples
4945 ///
4946 /// ```
4947 /// use spectrograms::*;
4948 /// use non_empty_slice::non_empty_vec;
4949 ///
4950 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4951 /// let mut planner = FftPlanner::new();
4952 ///
4953 /// let signal = non_empty_vec![1.0; nzu!(512)];
4954 /// let spectrum = planner.fft(&signal, nzu!(512))?;
4955 ///
4956 /// assert_eq!(spectrum.len(), 257); // 512/2 + 1
4957 /// # Ok(())
4958 /// # }
4959 /// ```
4960 #[inline]
4961 pub fn fft(
4962 &mut self,
4963 samples: &NonEmptySlice<f64>,
4964 n_fft: NonZeroUsize,
4965 ) -> SpectrogramResult<Array1<Complex<f64>>> {
4966 use crate::fft_backend::{R2cPlan, R2cPlanner, r2c_output_size};
4967
4968 if samples.len() > n_fft {
4969 return Err(SpectrogramError::invalid_input(format!(
4970 "Input length ({}) exceeds FFT size ({})",
4971 samples.len(),
4972 n_fft
4973 )));
4974 }
4975
4976 let out_len = r2c_output_size(n_fft.get());
4977 let mut plan = self.inner.plan_r2c(n_fft.get())?;
4978
4979 let input = if samples.len() < n_fft {
4980 let mut padded = vec![0.0; n_fft.get()];
4981 padded[..samples.len().get()].copy_from_slice(samples);
4982
4983 // safety: samples.len() < n_fft checked above and n_fft > 0
4984 unsafe { NonEmptyVec::new_unchecked(padded) }
4985 } else {
4986 samples.to_non_empty_vec()
4987 };
4988
4989 let mut output = vec![Complex::new(0.0, 0.0); out_len];
4990 plan.process(&input, &mut output)?;
4991
4992 let output = Array1::from_vec(output);
4993 Ok(output)
4994 }
4995
4996 /// Compute forward real FFT magnitude
4997 ///
4998 /// # Errors
4999 ///
5000 /// Returns an error if:
5001 /// - `n_fft` doesn't match the samples length
5002 ///
5003 ///
5004 #[inline]
5005 pub fn rfft(
5006 &mut self,
5007 samples: &NonEmptySlice<f64>,
5008 n_fft: NonZeroUsize,
5009 ) -> SpectrogramResult<Array1<f64>> {
5010 let fft_with_complex = fft(samples, n_fft)?;
5011 Ok(fft_with_complex.mapv(Complex::norm))
5012 }
5013
5014 /// Compute inverse FFT, reusing cached plans.
5015 ///
5016 /// This is more efficient than calling the standalone `irfft()` function
5017 /// repeatedly for the same FFT size.
5018 ///
5019 /// # Errors
5020 /// Returns an error if:
5021 ///
5022 /// - The calculated expected length of `spectrum` doesn't match its actual length
5023 ///
5024 /// # Examples
5025 ///
5026 /// ```
5027 /// use spectrograms::*;
5028 /// use non_empty_slice::{non_empty_vec, NonEmptySlice};
5029 ///
5030 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
5031 /// let mut planner = FftPlanner::new();
5032 ///
5033 /// // Forward FFT
5034 /// let signal = non_empty_vec![1.0; nzu!(512)];
5035 /// let spectrum = planner.fft(&signal, nzu!(512))?;
5036 ///
5037 /// // Inverse FFT
5038 /// let spectrum_slice = NonEmptySlice::new(spectrum.as_slice().unwrap()).unwrap();
5039 /// let reconstructed = planner.irfft(spectrum_slice, nzu!(512))?;
5040 ///
5041 /// assert_eq!(reconstructed.len(), nzu!(512));
5042 /// # Ok(())
5043 /// # }
5044 /// ```
5045 #[inline]
5046 pub fn irfft(
5047 &mut self,
5048 spectrum: &NonEmptySlice<Complex<f64>>,
5049 n_fft: NonZeroUsize,
5050 ) -> SpectrogramResult<NonEmptyVec<f64>> {
5051 use crate::fft_backend::{C2rPlan, C2rPlanner, r2c_output_size};
5052
5053 let expected_len = r2c_output_size(n_fft.get());
5054 if spectrum.len().get() != expected_len {
5055 return Err(SpectrogramError::dimension_mismatch(
5056 expected_len,
5057 spectrum.len().get(),
5058 ));
5059 }
5060
5061 let mut plan = self.inner.plan_c2r(n_fft.get())?;
5062 let mut output = vec![0.0; n_fft.get()];
5063 plan.process(spectrum, &mut output)?;
5064 // Safety: output is non-empty since n_fft > 0
5065 let output = unsafe { NonEmptyVec::new_unchecked(output) };
5066 Ok(output)
5067 }
5068
5069 /// Compute power spectrum with optional windowing, reusing cached plans.
5070 ///
5071 /// # Automatic Zero-Padding
5072 ///
5073 /// If the input signal is shorter than `n_fft`, it will be automatically
5074 /// zero-padded to the required length.
5075 ///
5076 /// # Errors
5077 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
5078 ///
5079 /// # Examples
5080 ///
5081 /// ```
5082 /// use spectrograms::*;
5083 /// use non_empty_slice::non_empty_vec;
5084 ///
5085 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
5086 /// let mut planner = FftPlanner::new();
5087 ///
5088 /// let signal = non_empty_vec![1.0; nzu!(512)];
5089 /// let power = planner.power_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
5090 ///
5091 /// assert_eq!(power.len(), nzu!(257));
5092 /// # Ok(())
5093 /// # }
5094 /// ```
5095 #[inline]
5096 pub fn power_spectrum(
5097 &mut self,
5098 samples: &NonEmptySlice<f64>,
5099 n_fft: NonZeroUsize,
5100 window: Option<WindowType>,
5101 ) -> SpectrogramResult<NonEmptyVec<f64>> {
5102 if samples.len() > n_fft {
5103 return Err(SpectrogramError::invalid_input(format!(
5104 "Input length ({}) exceeds FFT size ({})",
5105 samples.len(),
5106 n_fft
5107 )));
5108 }
5109
5110 let mut windowed = vec![0.0; n_fft.get()];
5111 windowed[..samples.len().get()].copy_from_slice(samples);
5112 if let Some(win_type) = window {
5113 let window_samples = make_window(win_type, n_fft);
5114 for i in 0..n_fft.get() {
5115 windowed[i] *= window_samples[i];
5116 }
5117 }
5118
5119 // safety: windowed is non-empty since n_fft > 0
5120 let windowed = unsafe { NonEmptySlice::new_unchecked(&windowed) };
5121 let fft_result = self.fft(windowed, n_fft)?;
5122 let f = fft_result
5123 .iter()
5124 .map(num_complex::Complex::norm_sqr)
5125 .collect();
5126 // safety: fft_result is non-empty since fft returned successfully
5127 Ok(unsafe { NonEmptyVec::new_unchecked(f) })
5128 }
5129
5130 /// Compute magnitude spectrum with optional windowing, reusing cached plans.
5131 ///
5132 /// # Automatic Zero-Padding
5133 ///
5134 /// If the input signal is shorter than `n_fft`, it will be automatically
5135 /// zero-padded to the required length.
5136 ///
5137 /// # Errors
5138 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
5139 ///
5140 /// # Examples
5141 ///
5142 /// ```
5143 /// use spectrograms::*;
5144 /// use non_empty_slice::non_empty_vec;
5145 ///
5146 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
5147 /// let mut planner = FftPlanner::new();
5148 ///
5149 /// let signal = non_empty_vec![1.0; nzu!(512)];
5150 /// let magnitude = planner.magnitude_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
5151 ///
5152 /// assert_eq!(magnitude.len(), nzu!(257));
5153 /// # Ok(())
5154 /// # }
5155 /// ```
5156 #[inline]
5157 pub fn magnitude_spectrum(
5158 &mut self,
5159 samples: &NonEmptySlice<f64>,
5160 n_fft: NonZeroUsize,
5161 window: Option<WindowType>,
5162 ) -> SpectrogramResult<NonEmptyVec<f64>> {
5163 let power = self.power_spectrum(samples, n_fft, window)?;
5164 let power = power.iter().map(|&p| p.sqrt()).collect::<Vec<f64>>();
5165 // safety: power is non-empty since power_spectrum returned successfully
5166 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
5167 }
5168}
5169
5170impl Default for FftPlanner {
5171 #[inline]
5172 fn default() -> Self {
5173 Self::new()
5174 }
5175}
5176
5177#[cfg(test)]
5178mod tests {
5179 use super::*;
5180
5181 #[test]
5182 fn test_sparse_matrix_basic() {
5183 // Create a simple 3x5 sparse matrix
5184 let mut sparse = SparseMatrix::new(3, 5);
5185
5186 // Row 0: only column 1 has value 2.0
5187 sparse.set(0, 1, 2.0);
5188
5189 // Row 1: columns 2 and 3
5190 sparse.set(1, 2, 0.5);
5191 sparse.set(1, 3, 1.5);
5192
5193 // Row 2: columns 0 and 4
5194 sparse.set(2, 0, 3.0);
5195 sparse.set(2, 4, 1.0);
5196
5197 // Test matrix-vector multiplication
5198 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
5199 let mut output = vec![0.0; 3];
5200
5201 sparse.multiply_vec(&input, &mut output);
5202
5203 // Expected results:
5204 // Row 0: 2.0 * 2.0 = 4.0
5205 // Row 1: 0.5 * 3.0 + 1.5 * 4.0 = 1.5 + 6.0 = 7.5
5206 // Row 2: 3.0 * 1.0 + 1.0 * 5.0 = 3.0 + 5.0 = 8.0
5207 assert_eq!(output[0], 4.0);
5208 assert_eq!(output[1], 7.5);
5209 assert_eq!(output[2], 8.0);
5210 }
5211
5212 #[test]
5213 fn test_sparse_matrix_zeros_ignored() {
5214 // Verify that zero values are not stored
5215 let mut sparse = SparseMatrix::new(2, 3);
5216
5217 sparse.set(0, 0, 1.0);
5218 sparse.set(0, 1, 0.0); // Should be ignored
5219 sparse.set(0, 2, 2.0);
5220
5221 // Only 2 values should be stored in row 0
5222 assert_eq!(sparse.values[0].len(), 2);
5223 assert_eq!(sparse.indices[0].len(), 2);
5224
5225 // The stored indices should be 0 and 2
5226 assert_eq!(sparse.indices[0], vec![0, 2]);
5227 assert_eq!(sparse.values[0], vec![1.0, 2.0]);
5228 }
5229
5230 #[test]
5231 fn test_loghz_matrix_sparsity() {
5232 // Verify that LogHz matrices are very sparse (1-2 non-zeros per row)
5233 let sample_rate = 16000.0;
5234 let n_fft = nzu!(512);
5235 let n_bins = nzu!(128);
5236 let f_min = 20.0;
5237 let f_max = sample_rate / 2.0;
5238
5239 let (matrix, _freqs) =
5240 build_loghz_matrix(sample_rate, n_fft, n_bins, f_min, f_max).unwrap();
5241
5242 // Each row should have at most 2 non-zero values (linear interpolation)
5243 for row_idx in 0..matrix.nrows() {
5244 let nnz = matrix.values[row_idx].len();
5245 assert!(
5246 nnz <= 2,
5247 "Row {} has {} non-zeros, expected at most 2",
5248 row_idx,
5249 nnz
5250 );
5251 assert!(nnz >= 1, "Row {} has no non-zeros", row_idx);
5252 }
5253
5254 // Total non-zeros should be close to n_bins * 2
5255 let total_nnz: usize = matrix.values.iter().map(|v| v.len()).sum();
5256 assert!(total_nnz <= n_bins.get() * 2);
5257 assert!(total_nnz >= n_bins.get()); // At least 1 per row
5258 }
5259
5260 #[test]
5261 fn test_mel_matrix_sparsity() {
5262 // Verify that Mel matrices are sparse (triangular filters)
5263 let sample_rate = 16000.0;
5264 let n_fft = nzu!(512);
5265 let n_mels = nzu!(40);
5266 let f_min = 0.0;
5267 let f_max = sample_rate / 2.0;
5268
5269 let matrix =
5270 build_mel_filterbank_matrix(sample_rate, n_fft, n_mels, f_min, f_max, MelNorm::None)
5271 .unwrap();
5272
5273 let n_fft_bins = r2c_output_size(n_fft.get());
5274
5275 // Calculate sparsity
5276 let total_nnz: usize = matrix.values.iter().map(|v| v.len()).sum();
5277 let total_elements = n_mels.get() * n_fft_bins;
5278 let sparsity = 1.0 - (total_nnz as f64 / total_elements as f64);
5279
5280 // Mel filterbanks should be >80% sparse
5281 assert!(
5282 sparsity > 0.8,
5283 "Mel matrix sparsity is only {:.1}%, expected >80%",
5284 sparsity * 100.0
5285 );
5286
5287 // Each mel filter should have significantly fewer than n_fft_bins non-zeros
5288 for row_idx in 0..matrix.nrows() {
5289 let nnz = matrix.values[row_idx].len();
5290 assert!(
5291 nnz < n_fft_bins / 2,
5292 "Mel filter {} is not sparse enough",
5293 row_idx
5294 );
5295 }
5296 }
5297}