spectrograms/spectrogram.rs
1use std::marker::PhantomData;
2use std::num::NonZeroUsize;
3use std::ops::{Deref, DerefMut};
4
5use ndarray::{Array1, Array2};
6use non_empty_slice::{NonEmptySlice, NonEmptyVec, non_empty_vec};
7use num_complex::Complex;
8
9#[cfg(feature = "python")]
10use pyo3::prelude::*;
11
12use crate::cqt::CqtKernel;
13use crate::erb::ErbFilterbank;
14use crate::{
15 CqtParams, ErbParams, R2cPlan, SpectrogramError, SpectrogramResult, WindowType,
16 min_max_single_pass, nzu,
17};
18const EPS: f64 = 1e-12;
19
20//
21// ========================
22// Sparse Matrix for efficient filterbank multiplication
23// ========================
24//
25
26/// Row-wise sparse matrix optimized for matrix-vector multiplication.
27///
28/// This structure stores sparse data as a vector of vectors, where each row maintains
29/// its own list of non-zero values and corresponding column indices. This is more
30/// flexible than traditional CSR format and allows efficient row-by-row construction.
31///
32/// This structure is designed for matrices with very few non-zero values per row,
33/// such as mel filterbanks (triangular filters) and logarithmic frequency mappings
34/// (linear interpolation between 1-2 adjacent bins).
35///
36/// For typical spectrograms:
37/// - `LogHz` interpolation: Only 1-2 non-zeros per row (~99% sparse)
38/// - Mel filterbank: ~10-50 non-zeros per row depending on FFT size (~90-98% sparse)
39///
40/// By storing only non-zero values, we avoid wasting CPU cycles multiplying by zero,
41/// which can provide 10-100x speedup compared to dense matrix multiplication.
42#[derive(Debug, Clone)]
43struct SparseMatrix {
44 /// Number of rows
45 nrows: usize,
46 /// Number of columns
47 ncols: usize,
48 /// Non-zero values for each row (row-major order)
49 values: Vec<Vec<f64>>,
50 /// Column indices for each non-zero value
51 indices: Vec<Vec<usize>>,
52}
53
54impl SparseMatrix {
55 /// Create a new sparse matrix with the given dimensions.
56 fn new(nrows: usize, ncols: usize) -> Self {
57 Self {
58 nrows,
59 ncols,
60 values: vec![Vec::new(); nrows],
61 indices: vec![Vec::new(); nrows],
62 }
63 }
64
65 /// Set a value in the matrix. Only stores if value is non-zero.
66 ///
67 /// # Panics (debug mode only)
68 /// Panics in debug builds if row or col are out of bounds.
69 fn set(&mut self, row: usize, col: usize, value: f64) {
70 debug_assert!(
71 row < self.nrows && col < self.ncols,
72 "SparseMatrix index out of bounds: ({}, {}) for {}x{} matrix",
73 row,
74 col,
75 self.nrows,
76 self.ncols
77 );
78 if row >= self.nrows || col >= self.ncols {
79 return;
80 }
81
82 // Only store non-zero values (with small epsilon for numerical stability)
83 if value.abs() > 1e-10 {
84 self.values[row].push(value);
85 self.indices[row].push(col);
86 }
87 }
88
89 /// Get the number of rows.
90 const fn nrows(&self) -> usize {
91 self.nrows
92 }
93
94 /// Get the number of columns.
95 const fn ncols(&self) -> usize {
96 self.ncols
97 }
98
99 /// Perform sparse matrix-vector multiplication: out = self * input
100 /// This is much faster than dense multiplication when the matrix is sparse.
101 #[inline]
102 fn multiply_vec(&self, input: &[f64], out: &mut [f64]) {
103 debug_assert_eq!(input.len(), self.ncols);
104 debug_assert_eq!(out.len(), self.nrows);
105
106 for (row_idx, (row_values, row_indices)) in
107 self.values.iter().zip(&self.indices).enumerate()
108 {
109 let mut acc = 0.0;
110 for (&value, &col_idx) in row_values.iter().zip(row_indices) {
111 acc += value * input[col_idx];
112 }
113 out[row_idx] = acc;
114 }
115 }
116}
117
118// Linear frequency
119pub type LinearPowerSpectrogram = Spectrogram<LinearHz, Power>;
120pub type LinearMagnitudeSpectrogram = Spectrogram<LinearHz, Magnitude>;
121pub type LinearDbSpectrogram = Spectrogram<LinearHz, Decibels>;
122pub type LinearSpectrogram<AmpScale> = Spectrogram<LinearHz, AmpScale>;
123
124// Log-frequency (e.g. CQT-style)
125pub type LogHzPowerSpectrogram = Spectrogram<LogHz, Power>;
126pub type LogHzMagnitudeSpectrogram = Spectrogram<LogHz, Magnitude>;
127pub type LogHzDbSpectrogram = Spectrogram<LogHz, Decibels>;
128pub type LogHzSpectrogram<AmpScale> = Spectrogram<LogHz, AmpScale>;
129
130// ERB / gammatone
131pub type ErbPowerSpectrogram = Spectrogram<Erb, Power>;
132pub type ErbMagnitudeSpectrogram = Spectrogram<Erb, Magnitude>;
133pub type ErbDbSpectrogram = Spectrogram<Erb, Decibels>;
134pub type GammatonePowerSpectrogram = ErbPowerSpectrogram;
135pub type GammatoneMagnitudeSpectrogram = ErbMagnitudeSpectrogram;
136pub type GammatoneDbSpectrogram = ErbDbSpectrogram;
137pub type ErbSpectrogram<AmpScale> = Spectrogram<Erb, AmpScale>;
138pub type GammatoneSpectrogram<AmpScale> = ErbSpectrogram<AmpScale>;
139
140// Mel
141pub type MelMagnitudeSpectrogram = Spectrogram<Mel, Magnitude>;
142pub type MelPowerSpectrogram = Spectrogram<Mel, Power>;
143pub type MelDbSpectrogram = Spectrogram<Mel, Decibels>;
144pub type LogMelSpectrogram = MelDbSpectrogram;
145pub type MelSpectrogram<AmpScale> = Spectrogram<Mel, AmpScale>;
146
147// CQT
148pub type CqtPowerSpectrogram = Spectrogram<Cqt, Power>;
149pub type CqtMagnitudeSpectrogram = Spectrogram<Cqt, Magnitude>;
150pub type CqtDbSpectrogram = Spectrogram<Cqt, Decibels>;
151pub type CqtSpectrogram<AmpScale> = Spectrogram<Cqt, AmpScale>;
152
153use crate::fft_backend::r2c_output_size;
154
155/// A spectrogram plan is the compiled, reusable execution object.
156///
157/// It owns:
158/// - FFT plan (reusable)
159/// - window samples
160/// - mapping (identity / mel filterbank / etc.)
161/// - amplitude scaling config
162/// - workspace buffers to avoid allocations in hot loops
163///
164/// It computes one specific spectrogram type: `Spectrogram<FreqScale, AmpScale>`.
165///
166/// # Type Parameters
167///
168/// - `FreqScale`: Frequency scale type (e.g. `LinearHz`, `LogHz`, `Mel`, etc.)
169/// - `AmpScale`: Amplitude scaling type (e.g. `Power`, `Magnitude`, `Decibels`, etc.)
170pub struct SpectrogramPlan<FreqScale, AmpScale>
171where
172 AmpScale: AmpScaleSpec + 'static,
173 FreqScale: Copy + Clone + 'static,
174{
175 params: SpectrogramParams,
176
177 stft: StftPlan,
178 mapping: FrequencyMapping<FreqScale>,
179 scaling: AmplitudeScaling<AmpScale>,
180
181 freq_axis: FrequencyAxis<FreqScale>,
182 workspace: Workspace,
183
184 _amp: PhantomData<AmpScale>,
185}
186
187impl<FreqScale, AmpScale> SpectrogramPlan<FreqScale, AmpScale>
188where
189 AmpScale: AmpScaleSpec + 'static,
190 FreqScale: Copy + Clone + 'static,
191{
192 /// Get the spectrogram parameters used to create this plan.
193 ///
194 /// # Returns
195 ///
196 /// A reference to the `SpectrogramParams` used in this plan.
197 #[inline]
198 #[must_use]
199 pub const fn params(&self) -> &SpectrogramParams {
200 &self.params
201 }
202
203 /// Get the frequency axis for this spectrogram plan.
204 ///
205 /// # Returns
206 ///
207 /// A reference to the `FrequencyAxis<FreqScale>` used in this plan.
208 #[inline]
209 #[must_use]
210 pub const fn freq_axis(&self) -> &FrequencyAxis<FreqScale> {
211 &self.freq_axis
212 }
213
214 /// Compute a spectrogram for a mono signal.
215 ///
216 /// This function performs:
217 /// - framing + windowing
218 /// - FFT per frame
219 /// - magnitude/power
220 /// - frequency mapping (identity/mel/etc.)
221 /// - amplitude scaling (linear or dB)
222 ///
223 /// It allocates the output `Array2` once, but does not allocate per-frame.
224 ///
225 /// # Arguments
226 ///
227 /// * `samples` - Audio samples
228 ///
229 /// # Returns
230 ///
231 /// A `Spectrogram<FreqScale, AmpScale>` containing the computed spectrogram.
232 ///
233 /// # Errors
234 ///
235 /// Returns an error if STFT computation or mapping fails.
236 #[inline]
237 pub fn compute(
238 &mut self,
239 samples: &NonEmptySlice<f64>,
240 ) -> SpectrogramResult<Spectrogram<FreqScale, AmpScale>> {
241 let n_frames = self.stft.frame_count(samples.len())?;
242 let n_bins = self.mapping.output_bins();
243
244 // Create output matrix: (n_bins, n_frames)
245 let mut data = Array2::<f64>::zeros((n_bins.get(), n_frames.get()));
246
247 // Ensure workspace is correctly sized
248 self.workspace
249 .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
250
251 // Main loop: fill each frame (column)
252 // TODO: Parallelize with rayon once thread-safety issues are resolved
253 for frame_idx in 0..n_frames.get() {
254 self.stft
255 .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
256
257 // mapping: spectrum(out_len) -> mapped(n_bins)
258 // For CQT, this uses workspace.frame; for others, workspace.spectrum
259 // For ERB, we need the complex FFT output (fft_out)
260 // We need to borrow workspace fields separately to avoid borrow conflicts
261 let Workspace {
262 spectrum,
263 mapped,
264 frame,
265 ..
266 } = &mut self.workspace;
267
268 self.mapping.apply(spectrum, frame, mapped)?;
269
270 // amplitude scaling in-place on mapped vector
271 self.scaling.apply_in_place(mapped)?;
272
273 // write column into output
274 for (row, &val) in mapped.iter().enumerate() {
275 data[[row, frame_idx]] = val;
276 }
277 }
278
279 let times = build_time_axis_seconds(&self.params, n_frames);
280 let axes = Axes::new(self.freq_axis.clone(), times);
281
282 Ok(Spectrogram::new(data, axes, self.params))
283 }
284
285 /// Compute a single frame of the spectrogram.
286 ///
287 /// This is useful for streaming/online processing where you want to
288 /// process audio frame-by-frame without computing the entire spectrogram.
289 ///
290 /// # Arguments
291 ///
292 /// * `samples` - Audio samples (must contain at least enough samples for the requested frame)
293 /// * `frame_idx` - Frame index to compute
294 ///
295 /// # Returns
296 ///
297 /// A vector of frequency bin values for the requested frame.
298 ///
299 /// # Errors
300 ///
301 /// Returns an error if the frame index is out of bounds or if STFT computation fails.
302 ///
303 /// # Examples
304 ///
305 /// ```
306 /// use spectrograms::*;
307 /// use non_empty_slice::non_empty_vec;
308 ///
309 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
310 /// let samples = non_empty_vec![0.0; nzu!(16000)];
311 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
312 /// let params = SpectrogramParams::new(stft, 16000.0)?;
313 ///
314 /// let planner = SpectrogramPlanner::new();
315 /// let mut plan = planner.linear_plan::<Power>(¶ms, None)?;
316 ///
317 /// // Compute just the first frame
318 /// let frame = plan.compute_frame(&samples, 0)?;
319 /// assert_eq!(frame.len(), nzu!(257)); // n_fft/2 + 1
320 /// # Ok(())
321 /// # }
322 /// ```
323 #[inline]
324 pub fn compute_frame(
325 &mut self,
326 samples: &NonEmptySlice<f64>,
327 frame_idx: usize,
328 ) -> SpectrogramResult<NonEmptyVec<f64>> {
329 let n_bins = self.mapping.output_bins();
330
331 // Ensure workspace is correctly sized
332 self.workspace
333 .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
334
335 // Compute frame spectrum
336 self.stft
337 .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
338
339 // Apply mapping (using split borrows to avoid borrow conflicts)
340 let Workspace {
341 spectrum,
342 mapped,
343 frame,
344 ..
345 } = &mut self.workspace;
346
347 self.mapping.apply(spectrum, frame, mapped)?;
348
349 // Apply amplitude scaling
350 self.scaling.apply_in_place(mapped)?;
351
352 Ok(mapped.clone())
353 }
354
355 /// Compute spectrogram into a pre-allocated buffer.
356 ///
357 /// This avoids allocating the output matrix, which is useful when
358 /// you want to reuse buffers or have strict memory requirements.
359 ///
360 /// # Arguments
361 ///
362 /// * `samples` - Audio samples
363 /// * `output` - Pre-allocated output matrix (must be correct size: `n_bins` x `n_frames`)
364 ///
365 /// # Returns
366 ///
367 /// An empty result on success.
368 ///
369 /// # Errors
370 ///
371 /// Returns an error if the output buffer dimensions don't match the expected size.
372 ///
373 /// # Examples
374 ///
375 /// ```
376 /// use spectrograms::*;
377 /// use ndarray::Array2;
378 /// use non_empty_slice::non_empty_vec;
379 ///
380 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
381 /// let samples = non_empty_vec![0.0; nzu!(16000)];
382 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
383 /// let params = SpectrogramParams::new(stft, 16000.0)?;
384 ///
385 /// let planner = SpectrogramPlanner::new();
386 /// let mut plan = planner.linear_plan::<Power>(¶ms, None)?;
387 ///
388 /// // Pre-allocate output buffer
389 /// let mut output = Array2::<f64>::zeros((257, 63));
390 /// plan.compute_into(&samples, &mut output)?;
391 /// # Ok(())
392 /// # }
393 /// ```
394 #[inline]
395 pub fn compute_into(
396 &mut self,
397 samples: &NonEmptySlice<f64>,
398 output: &mut Array2<f64>,
399 ) -> SpectrogramResult<()> {
400 let n_frames = self.stft.frame_count(samples.len())?;
401 let n_bins = self.mapping.output_bins();
402
403 // Validate output dimensions
404 if output.nrows() != n_bins.get() {
405 return Err(SpectrogramError::dimension_mismatch(
406 n_bins.get(),
407 output.nrows(),
408 ));
409 }
410 if output.ncols() != n_frames.get() {
411 return Err(SpectrogramError::dimension_mismatch(
412 n_frames.get(),
413 output.ncols(),
414 ));
415 }
416
417 // Ensure workspace is correctly sized
418 self.workspace
419 .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
420
421 // Main loop: fill each frame (column)
422 for frame_idx in 0..n_frames.get() {
423 self.stft
424 .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
425
426 // mapping: spectrum(out_len) -> mapped(n_bins)
427 // For CQT, this uses workspace.frame; for others, workspace.spectrum
428 // For ERB, we need the complex FFT output (fft_out)
429 // We need to borrow workspace fields separately to avoid borrow conflicts
430 let Workspace {
431 spectrum,
432 mapped,
433 frame,
434 ..
435 } = &mut self.workspace;
436
437 self.mapping.apply(spectrum, frame, mapped)?;
438
439 // amplitude scaling in-place on mapped vector
440 self.scaling.apply_in_place(mapped)?;
441
442 // write column into output
443 for (row, &val) in mapped.iter().enumerate() {
444 output[[row, frame_idx]] = val;
445 }
446 }
447
448 Ok(())
449 }
450
451 /// Get the expected output dimensions for a given signal length.
452 ///
453 /// # Arguments
454 ///
455 /// * `signal_length` - Length of the input signal in samples.
456 ///
457 /// # Returns
458 ///
459 /// A tuple `(n_bins, n_frames)` representing the output spectrogram shape.
460 ///
461 /// # Errors
462 ///
463 /// Returns an error if the signal length is too short to produce any frames.
464 ///
465 /// # Examples
466 ///
467 /// ```
468 /// use spectrograms::*;
469 ///
470 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
471 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
472 /// let params = SpectrogramParams::new(stft, 16000.0)?;
473 ///
474 /// let planner = SpectrogramPlanner::new();
475 /// let plan = planner.linear_plan::<Power>(¶ms, None)?;
476 ///
477 /// let (n_bins, n_frames) = plan.output_shape(nzu!(16000))?;
478 /// assert_eq!(n_bins, nzu!(257));
479 /// assert_eq!(n_frames, nzu!(63));
480 /// # Ok(())
481 /// # }
482 /// ```
483 #[inline]
484 pub fn output_shape(
485 &self,
486 signal_length: NonZeroUsize,
487 ) -> SpectrogramResult<(NonZeroUsize, NonZeroUsize)> {
488 let n_frames = self.stft.frame_count(signal_length)?;
489 let n_bins = self.mapping.output_bins();
490 Ok((n_bins, n_frames))
491 }
492}
493
494/// STFT (Short-Time Fourier Transform) result containing complex frequency bins.
495///
496/// This is the raw STFT output before any frequency mapping or amplitude scaling.
497///
498/// # Fields
499///
500/// - `data`: Complex STFT matrix with shape (`frequency_bins`, `time_frames`)
501/// - `frequencies`: Frequency axis in Hz
502/// - `sample_rate`: Sample rate in Hz
503/// - `params`: STFT computation parameters
504#[derive(Debug, Clone)]
505#[non_exhaustive]
506pub struct StftResult {
507 /// Complex STFT matrix with shape (`frequency_bins`, `time_frames`)
508 pub data: Array2<Complex<f64>>,
509 /// Frequency axis in Hz
510 pub frequencies: NonEmptyVec<f64>,
511 /// Sample rate in Hz
512 pub sample_rate: f64,
513 pub params: StftParams,
514}
515
516impl StftResult {
517 /// Get the number of frequency bins.
518 ///
519 /// # Returns
520 ///
521 /// Number of frequency bins in the STFT result.
522 #[inline]
523 #[must_use]
524 pub fn n_bins(&self) -> NonZeroUsize {
525 // safety: nrows() > 0 for NonEmptyVec
526 unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
527 }
528
529 /// Get the number of time frames.
530 ///
531 /// # Returns
532 ///
533 /// Number of time frames in the STFT result.
534 #[inline]
535 #[must_use]
536 pub fn n_frames(&self) -> NonZeroUsize {
537 // safety: ncols() > 0 for NonEmptyVec
538 unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
539 }
540
541 /// Get the frequency resolution in Hz
542 ///
543 /// # Returns
544 ///
545 /// Frequency bin width in Hz.
546 #[inline]
547 #[must_use]
548 pub fn frequency_resolution(&self) -> f64 {
549 self.sample_rate / self.params.n_fft().get() as f64
550 }
551
552 /// Get the time resolution in seconds.
553 ///
554 /// # Returns
555 ///
556 /// Time between successive frames in seconds.
557 #[inline]
558 #[must_use]
559 pub fn time_resolution(&self) -> f64 {
560 self.params.hop_size().get() as f64 / self.sample_rate
561 }
562
563 /// Normalizes self.data to remove the complex aspect of it.
564 ///
565 /// # Returns
566 ///
567 /// An Array2<f64> containing the norms of each complex number in self.data.
568 #[inline]
569 pub fn norm(&self) -> Array2<f64> {
570 self.as_ref().mapv(Complex::norm)
571 }
572}
573
574impl AsRef<Array2<Complex<f64>>> for StftResult {
575 #[inline]
576 fn as_ref(&self) -> &Array2<Complex<f64>> {
577 &self.data
578 }
579}
580
581impl AsMut<Array2<Complex<f64>>> for StftResult {
582 #[inline]
583 fn as_mut(&mut self) -> &mut Array2<Complex<f64>> {
584 &mut self.data
585 }
586}
587
588impl Deref for StftResult {
589 type Target = Array2<Complex<f64>>;
590
591 #[inline]
592 fn deref(&self) -> &Self::Target {
593 &self.data
594 }
595}
596
597impl DerefMut for StftResult {
598 #[inline]
599 fn deref_mut(&mut self) -> &mut Self::Target {
600 &mut self.data
601 }
602}
603
604/// A planner is an object that can build spectrogram plans.
605///
606/// In your design, this is where:
607/// - FFT plans are created
608/// - mapping matrices are compiled
609/// - axes are computed
610///
611/// This allows you to keep plan building separate from the output types.
612#[derive(Debug, Default)]
613#[non_exhaustive]
614pub struct SpectrogramPlanner;
615
616impl SpectrogramPlanner {
617 /// Create a new spectrogram planner.
618 ///
619 /// # Returns
620 ///
621 /// A new `SpectrogramPlanner` instance.
622 #[inline]
623 #[must_use]
624 pub const fn new() -> Self {
625 Self
626 }
627
628 /// Compute the Short-Time Fourier Transform (STFT) of a signal.
629 ///
630 /// This returns the raw complex STFT matrix before any frequency mapping
631 /// or amplitude scaling. Useful for applications that need the full complex
632 /// spectrum or custom processing.
633 ///
634 /// # Arguments
635 ///
636 /// * `samples` - Audio samples (any type that can be converted to a slice)
637 /// * `params` - STFT computation parameters
638 ///
639 /// # Returns
640 ///
641 /// An `StftResult` containing the complex STFT matrix and metadata.
642 ///
643 /// # Errors
644 ///
645 /// Returns an error if STFT computation fails.
646 ///
647 /// # Examples
648 ///
649 /// ```
650 /// use spectrograms::*;
651 /// use non_empty_slice::non_empty_vec;
652 ///
653 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
654 /// let samples = non_empty_vec![0.0; nzu!(16000)];
655 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
656 /// let params = SpectrogramParams::new(stft, 16000.0)?;
657 ///
658 /// let planner = SpectrogramPlanner::new();
659 /// let stft_result = planner.compute_stft(&samples, ¶ms)?;
660 ///
661 /// println!("STFT: {} bins x {} frames", stft_result.n_bins(), stft_result.n_frames());
662 /// # Ok(())
663 /// # }
664 /// ```
665 ///
666 /// # Performance Note
667 ///
668 /// This method creates a new FFT plan each time. For processing multiple
669 /// signals, create a reusable plan with `StftPlan::new()` instead.
670 ///
671 /// # Examples
672 ///
673 /// ```rust
674 /// use spectrograms::*;
675 /// use non_empty_slice::non_empty_vec;
676 ///
677 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
678 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
679 /// let params = SpectrogramParams::new(stft, 16000.0)?;
680 ///
681 /// // One-shot (convenient)
682 /// let planner = SpectrogramPlanner::new();
683 /// let stft_result = planner.compute_stft(&non_empty_vec![0.0; nzu!(16000)], ¶ms)?;
684 ///
685 /// // Reusable plan (efficient for batch)
686 /// let mut plan = StftPlan::new(¶ms)?;
687 /// for signal in &[non_empty_vec![0.0; nzu!(16000)], non_empty_vec![1.0; nzu!(16000)]] {
688 /// let stft = plan.compute(&signal, ¶ms)?;
689 /// }
690 /// # Ok(())
691 /// # }
692 /// ```
693 #[inline]
694 pub fn compute_stft(
695 &self,
696 samples: &NonEmptySlice<f64>,
697 params: &SpectrogramParams,
698 ) -> SpectrogramResult<StftResult> {
699 let mut plan = StftPlan::new(params)?;
700 plan.compute(samples, params)
701 }
702
703 /// Compute the power spectrum of a single audio frame.
704 ///
705 /// This is useful for real-time processing or analyzing individual frames.
706 ///
707 /// # Arguments
708 ///
709 /// * `samples` - Audio frame (length ≤ n_fft, will be zero-padded if shorter)
710 /// * `n_fft` - FFT size
711 /// * `window` - Window type to apply
712 ///
713 /// # Returns
714 ///
715 /// A vector of power values (|X|²) with length `n_fft/2` + 1.
716 ///
717 /// # Automatic Zero-Padding
718 ///
719 /// If the input signal is shorter than `n_fft`, it will be automatically
720 /// zero-padded to the required length.
721 ///
722 /// # Errors
723 ///
724 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
725 ///
726 /// # Examples
727 ///
728 /// ```
729 /// use spectrograms::*;
730 /// use non_empty_slice::non_empty_vec;
731 ///
732 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
733 /// let frame = non_empty_vec![0.0; nzu!(512)];
734 ///
735 /// let planner = SpectrogramPlanner::new();
736 /// let power = planner.compute_power_spectrum(frame.as_ref(), nzu!(512), WindowType::Hanning)?;
737 ///
738 /// assert_eq!(power.len(), nzu!(257)); // 512/2 + 1
739 /// # Ok(())
740 /// # }
741 /// ```
742 #[inline]
743 pub fn compute_power_spectrum(
744 &self,
745 samples: &NonEmptySlice<f64>,
746 n_fft: NonZeroUsize,
747 window: WindowType,
748 ) -> SpectrogramResult<NonEmptyVec<f64>> {
749 if samples.len() > n_fft {
750 return Err(SpectrogramError::invalid_input(format!(
751 "Input length ({}) exceeds FFT size ({})",
752 samples.len(),
753 n_fft
754 )));
755 }
756
757 let window_samples = make_window(window, n_fft);
758 let out_len = r2c_output_size(n_fft.get());
759
760 // Create FFT plan
761 #[cfg(feature = "realfft")]
762 let mut fft = {
763 let mut planner = crate::RealFftPlanner::new();
764 let plan = planner.get_or_create(n_fft.get());
765 crate::RealFftPlan::new(n_fft.get(), plan)
766 };
767
768 #[cfg(feature = "fftw")]
769 let mut fft = {
770 use std::sync::Arc;
771 let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
772 crate::FftwPlan::new(Arc::new(plan))
773 };
774
775 // Apply window and compute FFT
776 let mut windowed = vec![0.0; n_fft.get()];
777 for i in 0..samples.len().get() {
778 windowed[i] = samples[i] * window_samples[i];
779 }
780 // The rest is already zero-padded
781 let mut fft_out = vec![Complex::new(0.0, 0.0); out_len];
782 fft.process(&windowed, &mut fft_out)?;
783
784 // Convert to power
785 let power: Vec<f64> = fft_out.iter().map(num_complex::Complex::norm_sqr).collect();
786 // safety: power is non-empty since n_fft > 0
787 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
788 }
789
790 /// Compute the magnitude spectrum of a single audio frame.
791 ///
792 /// This is useful for real-time processing or analyzing individual frames.
793 ///
794 /// # Arguments
795 ///
796 /// * `samples` - Audio frame (length ≤ n_fft, will be zero-padded if shorter)
797 /// * `n_fft` - FFT size
798 /// * `window` - Window type to apply
799 ///
800 /// # Returns
801 ///
802 /// A vector of magnitude values (|X|) with length `n_fft/2` + 1.
803 ///
804 /// # Automatic Zero-Padding
805 ///
806 /// If the input signal is shorter than `n_fft`, it will be automatically
807 /// zero-padded to the required length.
808 ///
809 /// # Errors
810 ///
811 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
812 ///
813 /// # Examples
814 ///
815 /// ```
816 /// use spectrograms::*;
817 /// use non_empty_slice::non_empty_vec;
818 ///
819 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
820 /// let frame = non_empty_vec![0.0; nzu!(512)];
821 ///
822 /// let planner = SpectrogramPlanner::new();
823 /// let magnitude = planner.compute_magnitude_spectrum(frame.as_ref(), nzu!(512), WindowType::Hanning)?;
824 ///
825 /// assert_eq!(magnitude.len(), nzu!(257)); // 512/2 + 1
826 /// # Ok(())
827 /// # }
828 /// ```
829 #[inline]
830 pub fn compute_magnitude_spectrum(
831 &self,
832 samples: &NonEmptySlice<f64>,
833 n_fft: NonZeroUsize,
834 window: WindowType,
835 ) -> SpectrogramResult<NonEmptyVec<f64>> {
836 let power = self.compute_power_spectrum(samples, n_fft, window)?;
837 let power = power.iter().map(|&p| p.sqrt()).collect::<Vec<f64>>();
838 // safety: power is non-empty since power_spectrum returned successfully
839 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
840 }
841
842 /// Build a linear-frequency spectrogram plan.
843 ///
844 /// # Type Parameters
845 ///
846 /// `AmpScale` determines whether output is:
847 /// - Magnitude
848 /// - Power
849 /// - Decibels
850 ///
851 /// # Arguments
852 ///
853 /// * `params` - Spectrogram parameters
854 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale
855 /// is `Decibels`)
856 ///
857 /// # Returns
858 ///
859 /// A `SpectrogramPlan` configured for linear-frequency spectrogram computation.
860 ///
861 /// # Errors
862 ///
863 /// Returns an error if the plan cannot be created due to invalid parameters.
864 #[inline]
865 pub fn linear_plan<AmpScale>(
866 &self,
867 params: &SpectrogramParams,
868 db: Option<&LogParams>, // only used when AmpScale = Decibels
869 ) -> SpectrogramResult<SpectrogramPlan<LinearHz, AmpScale>>
870 where
871 AmpScale: AmpScaleSpec + 'static,
872 {
873 let stft = StftPlan::new(params)?;
874 let mapping = FrequencyMapping::<LinearHz>::new(params)?;
875 let scaling = AmplitudeScaling::<AmpScale>::new(db);
876 let freq_axis = build_frequency_axis::<LinearHz>(params, &mapping);
877
878 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
879
880 Ok(SpectrogramPlan {
881 params: *params,
882 stft,
883 mapping,
884 scaling,
885 freq_axis,
886 workspace,
887 _amp: PhantomData,
888 })
889 }
890
891 /// Build a mel-frequency spectrogram plan.
892 ///
893 /// This compiles a mel filterbank matrix and caches it inside the plan.
894 ///
895 /// # Type Parameters
896 ///
897 /// `AmpScale`: determines whether output is:
898 /// - Magnitude
899 /// - Power
900 /// - Decibels
901 ///
902 /// # Arguments
903 ///
904 /// * `params` - Spectrogram parameters
905 /// * `mel` - Mel-specific parameters
906 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
907 ///
908 /// # Returns
909 ///
910 /// A `SpectrogramPlan` configured for mel spectrogram computation.
911 ///
912 /// # Errors
913 ///
914 /// Returns an error if the plan cannot be created due to invalid parameters.
915 #[inline]
916 pub fn mel_plan<AmpScale>(
917 &self,
918 params: &SpectrogramParams,
919 mel: &MelParams,
920 db: Option<&LogParams>, // only used when AmpScale = Decibels
921 ) -> SpectrogramResult<SpectrogramPlan<Mel, AmpScale>>
922 where
923 AmpScale: AmpScaleSpec + 'static,
924 {
925 // cross-validation: mel range must be compatible with sample rate
926 let nyquist = params.nyquist_hz();
927 if mel.f_max() > nyquist {
928 return Err(SpectrogramError::invalid_input(
929 "mel f_max must be <= Nyquist",
930 ));
931 }
932
933 let stft = StftPlan::new(params)?;
934 let mapping = FrequencyMapping::<Mel>::new_mel(params, mel)?;
935 let scaling = AmplitudeScaling::<AmpScale>::new(db);
936 let freq_axis = build_frequency_axis::<Mel>(params, &mapping);
937
938 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
939
940 Ok(SpectrogramPlan {
941 params: *params,
942 stft,
943 mapping,
944 scaling,
945 freq_axis,
946 workspace,
947 _amp: PhantomData,
948 })
949 }
950
951 /// Build an ERB-scale spectrogram plan.
952 ///
953 /// This creates a spectrogram with ERB-spaced frequency bands using gammatone
954 /// filterbank approximation in the frequency domain.
955 ///
956 /// # Type Parameters
957 ///
958 /// `AmpScale`: determines whether output is:
959 /// - Magnitude
960 /// - Power
961 /// - Decibels
962 ///
963 /// # Arguments
964 ///
965 /// * `params` - Spectrogram parameters
966 /// * `erb` - ERB-specific parameters
967 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
968 ///
969 /// # Returns
970 ///
971 /// A `SpectrogramPlan` configured for ERB spectrogram computation.
972 ///
973 /// # Errors
974 ///
975 /// Returns an error if the plan cannot be created due to invalid parameters.
976 #[inline]
977 pub fn erb_plan<AmpScale>(
978 &self,
979 params: &SpectrogramParams,
980 erb: &ErbParams,
981 db: Option<&LogParams>,
982 ) -> SpectrogramResult<SpectrogramPlan<Erb, AmpScale>>
983 where
984 AmpScale: AmpScaleSpec + 'static,
985 {
986 // cross-validation: erb range must be compatible with sample rate
987 let nyquist = params.nyquist_hz();
988 if erb.f_max() > nyquist {
989 return Err(SpectrogramError::invalid_input(format!(
990 "f_max={} exceeds Nyquist={}",
991 erb.f_max(),
992 nyquist
993 )));
994 }
995
996 let stft = StftPlan::new(params)?;
997 let mapping = FrequencyMapping::<Erb>::new_erb(params, erb)?;
998 let scaling = AmplitudeScaling::<AmpScale>::new(db);
999 let freq_axis = build_frequency_axis::<Erb>(params, &mapping);
1000
1001 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1002
1003 Ok(SpectrogramPlan {
1004 params: *params,
1005 stft,
1006 mapping,
1007 scaling,
1008 freq_axis,
1009 workspace,
1010 _amp: PhantomData,
1011 })
1012 }
1013
1014 /// Build a log-frequency plan.
1015 ///
1016 /// This creates a spectrogram with logarithmically-spaced frequency bins.
1017 ///
1018 /// # Type Parameters
1019 ///
1020 /// `AmpScale`: determines whether output is:
1021 /// - Magnitude
1022 /// - Power
1023 /// - Decibels
1024 ///
1025 /// # Arguments
1026 ///
1027 /// * `params` - Spectrogram parameters
1028 /// * `loghz` - LogHz-specific parameters
1029 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
1030 ///
1031 /// # Returns
1032 ///
1033 /// A `SpectrogramPlan` configured for log-frequency spectrogram computation.
1034 ///
1035 /// # Errors
1036 ///
1037 /// Returns an error if the plan cannot be created due to invalid parameters.
1038 #[inline]
1039 pub fn log_hz_plan<AmpScale>(
1040 &self,
1041 params: &SpectrogramParams,
1042 loghz: &LogHzParams,
1043 db: Option<&LogParams>,
1044 ) -> SpectrogramResult<SpectrogramPlan<LogHz, AmpScale>>
1045 where
1046 AmpScale: AmpScaleSpec + 'static,
1047 {
1048 // cross-validation: loghz range must be compatible with sample rate
1049 let nyquist = params.nyquist_hz();
1050 if loghz.f_max() > nyquist {
1051 return Err(SpectrogramError::invalid_input(format!(
1052 "f_max={} exceeds Nyquist={}",
1053 loghz.f_max(),
1054 nyquist
1055 )));
1056 }
1057
1058 let stft = StftPlan::new(params)?;
1059 let mapping = FrequencyMapping::<LogHz>::new_loghz(params, loghz)?;
1060 let scaling = AmplitudeScaling::<AmpScale>::new(db);
1061 let freq_axis = build_frequency_axis::<LogHz>(params, &mapping);
1062
1063 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1064
1065 Ok(SpectrogramPlan {
1066 params: *params,
1067 stft,
1068 mapping,
1069 scaling,
1070 freq_axis,
1071 workspace,
1072 _amp: PhantomData,
1073 })
1074 }
1075
1076 /// Build a cqt spectrogram plan.
1077 ///
1078 /// # Type Parameters
1079 ///
1080 /// `AmpScale`: determines whether output is:
1081 /// - Magnitude
1082 /// - Power
1083 /// - Decibels
1084 ///
1085 /// # Arguments
1086 ///
1087 /// * `params` - Spectrogram parameters
1088 /// * `cqt` - CQT-specific parameters
1089 /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
1090 ///
1091 /// # Returns
1092 ///
1093 /// A `SpectrogramPlan` configured for CQT spectrogram computation.
1094 ///
1095 /// # Errors
1096 ///
1097 /// Returns an error if the plan cannot be created due to invalid parameters.
1098 #[inline]
1099 pub fn cqt_plan<AmpScale>(
1100 &self,
1101 params: &SpectrogramParams,
1102 cqt: &CqtParams,
1103 db: Option<&LogParams>, // only used when AmpScale = Decibels
1104 ) -> SpectrogramResult<SpectrogramPlan<Cqt, AmpScale>>
1105 where
1106 AmpScale: AmpScaleSpec + 'static,
1107 {
1108 let stft = StftPlan::new(params)?;
1109 let mapping = FrequencyMapping::<Cqt>::new(params, cqt)?;
1110 let scaling = AmplitudeScaling::<AmpScale>::new(db);
1111 let freq_axis = build_frequency_axis::<Cqt>(params, &mapping);
1112
1113 let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1114
1115 Ok(SpectrogramPlan {
1116 params: *params,
1117 stft,
1118 mapping,
1119 scaling,
1120 freq_axis,
1121 workspace,
1122 _amp: PhantomData,
1123 })
1124 }
1125}
1126
1127/// STFT plan containing reusable FFT plan and buffers.
1128///
1129/// This struct is responsible for performing the Short-Time Fourier Transform (STFT)
1130/// on audio signals based on the provided parameters.
1131///
1132/// It encapsulates the FFT plan, windowing function, and internal buffers to efficiently
1133/// compute the STFT for multiple frames of audio data.
1134///
1135/// # Fields
1136///
1137/// - `n_fft`: Size of the FFT.
1138/// - `hop_size`: Hop size between consecutive frames.
1139/// - `window`: Windowing function samples.
1140/// - `centre`: Whether to centre the frames with padding.
1141/// - `out_len`: Length of the FFT output.
1142/// - `fft`: Boxed FFT plan for real-to-complex transformation.
1143/// - `fft_out`: Internal buffer for FFT output.
1144/// - `frame`: Internal buffer for windowed audio frames.
1145pub struct StftPlan {
1146 n_fft: NonZeroUsize,
1147 hop_size: NonZeroUsize,
1148 window: NonEmptyVec<f64>,
1149 centre: bool,
1150
1151 out_len: NonZeroUsize,
1152
1153 // FFT plan (reused for all frames)
1154 fft: Box<dyn R2cPlan>,
1155
1156 // internal scratch
1157 fft_out: NonEmptyVec<Complex<f64>>,
1158 frame: NonEmptyVec<f64>,
1159}
1160
1161impl StftPlan {
1162 /// Create a new STFT plan from parameters.
1163 ///
1164 /// # Arguments
1165 ///
1166 /// * `params` - Spectrogram parameters containing STFT config
1167 ///
1168 /// # Returns
1169 ///
1170 /// A new `StftPlan` instance.
1171 ///
1172 /// # Errors
1173 ///
1174 /// Returns an error if the FFT plan cannot be created.
1175 #[inline]
1176 pub fn new(params: &SpectrogramParams) -> SpectrogramResult<Self> {
1177 let stft = params.stft();
1178 let n_fft = stft.n_fft();
1179 let hop_size = stft.hop_size();
1180 let centre = stft.centre();
1181
1182 let window = make_window(stft.window(), n_fft);
1183
1184 let out_len = r2c_output_size(n_fft.get());
1185 let out_len = NonZeroUsize::new(out_len)
1186 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1187
1188 #[cfg(feature = "realfft")]
1189 let fft = {
1190 let mut planner = crate::RealFftPlanner::new();
1191 let plan = planner.get_or_create(n_fft.get());
1192 let plan = crate::RealFftPlan::new(n_fft.get(), plan);
1193 Box::new(plan)
1194 };
1195
1196 #[cfg(feature = "fftw")]
1197 let fft = {
1198 use std::sync::Arc;
1199 let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
1200 Box::new(crate::FftwPlan::new(Arc::new(plan)))
1201 };
1202
1203 Ok(Self {
1204 n_fft,
1205 hop_size,
1206 window,
1207 centre,
1208 out_len,
1209 fft,
1210 fft_out: non_empty_vec![Complex::new(0.0, 0.0); out_len],
1211 frame: non_empty_vec![0.0; n_fft],
1212 })
1213 }
1214
1215 fn frame_count(&self, n_samples: NonZeroUsize) -> SpectrogramResult<NonZeroUsize> {
1216 // Framing policy:
1217 // - centre = true: implicit padding of n_fft/2 on both sides
1218 // - centre = false: no padding
1219 //
1220 // Define the number of frames such that each frame has a valid centre sample position.
1221 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1222 let padded_len = n_samples.get() + 2 * pad;
1223
1224 if padded_len < self.n_fft.get() {
1225 // still produce 1 frame (all padding / partial)
1226 return Ok(nzu!(1));
1227 }
1228
1229 let remaining = padded_len - self.n_fft.get();
1230 let n_frames = remaining / self.hop_size().get() + 1;
1231 let n_frames = NonZeroUsize::new(n_frames).ok_or_else(|| {
1232 SpectrogramError::invalid_input("computed number of frames must be non-zero")
1233 })?;
1234 Ok(n_frames)
1235 }
1236
1237 /// Compute one frame FFT using internal buffers only.
1238 fn compute_frame_fft_simple(
1239 &mut self,
1240 samples: &NonEmptySlice<f64>,
1241 frame_idx: usize,
1242 ) -> SpectrogramResult<()> {
1243 let out = self.frame.as_mut_slice();
1244 debug_assert_eq!(out.len(), self.n_fft.get());
1245
1246 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1247 let start = frame_idx
1248 .checked_mul(self.hop_size.get())
1249 .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1250
1251 // Fill windowed frame
1252 for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1253 let v_idx = start + i;
1254 let s_idx = v_idx as isize - pad as isize;
1255
1256 let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1257 0.0
1258 } else {
1259 samples[s_idx as usize]
1260 };
1261 *sample = sample_val * self.window[i];
1262 }
1263
1264 // Compute FFT
1265 let fft_out = self.fft_out.as_mut_slice();
1266 self.fft.process(out, fft_out)?;
1267
1268 Ok(())
1269 }
1270
1271 /// Compute one frame spectrum into workspace:
1272 /// - fills windowed frame
1273 /// - runs FFT
1274 /// - converts to magnitude/power based on `AmpScale` later
1275 fn compute_frame_spectrum(
1276 &mut self,
1277 samples: &NonEmptySlice<f64>,
1278 frame_idx: usize,
1279 workspace: &mut Workspace,
1280 ) -> SpectrogramResult<()> {
1281 let out = workspace.frame.as_mut_slice();
1282
1283 // self.fill_frame(samples, frame_idx, frame)?;
1284 debug_assert_eq!(out.len(), self.n_fft.get());
1285
1286 let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1287 let start = frame_idx
1288 .checked_mul(self.hop_size().get())
1289 .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1290
1291 // The "virtual" signal is samples with pad zeros on both sides.
1292 // Virtual index 0..padded_len
1293 // Map virtual index to original samples by subtracting pad.
1294 for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1295 let v_idx = start + i;
1296 let s_idx = v_idx as isize - pad as isize;
1297
1298 let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1299 0.0
1300 } else {
1301 samples[s_idx as usize]
1302 };
1303
1304 *sample = sample_val * self.window[i];
1305 }
1306 let fft_out = workspace.fft_out.as_mut_slice();
1307 // FFT
1308 self.fft.process(out, fft_out)?;
1309
1310 // Convert complex spectrum to linear magnitude OR power here? No:
1311 // Keep "spectrum" as power by default? That would entangle semantics.
1312 //
1313 // Instead, we store magnitude^2 (power) as the canonical intermediate,
1314 // and let AmpScale decide later whether output is magnitude or power.
1315 //
1316 // This is consistent and avoids recomputing norms multiple times.
1317 for (i, c) in workspace.fft_out.iter().enumerate() {
1318 workspace.spectrum[i] = c.norm_sqr();
1319 }
1320
1321 Ok(())
1322 }
1323
1324 /// Compute the full STFT for a signal, returning an `StftResult`.
1325 ///
1326 /// This is a convenience method that handles frame iteration and
1327 /// builds the complete STFT matrix.
1328 ///
1329 ///
1330 /// # Arguments
1331 ///
1332 /// * `samples` - Input audio samples
1333 /// * `params` - STFT computation parameters
1334 ///
1335 /// # Returns
1336 ///
1337 /// An `StftResult` containing the complex STFT matrix and metadata.
1338 ///
1339 /// # Errors
1340 ///
1341 /// Returns an error if computation fails.
1342 ///
1343 /// # Examples
1344 ///
1345 /// ```rust
1346 /// use spectrograms::*;
1347 /// use non_empty_slice::non_empty_vec;
1348 ///
1349 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1350 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1351 /// let params = SpectrogramParams::new(stft, 16000.0)?;
1352 /// let mut plan = StftPlan::new(¶ms)?;
1353 ///
1354 /// let samples = non_empty_vec![0.0; nzu!(16000)];
1355 /// let stft_result = plan.compute(&samples, ¶ms)?;
1356 ///
1357 /// println!("STFT: {} bins x {} frames", stft_result.n_bins(), stft_result.n_frames());
1358 /// # Ok(())
1359 /// # }
1360 /// ```
1361 #[inline]
1362 pub fn compute(
1363 &mut self,
1364 samples: &NonEmptySlice<f64>,
1365 params: &SpectrogramParams,
1366 ) -> SpectrogramResult<StftResult> {
1367 let n_frames = self.frame_count(samples.len())?;
1368 let n_bins = self.out_len;
1369
1370 // Allocate output matrix (frequency_bins x time_frames)
1371 let mut data = Array2::<Complex<f64>>::zeros((n_bins.get(), n_frames.get()));
1372
1373 // Compute each frame
1374 for frame_idx in 0..n_frames.get() {
1375 self.compute_frame_fft_simple(samples, frame_idx)?;
1376
1377 // Copy from internal buffer to output
1378 for (bin_idx, &value) in self.fft_out.iter().enumerate() {
1379 data[[bin_idx, frame_idx]] = value;
1380 }
1381 }
1382
1383 // Build frequency axis
1384 let frequencies: Vec<f64> = (0..n_bins.get())
1385 .map(|k| k as f64 * params.sample_rate_hz() / params.stft().n_fft().get() as f64)
1386 .collect();
1387 // SAFETY: n_bins > 0
1388 let frequencies = unsafe { NonEmptyVec::new_unchecked(frequencies) };
1389
1390 Ok(StftResult {
1391 data,
1392 frequencies,
1393 sample_rate: params.sample_rate_hz(),
1394 params: *params.stft(),
1395 })
1396 }
1397
1398 /// Compute a single frame of STFT, returning the complex spectrum.
1399 ///
1400 /// This is useful for streaming/online processing where you want to
1401 /// process audio frame-by-frame.
1402 ///
1403 /// # Arguments
1404 ///
1405 /// * `samples` - Input audio samples
1406 /// * `frame_idx` - Index of the frame to compute
1407 ///
1408 /// # Returns
1409 ///
1410 /// A `NonEmptyVec` containing the complex spectrum for the specified frame.
1411 ///
1412 /// # Errors
1413 ///
1414 /// Returns an error if computation fails.
1415 ///
1416 /// # Examples
1417 ///
1418 /// ```rust
1419 /// use spectrograms::*;
1420 /// use non_empty_slice::non_empty_vec;
1421 ///
1422 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1423 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1424 /// let params = SpectrogramParams::new(stft, 16000.0)?;
1425 /// let mut plan = StftPlan::new(¶ms)?;
1426 ///
1427 /// let samples = non_empty_vec![0.0; nzu!(16000)];
1428 /// let (_, n_frames) = plan.output_shape(samples.len())?;
1429 ///
1430 /// for frame_idx in 0..n_frames.get() {
1431 /// let spectrum = plan.compute_frame_simple(&samples, frame_idx)?;
1432 /// // Process spectrum...
1433 /// }
1434 /// # Ok(())
1435 /// # }
1436 /// ```
1437 #[inline]
1438 pub fn compute_frame_simple(
1439 &mut self,
1440 samples: &NonEmptySlice<f64>,
1441 frame_idx: usize,
1442 ) -> SpectrogramResult<NonEmptyVec<Complex<f64>>> {
1443 self.compute_frame_fft_simple(samples, frame_idx)?;
1444 Ok(self.fft_out.clone())
1445 }
1446
1447 /// Compute STFT into a pre-allocated buffer.
1448 ///
1449 /// This avoids allocating the output matrix, useful for reusing buffers.
1450 ///
1451 /// # Arguments
1452 ///
1453 /// * `samples` - Input audio samples
1454 /// * `output` - Pre-allocated output buffer (shape: `n_bins` x `n_frames`)
1455 ///
1456 /// # Returns
1457 ///
1458 /// `Ok(())` on success, or an error if dimensions mismatch.
1459 ///
1460 /// # Errors
1461 ///
1462 /// Returns `DimensionMismatch` error if the output buffer has incorrect shape.
1463 ///
1464 /// # Examples
1465 ///
1466 /// ```rust
1467 /// use spectrograms::*;
1468 /// use ndarray::Array2;
1469 /// use num_complex::Complex;
1470 /// use non_empty_slice::non_empty_vec;
1471 ///
1472 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1473 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1474 /// let params = SpectrogramParams::new(stft, 16000.0)?;
1475 /// let mut plan = StftPlan::new(¶ms)?;
1476 ///
1477 /// let samples = non_empty_vec![0.0; nzu!(16000)];
1478 /// let (n_bins, n_frames) = plan.output_shape(samples.len())?;
1479 /// let mut output = Array2::<Complex<f64>>::zeros((n_bins.get(), n_frames.get()));
1480 ///
1481 /// plan.compute_into(&samples, &mut output)?;
1482 /// # Ok(())
1483 /// # }
1484 /// ```
1485 #[inline]
1486 pub fn compute_into(
1487 &mut self,
1488 samples: &NonEmptySlice<f64>,
1489 output: &mut Array2<Complex<f64>>,
1490 ) -> SpectrogramResult<()> {
1491 let n_frames = self.frame_count(samples.len())?;
1492 let n_bins = self.out_len;
1493
1494 // Validate output dimensions
1495 if output.nrows() != n_bins.get() {
1496 return Err(SpectrogramError::dimension_mismatch(
1497 n_bins.get(),
1498 output.nrows(),
1499 ));
1500 }
1501 if output.ncols() != n_frames.get() {
1502 return Err(SpectrogramError::dimension_mismatch(
1503 n_frames.get(),
1504 output.ncols(),
1505 ));
1506 }
1507
1508 // Compute into pre-allocated buffer
1509 for frame_idx in 0..n_frames.get() {
1510 self.compute_frame_fft_simple(samples, frame_idx)?;
1511
1512 for (bin_idx, &value) in self.fft_out.iter().enumerate() {
1513 output[[bin_idx, frame_idx]] = value;
1514 }
1515 }
1516
1517 Ok(())
1518 }
1519
1520 /// Get the expected output dimensions for a given signal length.
1521 ///
1522 /// # Arguments
1523 ///
1524 /// * `signal_length` - Length of the input signal in samples
1525 ///
1526 /// # Returns
1527 ///
1528 /// A tuple `(n_frequency_bins, n_time_frames)`.
1529 ///
1530 /// # Errors
1531 ///
1532 /// Returns an error if the computed number of frames is invalid.
1533 #[inline]
1534 pub fn output_shape(
1535 &self,
1536 signal_length: NonZeroUsize,
1537 ) -> SpectrogramResult<(NonZeroUsize, NonZeroUsize)> {
1538 let n_frames = self.frame_count(signal_length)?;
1539 Ok((self.out_len, n_frames))
1540 }
1541
1542 /// Get the number of frequency bins in the output.
1543 ///
1544 /// # Returns
1545 ///
1546 /// The number of frequency bins.
1547 #[inline]
1548 #[must_use]
1549 pub const fn n_bins(&self) -> NonZeroUsize {
1550 self.out_len
1551 }
1552
1553 /// Get the FFT size.
1554 ///
1555 /// # Returns
1556 ///
1557 /// The FFT size.
1558 #[inline]
1559 #[must_use]
1560 pub const fn n_fft(&self) -> NonZeroUsize {
1561 self.n_fft
1562 }
1563
1564 /// Get the hop size.
1565 ///
1566 /// # Returns
1567 ///
1568 /// The hop size.
1569 #[inline]
1570 #[must_use]
1571 pub const fn hop_size(&self) -> NonZeroUsize {
1572 self.hop_size
1573 }
1574}
1575
1576#[derive(Debug, Clone)]
1577enum MappingKind {
1578 Identity {
1579 out_len: NonZeroUsize,
1580 },
1581 Mel {
1582 matrix: SparseMatrix,
1583 }, // shape: (n_mels, out_len)
1584 LogHz {
1585 matrix: SparseMatrix,
1586 frequencies: NonEmptyVec<f64>,
1587 }, // shape: (n_bins, out_len)
1588 Erb {
1589 filterbank: ErbFilterbank,
1590 },
1591 Cqt {
1592 kernel: CqtKernel,
1593 },
1594}
1595
1596/// Typed mapping wrapper.
1597#[derive(Debug, Clone)]
1598struct FrequencyMapping<FreqScale> {
1599 kind: MappingKind,
1600 _marker: PhantomData<FreqScale>,
1601}
1602
1603impl FrequencyMapping<LinearHz> {
1604 fn new(params: &SpectrogramParams) -> SpectrogramResult<Self> {
1605 let out_len = r2c_output_size(params.stft().n_fft().get());
1606 let out_len = NonZeroUsize::new(out_len)
1607 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1608 Ok(Self {
1609 kind: MappingKind::Identity { out_len },
1610 _marker: PhantomData,
1611 })
1612 }
1613}
1614
1615impl FrequencyMapping<Mel> {
1616 fn new_mel(params: &SpectrogramParams, mel: &MelParams) -> SpectrogramResult<Self> {
1617 let n_fft = params.stft().n_fft();
1618 let out_len = r2c_output_size(n_fft.get());
1619 let out_len = NonZeroUsize::new(out_len)
1620 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1621
1622 // Validate: mel bins must be <= something sensible
1623 if mel.n_mels() > nzu!(10_000) {
1624 return Err(SpectrogramError::invalid_input(
1625 "n_mels is unreasonably large",
1626 ));
1627 }
1628
1629 let matrix = build_mel_filterbank_matrix(
1630 params.sample_rate_hz(),
1631 n_fft,
1632 mel.n_mels(),
1633 mel.f_min(),
1634 mel.f_max(),
1635 mel.norm(),
1636 )?;
1637
1638 // matrix must be (n_mels, out_len)
1639 if matrix.nrows() != mel.n_mels().get() || matrix.ncols() != out_len.get() {
1640 return Err(SpectrogramError::invalid_input(
1641 "mel filterbank matrix shape mismatch",
1642 ));
1643 }
1644
1645 Ok(Self {
1646 kind: MappingKind::Mel { matrix },
1647 _marker: PhantomData,
1648 })
1649 }
1650}
1651
1652impl FrequencyMapping<LogHz> {
1653 fn new_loghz(params: &SpectrogramParams, loghz: &LogHzParams) -> SpectrogramResult<Self> {
1654 let n_fft = params.stft().n_fft();
1655 let out_len = r2c_output_size(n_fft.get());
1656 let out_len = NonZeroUsize::new(out_len)
1657 .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1658 // Validate: n_bins must be <= something sensible
1659 if loghz.n_bins() > nzu!(10_000) {
1660 return Err(SpectrogramError::invalid_input(
1661 "n_bins is unreasonably large",
1662 ));
1663 }
1664
1665 let (matrix, frequencies) = build_loghz_matrix(
1666 params.sample_rate_hz(),
1667 n_fft,
1668 loghz.n_bins(),
1669 loghz.f_min(),
1670 loghz.f_max(),
1671 )?;
1672
1673 // matrix must be (n_bins, out_len)
1674 if matrix.nrows() != loghz.n_bins().get() || matrix.ncols() != out_len.get() {
1675 return Err(SpectrogramError::invalid_input(
1676 "loghz matrix shape mismatch",
1677 ));
1678 }
1679
1680 Ok(Self {
1681 kind: MappingKind::LogHz {
1682 matrix,
1683 frequencies,
1684 },
1685 _marker: PhantomData,
1686 })
1687 }
1688}
1689
1690impl FrequencyMapping<Erb> {
1691 fn new_erb(params: &SpectrogramParams, erb: &crate::erb::ErbParams) -> SpectrogramResult<Self> {
1692 let n_fft = params.stft().n_fft();
1693 let sample_rate = params.sample_rate_hz();
1694
1695 // Validate: n_filters must be <= something sensible
1696 if erb.n_filters() > nzu!(10_000) {
1697 return Err(SpectrogramError::invalid_input(
1698 "n_filters is unreasonably large",
1699 ));
1700 }
1701
1702 // Generate ERB filterbank with pre-computed frequency responses
1703 let filterbank = crate::erb::ErbFilterbank::generate(erb, sample_rate, n_fft)?;
1704
1705 Ok(Self {
1706 kind: MappingKind::Erb { filterbank },
1707 _marker: PhantomData,
1708 })
1709 }
1710}
1711
1712impl FrequencyMapping<Cqt> {
1713 fn new(params: &SpectrogramParams, cqt: &CqtParams) -> SpectrogramResult<Self> {
1714 let sample_rate = params.sample_rate_hz();
1715 let n_fft = params.stft().n_fft();
1716
1717 // Validate that frequency range is reasonable
1718 let f_max = cqt.bin_frequency(cqt.num_bins().get().saturating_sub(1));
1719 if f_max >= sample_rate / 2.0 {
1720 return Err(SpectrogramError::invalid_input(
1721 "CQT maximum frequency must be below Nyquist frequency",
1722 ));
1723 }
1724
1725 // Generate CQT kernel using n_fft as the signal length for kernel generation
1726 let kernel = CqtKernel::generate(cqt, sample_rate, n_fft);
1727
1728 Ok(Self {
1729 kind: MappingKind::Cqt { kernel },
1730 _marker: PhantomData,
1731 })
1732 }
1733}
1734
1735impl<FreqScale> FrequencyMapping<FreqScale> {
1736 const fn output_bins(&self) -> NonZeroUsize {
1737 // safety: all variants ensure output bins > 0 OR rely on a matrix that is guaranteed to have rows > 0
1738 match &self.kind {
1739 MappingKind::Identity { out_len } => *out_len,
1740 // safety: matrix.nrows() > 0
1741 MappingKind::LogHz { matrix, .. } | MappingKind::Mel { matrix } => unsafe {
1742 NonZeroUsize::new_unchecked(matrix.nrows())
1743 },
1744 MappingKind::Erb { filterbank, .. } => filterbank.num_filters(),
1745 MappingKind::Cqt { kernel, .. } => kernel.num_bins(),
1746 }
1747 }
1748
1749 fn apply(
1750 &self,
1751 spectrum: &NonEmptySlice<f64>,
1752 frame: &NonEmptySlice<f64>,
1753 out: &mut NonEmptySlice<f64>,
1754 ) -> SpectrogramResult<()> {
1755 match &self.kind {
1756 MappingKind::Identity { out_len } => {
1757 if spectrum.len() != *out_len {
1758 return Err(SpectrogramError::dimension_mismatch(
1759 (*out_len).get(),
1760 spectrum.len().get(),
1761 ));
1762 }
1763 if out.len() != *out_len {
1764 return Err(SpectrogramError::dimension_mismatch(
1765 (*out_len).get(),
1766 out.len().get(),
1767 ));
1768 }
1769 out.copy_from_slice(spectrum);
1770 Ok(())
1771 }
1772 MappingKind::LogHz { matrix, .. } | MappingKind::Mel { matrix } => {
1773 let out_bins = matrix.nrows();
1774 let in_bins = matrix.ncols();
1775
1776 if spectrum.len().get() != in_bins {
1777 return Err(SpectrogramError::dimension_mismatch(
1778 in_bins,
1779 spectrum.len().get(),
1780 ));
1781 }
1782 if out.len().get() != out_bins {
1783 return Err(SpectrogramError::dimension_mismatch(
1784 out_bins,
1785 out.len().get(),
1786 ));
1787 }
1788
1789 // Sparse matrix-vector multiplication: out = matrix * spectrum
1790 matrix.multiply_vec(spectrum, out);
1791 Ok(())
1792 }
1793 MappingKind::Erb { filterbank } => {
1794 // Apply ERB filterbank using pre-computed frequency responses
1795 // The filterbank already has |H(f)|^2 pre-computed, so we just
1796 // apply it to the power spectrum
1797 let erb_out = filterbank.apply_to_power_spectrum(spectrum)?;
1798
1799 if out.len().get() != erb_out.len().get() {
1800 return Err(SpectrogramError::dimension_mismatch(
1801 erb_out.len().get(),
1802 out.len().get(),
1803 ));
1804 }
1805
1806 out.copy_from_slice(&erb_out);
1807 Ok(())
1808 }
1809 MappingKind::Cqt { kernel } => {
1810 // CQT works on time-domain windowed frame, not FFT spectrum
1811 // Apply CQT kernel to get complex coefficients
1812 let cqt_complex = kernel.apply(frame)?;
1813
1814 if out.len().get() != cqt_complex.len().get() {
1815 return Err(SpectrogramError::dimension_mismatch(
1816 cqt_complex.len().get(),
1817 out.len().get(),
1818 ));
1819 }
1820
1821 // Convert complex coefficients to power (|z|^2)
1822 // This matches the convention where intermediate values are in power domain
1823 for (i, c) in cqt_complex.iter().enumerate() {
1824 out[i] = c.norm_sqr();
1825 }
1826
1827 Ok(())
1828 }
1829 }
1830 }
1831
1832 fn frequencies_hz(&self, params: &SpectrogramParams) -> NonEmptyVec<f64> {
1833 match &self.kind {
1834 MappingKind::Identity { out_len } => {
1835 // Standard R2C bins: k * sr / n_fft
1836 let n_fft = params.stft().n_fft().get() as f64;
1837 let sr = params.sample_rate_hz();
1838 let df = sr / n_fft;
1839
1840 let mut f = Vec::with_capacity((*out_len).get());
1841 for k in 0..(*out_len).get() {
1842 f.push(k as f64 * df);
1843 }
1844 // safety: out_len > 0
1845 unsafe { NonEmptyVec::new_unchecked(f) }
1846 }
1847 MappingKind::Mel { matrix } => {
1848 // For mel, the axis is defined by the mel band centre frequencies.
1849 // We compute and store them consistently with how we built the filterbank.
1850 let n_mels = matrix.nrows();
1851 // safety: n_mels > 0
1852 let n_mels = unsafe { NonZeroUsize::new_unchecked(n_mels) };
1853 mel_band_centres_hz(n_mels, params.sample_rate_hz(), params.nyquist_hz())
1854 }
1855 MappingKind::LogHz { frequencies, .. } => {
1856 // Frequencies are stored when the mapping is created
1857 frequencies.clone()
1858 }
1859 MappingKind::Erb { filterbank, .. } => {
1860 // ERB center frequencies
1861 filterbank.center_frequencies().to_non_empty_vec()
1862 }
1863 MappingKind::Cqt { kernel, .. } => {
1864 // CQT center frequencies from the kernel
1865 kernel.frequencies().to_non_empty_vec()
1866 }
1867 }
1868 }
1869}
1870
1871//
1872// ========================
1873// Amplitude scaling
1874// ========================
1875//
1876
1877/// Marker trait so we can specialise behaviour by `AmpScale`.
1878pub trait AmpScaleSpec: Sized + Send + Sync {
1879 /// Apply conversion from power-domain value to the desired amplitude scale.
1880 ///
1881 /// # Arguments
1882 ///
1883 /// - `power`: input power-domain value.
1884 ///
1885 /// # Returns
1886 ///
1887 /// Converted amplitude value.
1888 fn apply_from_power(power: f64) -> f64;
1889
1890 /// Apply dB conversion in-place on a power-domain vector.
1891 ///
1892 /// This is a no-op for Power and Magnitude scales.
1893 ///
1894 /// # Arguments
1895 ///
1896 /// - `x`: power-domain values to convert to dB in-place.
1897 /// - `floor_db`: dB floor value to apply.
1898 ///
1899 /// # Returns
1900 ///
1901 /// `SpectrogramResult<()>`: Ok on success, error on invalid input.
1902 ///
1903 /// # Errors
1904 ///
1905 /// - If `floor_db` is not finite.
1906 fn apply_db_in_place(x: &mut [f64], floor_db: f64) -> SpectrogramResult<()>;
1907}
1908
1909impl AmpScaleSpec for Power {
1910 #[inline]
1911 fn apply_from_power(power: f64) -> f64 {
1912 power
1913 }
1914
1915 #[inline]
1916 fn apply_db_in_place(_x: &mut [f64], _floor_db: f64) -> SpectrogramResult<()> {
1917 Ok(())
1918 }
1919}
1920
1921impl AmpScaleSpec for Magnitude {
1922 #[inline]
1923 fn apply_from_power(power: f64) -> f64 {
1924 power.sqrt()
1925 }
1926
1927 #[inline]
1928 fn apply_db_in_place(_x: &mut [f64], _floor_db: f64) -> SpectrogramResult<()> {
1929 Ok(())
1930 }
1931}
1932
1933impl AmpScaleSpec for Decibels {
1934 #[inline]
1935 fn apply_from_power(power: f64) -> f64 {
1936 // dB conversion is applied in batch, not here.
1937 power
1938 }
1939
1940 #[inline]
1941 fn apply_db_in_place(x: &mut [f64], floor_db: f64) -> SpectrogramResult<()> {
1942 // Convert power -> dB: 10*log10(max(power, eps))
1943 // where eps is derived from floor_db to ensure consistency
1944 if !floor_db.is_finite() {
1945 return Err(SpectrogramError::invalid_input("floor_db must be finite"));
1946 }
1947
1948 // Convert floor_db to linear scale to get epsilon
1949 // e.g., floor_db = -80 dB -> eps = 10^(-80/10) = 1e-8
1950 let eps = 10.0_f64.powf(floor_db / 10.0);
1951
1952 for v in x.iter_mut() {
1953 // Clamp power to epsilon before log to avoid log(0) and ensure floor
1954 *v = 10.0 * v.max(eps).log10();
1955 }
1956 Ok(())
1957 }
1958}
1959
1960/// Amplitude scaling configuration.
1961///
1962/// This handles conversion from power-domain intermediate to the desired amplitude scale (Power, Magnitude, Decibels).
1963#[derive(Debug, Clone)]
1964struct AmplitudeScaling<AmpScale> {
1965 db_floor: Option<f64>,
1966 _marker: PhantomData<AmpScale>,
1967}
1968
1969impl<AmpScale> AmplitudeScaling<AmpScale>
1970where
1971 AmpScale: AmpScaleSpec + 'static,
1972{
1973 fn new(db: Option<&LogParams>) -> Self {
1974 let db_floor = db.map(LogParams::floor_db);
1975 Self {
1976 db_floor,
1977 _marker: PhantomData,
1978 }
1979 }
1980
1981 /// Apply amplitude scaling in-place on a mapped spectrum vector.
1982 ///
1983 /// The input vector is assumed to be in the *power* domain (|X|^2),
1984 /// because the STFT stage produces power as the canonical intermediate.
1985 ///
1986 /// - Power: leaves values unchanged.
1987 /// - Magnitude: sqrt(power).
1988 /// - Decibels: converts power -> dB and floors at `db_floor`.
1989 pub fn apply_in_place(&self, x: &mut [f64]) -> SpectrogramResult<()> {
1990 // Convert from canonical power-domain intermediate into the requested linear domain.
1991 for v in x.iter_mut() {
1992 *v = AmpScale::apply_from_power(*v);
1993 }
1994
1995 // Apply dB conversion if configured (no-op for Power/Magnitude via trait impls).
1996 if let Some(floor_db) = self.db_floor {
1997 AmpScale::apply_db_in_place(x, floor_db)?;
1998 }
1999
2000 Ok(())
2001 }
2002}
2003
2004#[derive(Debug, Clone)]
2005struct Workspace {
2006 spectrum: NonEmptyVec<f64>, // out_len (power spectrum)
2007 mapped: NonEmptyVec<f64>, // n_bins (after mapping)
2008 frame: NonEmptyVec<f64>, // n_fft (windowed frame for FFT)
2009 fft_out: NonEmptyVec<Complex<f64>>, // out_len (FFT output)
2010}
2011
2012impl Workspace {
2013 fn new(n_fft: NonZeroUsize, out_len: NonZeroUsize, n_bins: NonZeroUsize) -> Self {
2014 Self {
2015 spectrum: non_empty_vec![0.0; out_len],
2016 mapped: non_empty_vec![0.0; n_bins],
2017 frame: non_empty_vec![0.0; n_fft],
2018 fft_out: non_empty_vec![Complex::new(0.0, 0.0); out_len],
2019 }
2020 }
2021
2022 fn ensure_sizes(&mut self, n_fft: NonZeroUsize, out_len: NonZeroUsize, n_bins: NonZeroUsize) {
2023 if self.spectrum.len() != out_len {
2024 self.spectrum.resize(out_len, 0.0);
2025 }
2026 if self.mapped.len() != n_bins {
2027 self.mapped.resize(n_bins, 0.0);
2028 }
2029 if self.frame.len() != n_fft {
2030 self.frame.resize(n_fft, 0.0);
2031 }
2032 if self.fft_out.len() != out_len {
2033 self.fft_out.resize(out_len, Complex::new(0.0, 0.0));
2034 }
2035 }
2036}
2037
2038fn build_frequency_axis<FreqScale>(
2039 params: &SpectrogramParams,
2040 mapping: &FrequencyMapping<FreqScale>,
2041) -> FrequencyAxis<FreqScale>
2042where
2043 FreqScale: Copy + Clone + 'static,
2044{
2045 let frequencies = mapping.frequencies_hz(params);
2046 FrequencyAxis::new(frequencies)
2047}
2048
2049fn build_time_axis_seconds(params: &SpectrogramParams, n_frames: NonZeroUsize) -> NonEmptyVec<f64> {
2050 let dt = params.frame_period_seconds();
2051 let mut times = Vec::with_capacity(n_frames.get());
2052
2053 for i in 0..n_frames.get() {
2054 times.push(i as f64 * dt);
2055 }
2056
2057 // safety: times is guaranteed non-empty since n_frames > 0
2058
2059 unsafe { NonEmptyVec::new_unchecked(times) }
2060}
2061
2062/// Generate window function samples.
2063///
2064/// Supports various window types including Rectangular, Hanning, Hamming, Blackman, Kaiser, and Gaussian.
2065///
2066/// # Arguments
2067///
2068/// * `window` - The type of window function to generate.
2069/// * `n_fft` - The size of the FFT, which determines the length of the window.
2070///
2071/// # Returns
2072///
2073/// A `NonEmptyVec<f64>` containing the window function samples.
2074#[inline]
2075#[must_use]
2076pub fn make_window(window: WindowType, n_fft: NonZeroUsize) -> NonEmptyVec<f64> {
2077 let n_fft = n_fft.get();
2078 let mut w = vec![0.0; n_fft];
2079
2080 match window {
2081 WindowType::Rectangular => {
2082 w.fill(1.0);
2083 }
2084 WindowType::Hanning => {
2085 // Hann: 0.5 - 0.5*cos(2Ï€n/(N-1))
2086 let n1 = (n_fft - 1) as f64;
2087 for (n, v) in w.iter_mut().enumerate() {
2088 *v = 0.5f64.mul_add(-(2.0 * std::f64::consts::PI * (n as f64) / n1).cos(), 0.5);
2089 }
2090 }
2091 WindowType::Hamming => {
2092 // Hamming: 0.54 - 0.46*cos(2Ï€n/(N-1))
2093 let n1 = (n_fft - 1) as f64;
2094 for (n, v) in w.iter_mut().enumerate() {
2095 *v = 0.46f64.mul_add(-(2.0 * std::f64::consts::PI * (n as f64) / n1).cos(), 0.54);
2096 }
2097 }
2098 WindowType::Blackman => {
2099 // Blackman: 0.42 - 0.5*cos(2Ï€n/(N-1)) + 0.08*cos(4Ï€n/(N-1))
2100 let n1 = (n_fft - 1) as f64;
2101 for (n, v) in w.iter_mut().enumerate() {
2102 let a = 2.0 * std::f64::consts::PI * (n as f64) / n1;
2103 *v = 0.08f64.mul_add((2.0 * a).cos(), 0.5f64.mul_add(-a.cos(), 0.42));
2104 }
2105 }
2106 WindowType::Kaiser { beta } => {
2107 (0..n_fft).for_each(|i| {
2108 let n = i as f64;
2109 let n_max: f64 = (n_fft - 1) as f64;
2110 let alpha: f64 = (n - n_max / 2.0) / (n_max / 2.0);
2111 let bessel_arg = beta * alpha.mul_add(-alpha, 1.0).sqrt();
2112 // Simplified approximation of modified Bessel function
2113 let x = 1.0
2114 + bessel_arg / 2.0
2115 // Normalize by I0(beta) approximation
2116 / (1.0 + beta / 2.0);
2117 w[i] = x;
2118 });
2119 }
2120 WindowType::Gaussian { std } => (0..n_fft).for_each(|i| {
2121 let n = i as f64;
2122 let center: f64 = (n_fft - 1) as f64 / 2.0;
2123 let exponent: f64 = -0.5 * ((n - center) / std).powi(2);
2124 w[i] = exponent.exp();
2125 }),
2126 }
2127
2128 // safety: window is guaranteed non-empty since n_fft > 0
2129 unsafe { NonEmptyVec::new_unchecked(w) }
2130}
2131
2132/// Convert Hz to mel scale using Slaney formula (librosa default, htk=False).
2133///
2134/// Uses a hybrid scale:
2135/// - Linear below 1000 Hz: mel = hz / (200/3)
2136/// - Logarithmic above 1000 Hz: mel = 15 + log(hz/1000) / log_step
2137///
2138/// This matches librosa's default behavior.
2139fn hz_to_mel(hz: f64) -> f64 {
2140 const F_MIN: f64 = 0.0;
2141 const F_SP: f64 = 200.0 / 3.0; // ~66.667
2142 const MIN_LOG_HZ: f64 = 1000.0;
2143 const MIN_LOG_MEL: f64 = (MIN_LOG_HZ - F_MIN) / F_SP; // = 15.0
2144 const LOGSTEP: f64 = 0.06853891945200942; // ln(6.4) / 27
2145
2146 if hz >= MIN_LOG_HZ {
2147 // Logarithmic region
2148 MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / LOGSTEP
2149 } else {
2150 // Linear region
2151 (hz - F_MIN) / F_SP
2152 }
2153}
2154
2155/// Convert mel to Hz using Slaney formula (librosa default, htk=False).
2156///
2157/// Inverse of hz_to_mel.
2158fn mel_to_hz(mel: f64) -> f64 {
2159 const F_MIN: f64 = 0.0;
2160 const F_SP: f64 = 200.0 / 3.0; // ~66.667
2161 const MIN_LOG_HZ: f64 = 1000.0;
2162 const MIN_LOG_MEL: f64 = (MIN_LOG_HZ - F_MIN) / F_SP; // = 15.0
2163 const LOGSTEP: f64 = 0.06853891945200942; // ln(6.4) / 27
2164
2165 if mel >= MIN_LOG_MEL {
2166 // Logarithmic region
2167 MIN_LOG_HZ * (LOGSTEP * (mel - MIN_LOG_MEL)).exp()
2168 } else {
2169 // Linear region
2170 F_MIN + F_SP * mel
2171 }
2172}
2173
2174fn build_mel_filterbank_matrix(
2175 sample_rate_hz: f64,
2176 n_fft: NonZeroUsize,
2177 n_mels: NonZeroUsize,
2178 f_min: f64,
2179 f_max: f64,
2180 norm: MelNorm,
2181) -> SpectrogramResult<SparseMatrix> {
2182 if sample_rate_hz <= 0.0 || !sample_rate_hz.is_finite() {
2183 return Err(SpectrogramError::invalid_input(
2184 "sample_rate_hz must be finite and > 0",
2185 ));
2186 }
2187 if f_min < 0.0 || f_min.is_infinite() {
2188 return Err(SpectrogramError::invalid_input("f_min must be >= 0"));
2189 }
2190 if f_max <= f_min {
2191 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
2192 }
2193 if f_max > sample_rate_hz * 0.5 {
2194 return Err(SpectrogramError::invalid_input("f_max must be <= Nyquist"));
2195 }
2196 let n_mels = n_mels.get();
2197 let n_fft = n_fft.get();
2198 let out_len = r2c_output_size(n_fft);
2199
2200 // FFT bin frequencies
2201 let df = sample_rate_hz / n_fft as f64;
2202
2203 // Mel points: n_mels + 2 (for triangular edges)
2204 let mel_min = hz_to_mel(f_min);
2205 let mel_max = hz_to_mel(f_max);
2206
2207 let n_points = n_mels + 2;
2208 let step = (mel_max - mel_min) / (n_points - 1) as f64;
2209
2210 let mut mel_points = Vec::with_capacity(n_points);
2211 for i in 0..n_points {
2212 mel_points.push((i as f64).mul_add(step, mel_min));
2213 }
2214
2215 let mut hz_points = Vec::with_capacity(n_points);
2216 for m in &mel_points {
2217 hz_points.push(mel_to_hz(*m));
2218 }
2219
2220 // Convert Hz points into FFT bin indices
2221 let mut bin_points = Vec::with_capacity(n_points);
2222 for hz in hz_points {
2223 let b = (hz / df).floor() as isize;
2224 let b = b.clamp(0, (out_len - 1) as isize);
2225 bin_points.push(b as usize);
2226 }
2227
2228 // Build filterbank as sparse matrix
2229 let mut fb = SparseMatrix::new(n_mels, out_len);
2230
2231 for m in 0..n_mels {
2232 let left = bin_points[m];
2233 let centre = bin_points[m + 1];
2234 let right = bin_points[m + 2];
2235
2236 if left == centre || centre == right {
2237 // Degenerate triangle, skip (should be rare if params are sensible)
2238 continue;
2239 }
2240
2241 // Rising slope: left -> centre
2242 for k in left..centre {
2243 let v = (k - left) as f64 / (centre - left) as f64;
2244 fb.set(m, k, v);
2245 }
2246
2247 // Falling slope: centre -> right
2248 for k in centre..right {
2249 let v = (right - k) as f64 / (right - centre) as f64;
2250 fb.set(m, k, v);
2251 }
2252 }
2253
2254 // Apply normalization
2255 match norm {
2256 MelNorm::None => {
2257 // No normalization needed
2258 }
2259 MelNorm::Slaney => {
2260 // Slaney-style area normalization: 2 / (hz_max - hz_min) for each triangle
2261 // NOTE: Uses Hz bandwidth, not mel bandwidth (to match librosa's implementation)
2262 for m in 0..n_mels {
2263 let mel_left = mel_points[m];
2264 let mel_right = mel_points[m + 2];
2265 let hz_left = mel_to_hz(mel_left);
2266 let hz_right = mel_to_hz(mel_right);
2267 let enorm = 2.0 / (hz_right - hz_left);
2268
2269 // Normalize all values in this row
2270 for val in &mut fb.values[m] {
2271 *val *= enorm;
2272 }
2273 }
2274 }
2275 MelNorm::L1 => {
2276 // L1 normalization: sum of weights = 1.0
2277 for m in 0..n_mels {
2278 let sum: f64 = fb.values[m].iter().sum();
2279 if sum > 0.0 {
2280 let normalizer = 1.0 / sum;
2281 for val in &mut fb.values[m] {
2282 *val *= normalizer;
2283 }
2284 }
2285 }
2286 }
2287 MelNorm::L2 => {
2288 // L2 normalization: L2 norm = 1.0
2289 for m in 0..n_mels {
2290 let norm_val: f64 = fb.values[m].iter().map(|&v| v * v).sum::<f64>().sqrt();
2291 if norm_val > 0.0 {
2292 let normalizer = 1.0 / norm_val;
2293 for val in &mut fb.values[m] {
2294 *val *= normalizer;
2295 }
2296 }
2297 }
2298 }
2299 }
2300
2301 Ok(fb)
2302}
2303
2304/// Build a logarithmic frequency interpolation matrix.
2305///
2306/// Maps linearly-spaced FFT bins to logarithmically-spaced frequency bins
2307/// using linear interpolation.
2308fn build_loghz_matrix(
2309 sample_rate_hz: f64,
2310 n_fft: NonZeroUsize,
2311 n_bins: NonZeroUsize,
2312 f_min: f64,
2313 f_max: f64,
2314) -> SpectrogramResult<(SparseMatrix, NonEmptyVec<f64>)> {
2315 if sample_rate_hz <= 0.0 || !sample_rate_hz.is_finite() {
2316 return Err(SpectrogramError::invalid_input(
2317 "sample_rate_hz must be finite and > 0",
2318 ));
2319 }
2320 if f_min <= 0.0 || f_min.is_infinite() {
2321 return Err(SpectrogramError::invalid_input(
2322 "f_min must be finite and > 0",
2323 ));
2324 }
2325 if f_max <= f_min {
2326 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
2327 }
2328 if f_max > sample_rate_hz * 0.5 {
2329 return Err(SpectrogramError::invalid_input("f_max must be <= Nyquist"));
2330 }
2331
2332 let n_bins = n_bins.get();
2333 let n_fft = n_fft.get();
2334
2335 let out_len = r2c_output_size(n_fft);
2336 let df = sample_rate_hz / n_fft as f64;
2337
2338 // Generate logarithmically-spaced frequencies
2339 let log_f_min = f_min.ln();
2340 let log_f_max = f_max.ln();
2341 let log_step = (log_f_max - log_f_min) / (n_bins - 1) as f64;
2342
2343 let mut log_frequencies = Vec::with_capacity(n_bins);
2344 for i in 0..n_bins {
2345 let log_f = (i as f64).mul_add(log_step, log_f_min);
2346 log_frequencies.push(log_f.exp());
2347 }
2348 // safety: n_bins > 0
2349 let log_frequencies = unsafe { NonEmptyVec::new_unchecked(log_frequencies) };
2350
2351 // Build interpolation matrix as sparse matrix
2352 let mut matrix = SparseMatrix::new(n_bins, out_len);
2353
2354 for (bin_idx, &target_freq) in log_frequencies.iter().enumerate() {
2355 // Find the two FFT bins that bracket this frequency
2356 let exact_bin = target_freq / df;
2357 let lower_bin = exact_bin.floor() as usize;
2358 let upper_bin = (exact_bin.ceil() as usize).min(out_len - 1);
2359
2360 if lower_bin >= out_len {
2361 continue;
2362 }
2363
2364 if lower_bin == upper_bin {
2365 // Exact match
2366 matrix.set(bin_idx, lower_bin, 1.0);
2367 } else {
2368 // Linear interpolation
2369 let frac = exact_bin - lower_bin as f64;
2370 matrix.set(bin_idx, lower_bin, 1.0 - frac);
2371 if upper_bin < out_len {
2372 matrix.set(bin_idx, upper_bin, frac);
2373 }
2374 }
2375 }
2376
2377 Ok((matrix, log_frequencies))
2378}
2379
2380fn mel_band_centres_hz(
2381 n_mels: NonZeroUsize,
2382 sample_rate_hz: f64,
2383 nyquist_hz: f64,
2384) -> NonEmptyVec<f64> {
2385 let f_min = 0.0;
2386 let f_max = nyquist_hz.min(sample_rate_hz * 0.5);
2387
2388 let mel_min = hz_to_mel(f_min);
2389 let mel_max = hz_to_mel(f_max);
2390 let n_mels = n_mels.get();
2391 let step = (mel_max - mel_min) / (n_mels + 1) as f64;
2392
2393 let mut centres = Vec::with_capacity(n_mels);
2394 for i in 0..n_mels {
2395 let mel = (i as f64 + 1.0).mul_add(step, mel_min);
2396 centres.push(mel_to_hz(mel));
2397 }
2398 // safety: centres is guaranteed non-empty since n_mels > 0
2399 unsafe { NonEmptyVec::new_unchecked(centres) }
2400}
2401
2402/// Spectrogram structure holding the computed spectrogram data and metadata.
2403///
2404/// # Type Parameters
2405///
2406/// * `FreqScale`: The frequency scale type (e.g., `LinearHz`, `Mel`, `LogHz`, etc.).
2407/// * `AmpScale`: The amplitude scale type (e.g., `Power`, `Magnitude`, `Decibels`).
2408///
2409/// # Fields
2410///
2411/// * `data`: A 2D array containing the spectrogram data.
2412/// * `axes`: The axes of the spectrogram (frequency and time).
2413/// * `params`: The parameters used to compute the spectrogram.
2414/// * `_amp`: A phantom data marker for the amplitude scale type.
2415#[derive(Debug, Clone)]
2416#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
2417pub struct Spectrogram<FreqScale, AmpScale>
2418where
2419 AmpScale: AmpScaleSpec + 'static,
2420 FreqScale: Copy + Clone + 'static,
2421{
2422 data: Array2<f64>,
2423 axes: Axes<FreqScale>,
2424 params: SpectrogramParams,
2425 #[cfg_attr(feature = "serde", serde(skip))]
2426 _amp: PhantomData<AmpScale>,
2427}
2428
2429impl<FreqScale, AmpScale> Spectrogram<FreqScale, AmpScale>
2430where
2431 AmpScale: AmpScaleSpec + 'static,
2432 FreqScale: Copy + Clone + 'static,
2433{
2434 /// Get the X-axis label.
2435 ///
2436 /// # Returns
2437 ///
2438 /// A static string slice representing the X-axis label.
2439 #[inline]
2440 #[must_use]
2441 pub const fn x_axis_label() -> &'static str {
2442 "Time (s)"
2443 }
2444
2445 /// Get the Y-axis label based on the frequency scale type.
2446 ///
2447 /// # Returns
2448 ///
2449 /// A static string slice representing the Y-axis label.
2450 #[inline]
2451 #[must_use]
2452 pub fn y_axis_label() -> &'static str {
2453 match std::any::TypeId::of::<FreqScale>() {
2454 id if id == std::any::TypeId::of::<LinearHz>() => "Frequency (Hz)",
2455 id if id == std::any::TypeId::of::<Mel>() => "Frequency (Mel)",
2456 id if id == std::any::TypeId::of::<LogHz>() => "Frequency (Log Hz)",
2457 id if id == std::any::TypeId::of::<Erb>() => "Frequency (ERB)",
2458 id if id == std::any::TypeId::of::<Cqt>() => "Frequency (CQT Bins)",
2459 _ => "Frequency",
2460 }
2461 }
2462
2463 /// Internal constructor. Only callable inside the crate.
2464 ///
2465 /// All inputs must already be validated and consistent.
2466 pub(crate) fn new(data: Array2<f64>, axes: Axes<FreqScale>, params: SpectrogramParams) -> Self {
2467 debug_assert_eq!(data.nrows(), axes.frequencies().len().get());
2468 debug_assert_eq!(data.ncols(), axes.times().len().get());
2469
2470 Self {
2471 data,
2472 axes,
2473 params,
2474 _amp: PhantomData,
2475 }
2476 }
2477
2478 /// Set spectrogram data matrix
2479 ///
2480 /// # Arguments
2481 ///
2482 /// * `data` - The new spectrogram data matrix.
2483 #[inline]
2484 pub fn set_data(&mut self, data: Array2<f64>) {
2485 self.data = data;
2486 }
2487
2488 /// Spectrogram data matrix
2489 ///
2490 /// # Returns
2491 ///
2492 /// A reference to the spectrogram data matrix.
2493 #[inline]
2494 #[must_use]
2495 pub const fn data(&self) -> &Array2<f64> {
2496 &self.data
2497 }
2498
2499 /// Axes of the spectrogram
2500 ///
2501 /// # Returns
2502 ///
2503 /// A reference to the axes of the spectrogram.
2504 #[inline]
2505 #[must_use]
2506 pub const fn axes(&self) -> &Axes<FreqScale> {
2507 &self.axes
2508 }
2509
2510 /// Frequency axis in Hz
2511 ///
2512 /// # Returns
2513 ///
2514 /// A reference to the frequency axis in Hz.
2515 #[inline]
2516 #[must_use]
2517 pub fn frequencies(&self) -> &NonEmptySlice<f64> {
2518 self.axes.frequencies()
2519 }
2520
2521 /// Frequency range in Hz (min, max)
2522 ///
2523 /// # Returns
2524 ///
2525 /// A tuple containing the minimum and maximum frequencies in Hz.
2526 #[inline]
2527 #[must_use]
2528 pub const fn frequency_range(&self) -> (f64, f64) {
2529 self.axes.frequency_range()
2530 }
2531
2532 /// Time axis in seconds
2533 ///
2534 /// # Returns
2535 ///
2536 /// A reference to the time axis in seconds.
2537 #[inline]
2538 #[must_use]
2539 pub fn times(&self) -> &NonEmptySlice<f64> {
2540 self.axes.times()
2541 }
2542
2543 /// Spectrogram computation parameters
2544 ///
2545 /// # Returns
2546 ///
2547 /// A reference to the spectrogram computation parameters.
2548 #[inline]
2549 #[must_use]
2550 pub const fn params(&self) -> &SpectrogramParams {
2551 &self.params
2552 }
2553
2554 /// Duration of the spectrogram in seconds
2555 ///
2556 /// # Returns
2557 ///
2558 /// The duration of the spectrogram in seconds.
2559 #[inline]
2560 #[must_use]
2561 pub fn duration(&self) -> f64 {
2562 self.axes.duration()
2563 }
2564
2565 /// If this is a dB spectrogram, return the (min, max) dB values. otherwise do the maths to compute dB range.
2566 ///
2567 /// # Returns
2568 ///
2569 /// The (min, max) dB values of the spectrogram, or `None` if the amplitude scale is unknown.
2570 #[inline]
2571 #[must_use]
2572 pub fn db_range(&self) -> Option<(f64, f64)> {
2573 let type_self = std::any::TypeId::of::<AmpScale>();
2574
2575 if type_self == std::any::TypeId::of::<Decibels>() {
2576 let (min, max) = min_max_single_pass(self.data.as_slice()?);
2577 Some((min, max))
2578 } else if type_self == std::any::TypeId::of::<Power>() {
2579 // Not a dB spectrogram; compute dB range from power values
2580 let mut min_db = f64::INFINITY;
2581 let mut max_db = f64::NEG_INFINITY;
2582 for &v in &self.data {
2583 let db = 10.0 * (v + EPS).log10();
2584 if db < min_db {
2585 min_db = db;
2586 }
2587 if db > max_db {
2588 max_db = db;
2589 }
2590 }
2591 Some((min_db, max_db))
2592 } else if type_self == std::any::TypeId::of::<Magnitude>() {
2593 // Not a dB spectrogram; compute dB range from magnitude values
2594 let mut min_db = f64::INFINITY;
2595 let mut max_db = f64::NEG_INFINITY;
2596
2597 for &v in &self.data {
2598 let power = v * v;
2599 let db = 10.0 * (power + EPS).log10();
2600 if db < min_db {
2601 min_db = db;
2602 }
2603 if db > max_db {
2604 max_db = db;
2605 }
2606 }
2607
2608 Some((min_db, max_db))
2609 } else {
2610 // Unknown AmpScale type; return dummy values
2611 None
2612 }
2613 }
2614
2615 /// Number of frequency bins
2616 ///
2617 /// # Returns
2618 ///
2619 /// The number of frequency bins in the spectrogram.
2620 #[inline]
2621 #[must_use]
2622 pub fn n_bins(&self) -> NonZeroUsize {
2623 // safety: data.nrows() > 0 is guaranteed by construction
2624 unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
2625 }
2626
2627 /// Number of time frames in the spectrogram
2628 ///
2629 /// # Returns
2630 ///
2631 /// The number of time frames (columns) in the spectrogram.
2632 #[inline]
2633 #[must_use]
2634 pub fn n_frames(&self) -> NonZeroUsize {
2635 // safety: data.ncols() > 0 is guaranteed by construction
2636 unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
2637 }
2638}
2639
2640impl AsRef<Array2<f64>> for Spectrogram<LinearHz, Power> {
2641 #[inline]
2642 fn as_ref(&self) -> &Array2<f64> {
2643 &self.data
2644 }
2645}
2646
2647impl Deref for Spectrogram<LinearHz, Power> {
2648 type Target = Array2<f64>;
2649
2650 #[inline]
2651 fn deref(&self) -> &Self::Target {
2652 &self.data
2653 }
2654}
2655
2656impl<AmpScale> Spectrogram<LinearHz, AmpScale>
2657where
2658 AmpScale: AmpScaleSpec + 'static,
2659{
2660 /// Compute a linear-frequency spectrogram from audio samples.
2661 ///
2662 /// This is a convenience method that creates a planner internally and computes
2663 /// the spectrogram in one call. For processing multiple signals with the same
2664 /// parameters, use [`SpectrogramPlanner::linear_plan`] to create a reusable plan.
2665 ///
2666 /// # Arguments
2667 ///
2668 /// * `samples` - Audio samples (any type that can be converted to a slice)
2669 /// * `params` - Spectrogram computation parameters
2670 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2671 ///
2672 /// # Returns
2673 ///
2674 /// A linear-frequency spectrogram with the specified amplitude scale.
2675 ///
2676 /// # Errors
2677 ///
2678 /// Returns an error if:
2679 /// - The samples slice is empty
2680 /// - Parameters are invalid
2681 /// - FFT computation fails
2682 ///
2683 /// # Examples
2684 ///
2685 /// ```
2686 /// use spectrograms::*;
2687 /// use non_empty_slice::non_empty_vec;
2688 ///
2689 /// # fn example() -> SpectrogramResult<()> {
2690 /// // Create a simple test signal
2691 /// let sample_rate = 16000.0;
2692 /// let samples_vec: Vec<f64> = (0..16000).map(|i| {
2693 /// (2.0 * std::f64::consts::PI * 440.0 * i as f64 / sample_rate).sin()
2694 /// }).collect();
2695 /// let samples = non_empty_slice::NonEmptyVec::new(samples_vec).unwrap();
2696 ///
2697 /// // Set up parameters
2698 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2699 /// let params = SpectrogramParams::new(stft, sample_rate)?;
2700 ///
2701 /// // Compute power spectrogram
2702 /// let spec = LinearPowerSpectrogram::compute(&samples, ¶ms, None)?;
2703 ///
2704 /// println!("Computed spectrogram: {} bins x {} frames", spec.n_bins(), spec.n_frames());
2705 /// # Ok(())
2706 /// # }
2707 /// ```
2708 #[inline]
2709 pub fn compute(
2710 samples: &NonEmptySlice<f64>,
2711 params: &SpectrogramParams,
2712 db: Option<&LogParams>,
2713 ) -> SpectrogramResult<Self> {
2714 let planner = SpectrogramPlanner::new();
2715 let mut plan = planner.linear_plan(params, db)?;
2716 plan.compute(samples)
2717 }
2718}
2719
2720impl<AmpScale> Spectrogram<Mel, AmpScale>
2721where
2722 AmpScale: AmpScaleSpec + 'static,
2723{
2724 /// Compute a mel-frequency spectrogram from audio samples.
2725 ///
2726 /// This is a convenience method that creates a planner internally and computes
2727 /// the spectrogram in one call. For processing multiple signals with the same
2728 /// parameters, use [`SpectrogramPlanner::mel_plan`] to create a reusable plan.
2729 ///
2730 /// # Arguments
2731 ///
2732 /// * `samples` - Audio samples (any type that can be converted to a slice)
2733 /// * `params` - Spectrogram computation parameters
2734 /// * `mel` - Mel filterbank parameters
2735 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2736 ///
2737 /// # Returns
2738 ///
2739 /// A mel-frequency spectrogram with the specified amplitude scale.
2740 ///
2741 /// # Errors
2742 ///
2743 /// Returns an error if:
2744 /// - The samples slice is empty
2745 /// - Parameters are invalid
2746 /// - Mel `f_max` exceeds Nyquist frequency
2747 /// - FFT computation fails
2748 ///
2749 /// # Examples
2750 ///
2751 /// ```
2752 /// use spectrograms::*;
2753 /// use non_empty_slice::non_empty_vec;
2754 ///
2755 /// # fn example() -> SpectrogramResult<()> {
2756 /// // Create a simple test signal
2757 /// let sample_rate = 16000.0;
2758 /// let samples_vec: Vec<f64> = (0..16000).map(|i| {
2759 /// (2.0 * std::f64::consts::PI * 440.0 * i as f64 / sample_rate).sin()
2760 /// }).collect();
2761 /// let samples = non_empty_slice::NonEmptyVec::new(samples_vec).unwrap();
2762 ///
2763 /// // Set up parameters
2764 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2765 /// let params = SpectrogramParams::new(stft, sample_rate)?;
2766 /// let mel = MelParams::new(nzu!(80), 0.0, 8000.0)?;
2767 ///
2768 /// // Compute mel spectrogram in dB scale
2769 /// let db = LogParams::new(-80.0)?;
2770 /// let spec = MelDbSpectrogram::compute(&samples, ¶ms, &mel, Some(&db))?;
2771 ///
2772 /// println!("Computed mel spectrogram: {} mels x {} frames", spec.n_bins(), spec.n_frames());
2773 /// # Ok(())
2774 /// # }
2775 /// ```
2776 #[inline]
2777 pub fn compute(
2778 samples: &NonEmptySlice<f64>,
2779 params: &SpectrogramParams,
2780 mel: &MelParams,
2781 db: Option<&LogParams>,
2782 ) -> SpectrogramResult<Self> {
2783 let planner = SpectrogramPlanner::new();
2784 let mut plan = planner.mel_plan(params, mel, db)?;
2785 plan.compute(samples)
2786 }
2787}
2788
2789impl<AmpScale> Spectrogram<Erb, AmpScale>
2790where
2791 AmpScale: AmpScaleSpec + 'static,
2792{
2793 /// Compute an ERB-frequency spectrogram from audio samples.
2794 ///
2795 /// This is a convenience method that creates a planner internally and computes
2796 /// the spectrogram in one call. For processing multiple signals with the same
2797 /// parameters, use [`SpectrogramPlanner::erb_plan`] to create a reusable plan.
2798 ///
2799 /// # Arguments
2800 ///
2801 /// * `samples` - Audio samples (any type that can be converted to a slice)
2802 /// * `params` - Spectrogram computation parameters
2803 /// * `erb` - ERB frequency scale parameters
2804 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2805 ///
2806 /// # Returns
2807 ///
2808 /// An ERB-scale spectrogram with the specified amplitude scale.
2809 ///
2810 /// # Errors
2811 ///
2812 /// Returns an error if:
2813 /// - The samples slice is empty
2814 /// - Parameters are invalid
2815 /// - FFT computation fails
2816 ///
2817 /// # Examples
2818 ///
2819 /// ```
2820 /// use spectrograms::*;
2821 /// use non_empty_slice::non_empty_vec;
2822 ///
2823 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2824 /// let samples = non_empty_vec![0.0; nzu!(16000)];
2825 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2826 /// let params = SpectrogramParams::new(stft, 16000.0)?;
2827 /// let erb = ErbParams::speech_standard();
2828 ///
2829 /// let spec = ErbPowerSpectrogram::compute(&samples, ¶ms, &erb, None)?;
2830 /// assert_eq!(spec.n_bins(), nzu!(40));
2831 /// # Ok(())
2832 /// # }
2833 /// ```
2834 #[inline]
2835 pub fn compute(
2836 samples: &NonEmptySlice<f64>,
2837 params: &SpectrogramParams,
2838 erb: &ErbParams,
2839 db: Option<&LogParams>,
2840 ) -> SpectrogramResult<Self> {
2841 let planner = SpectrogramPlanner::new();
2842 let mut plan = planner.erb_plan(params, erb, db)?;
2843 plan.compute(samples)
2844 }
2845}
2846
2847impl<AmpScale> Spectrogram<LogHz, AmpScale>
2848where
2849 AmpScale: AmpScaleSpec + 'static,
2850{
2851 /// Compute a logarithmic-frequency spectrogram from audio samples.
2852 ///
2853 /// This is a convenience method that creates a planner internally and computes
2854 /// the spectrogram in one call. For processing multiple signals with the same
2855 /// parameters, use [`SpectrogramPlanner::log_hz_plan`] to create a reusable plan.
2856 ///
2857 /// # Arguments
2858 ///
2859 /// * `samples` - Audio samples (any type that can be converted to a slice)
2860 /// * `params` - Spectrogram computation parameters
2861 /// * `loghz` - Logarithmic frequency scale parameters
2862 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2863 ///
2864 /// # Returns
2865 ///
2866 /// A logarithmic-frequency spectrogram with the specified amplitude scale.
2867 ///
2868 /// # Errors
2869 ///
2870 /// Returns an error if:
2871 /// - The samples slice is empty
2872 /// - Parameters are invalid
2873 /// - FFT computation fails
2874 ///
2875 /// # Examples
2876 ///
2877 /// ```
2878 /// use spectrograms::*;
2879 /// use non_empty_slice::non_empty_vec;
2880 ///
2881 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2882 /// let samples = non_empty_vec![0.0; nzu!(16000)];
2883 /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2884 /// let params = SpectrogramParams::new(stft, 16000.0)?;
2885 /// let loghz = LogHzParams::new(nzu!(128), 20.0, 8000.0)?;
2886 ///
2887 /// let spec = LogHzPowerSpectrogram::compute(&samples, ¶ms, &loghz, None)?;
2888 /// assert_eq!(spec.n_bins(), nzu!(128));
2889 /// # Ok(())
2890 /// # }
2891 /// ```
2892 #[inline]
2893 pub fn compute(
2894 samples: &NonEmptySlice<f64>,
2895 params: &SpectrogramParams,
2896 loghz: &LogHzParams,
2897 db: Option<&LogParams>,
2898 ) -> SpectrogramResult<Self> {
2899 let planner = SpectrogramPlanner::new();
2900 let mut plan = planner.log_hz_plan(params, loghz, db)?;
2901 plan.compute(samples)
2902 }
2903}
2904
2905impl<AmpScale> Spectrogram<Cqt, AmpScale>
2906where
2907 AmpScale: AmpScaleSpec + 'static,
2908{
2909 /// Compute a constant-Q transform (CQT) spectrogram from audio samples.
2910 ///
2911 /// # Arguments
2912 ///
2913 /// * `samples` - Audio samples (any type that can be converted to a slice)
2914 /// * `params` - Spectrogram computation parameters
2915 /// * `cqt` - CQT parameters
2916 /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2917 ///
2918 /// # Returns
2919 ///
2920 /// A CQT spectrogram with the specified amplitude scale.
2921 ///
2922 /// # Errors
2923 ///
2924 /// Returns an error if:
2925 ///
2926 #[inline]
2927 pub fn compute(
2928 samples: &NonEmptySlice<f64>,
2929 params: &SpectrogramParams,
2930 cqt: &CqtParams,
2931 db: Option<&LogParams>,
2932 ) -> SpectrogramResult<Self> {
2933 let planner = SpectrogramPlanner::new();
2934 let mut plan = planner.cqt_plan(params, cqt, db)?;
2935 plan.compute(samples)
2936 }
2937}
2938
2939// ========================
2940// Display implementations
2941// ========================
2942
2943/// Helper function to get amplitude scale name
2944fn amp_scale_name<AmpScale>() -> &'static str
2945where
2946 AmpScale: AmpScaleSpec + 'static,
2947{
2948 match std::any::TypeId::of::<AmpScale>() {
2949 id if id == std::any::TypeId::of::<Power>() => "Power",
2950 id if id == std::any::TypeId::of::<Magnitude>() => "Magnitude",
2951 id if id == std::any::TypeId::of::<Decibels>() => "Decibels",
2952 _ => "Unknown",
2953 }
2954}
2955
2956/// Helper function to get frequency scale name
2957fn freq_scale_name<FreqScale>() -> &'static str
2958where
2959 FreqScale: Copy + Clone + 'static,
2960{
2961 match std::any::TypeId::of::<FreqScale>() {
2962 id if id == std::any::TypeId::of::<LinearHz>() => "Linear Hz",
2963 id if id == std::any::TypeId::of::<LogHz>() => "Log Hz",
2964 id if id == std::any::TypeId::of::<Mel>() => "Mel",
2965 id if id == std::any::TypeId::of::<Erb>() => "ERB",
2966 id if id == std::any::TypeId::of::<Cqt>() => "CQT",
2967 _ => "Unknown",
2968 }
2969}
2970
2971impl<FreqScale, AmpScale> core::fmt::Display for Spectrogram<FreqScale, AmpScale>
2972where
2973 AmpScale: AmpScaleSpec + 'static,
2974 FreqScale: Copy + Clone + 'static,
2975{
2976 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2977 let (freq_min, freq_max) = self.frequency_range();
2978 let duration = self.duration();
2979 let (rows, cols) = self.data.dim();
2980
2981 // Alternative formatting (#) provides more detailed output with data
2982 if f.alternate() {
2983 writeln!(f, "Spectrogram {{")?;
2984 writeln!(f, " Frequency Scale: {}", freq_scale_name::<FreqScale>())?;
2985 writeln!(f, " Amplitude Scale: {}", amp_scale_name::<AmpScale>())?;
2986 writeln!(f, " Shape: {} frequency bins × {} time frames", rows, cols)?;
2987 writeln!(
2988 f,
2989 " Frequency Range: {:.2} Hz - {:.2} Hz",
2990 freq_min, freq_max
2991 )?;
2992 writeln!(f, " Duration: {:.3} s", duration)?;
2993 writeln!(f)?;
2994
2995 // Display parameters
2996 writeln!(f, " Parameters:")?;
2997 writeln!(f, " Sample Rate: {} Hz", self.params.sample_rate_hz())?;
2998 writeln!(f, " FFT Size: {}", self.params.stft().n_fft())?;
2999 writeln!(f, " Hop Size: {}", self.params.stft().hop_size())?;
3000 writeln!(f, " Window: {:?}", self.params.stft().window())?;
3001 writeln!(f, " Centered: {}", self.params.stft().centre())?;
3002 writeln!(f)?;
3003
3004 // Display data statistics
3005 let data_slice = self.data.as_slice().unwrap_or(&[]);
3006 if !data_slice.is_empty() {
3007 let (min_val, max_val) = min_max_single_pass(data_slice);
3008 let mean = data_slice.iter().sum::<f64>() / data_slice.len() as f64;
3009 writeln!(f, " Data Statistics:")?;
3010 writeln!(f, " Min: {:.6}", min_val)?;
3011 writeln!(f, " Max: {:.6}", max_val)?;
3012 writeln!(f, " Mean: {:.6}", mean)?;
3013 writeln!(f)?;
3014 }
3015
3016 // Display actual data (truncated if too large)
3017 writeln!(f, " Data Matrix:")?;
3018 let max_rows_to_display = 8;
3019 let max_cols_to_display = 8;
3020
3021 for i in 0..rows.min(max_rows_to_display) {
3022 write!(f, " [")?;
3023 for j in 0..cols.min(max_cols_to_display) {
3024 if j > 0 {
3025 write!(f, ", ")?;
3026 }
3027 write!(f, "{:9.4}", self.data[[i, j]])?;
3028 }
3029 if cols > max_cols_to_display {
3030 write!(f, ", ... ({} more)", cols - max_cols_to_display)?;
3031 }
3032 writeln!(f, "]")?;
3033 }
3034
3035 if rows > max_rows_to_display {
3036 writeln!(f, " ... ({} more rows)", rows - max_rows_to_display)?;
3037 }
3038
3039 write!(f, "}}")?;
3040 } else {
3041 // Default formatting: compact summary
3042 write!(
3043 f,
3044 "Spectrogram<{}, {}>[{}×{}] ({:.2}-{:.2} Hz, {:.3}s)",
3045 freq_scale_name::<FreqScale>(),
3046 amp_scale_name::<AmpScale>(),
3047 rows,
3048 cols,
3049 freq_min,
3050 freq_max,
3051 duration
3052 )?;
3053 }
3054
3055 Ok(())
3056 }
3057}
3058
3059#[derive(Debug, Clone)]
3060#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3061pub struct FrequencyAxis<FreqScale>
3062where
3063 FreqScale: Copy + Clone + 'static,
3064{
3065 frequencies: NonEmptyVec<f64>,
3066 #[cfg_attr(feature = "serde", serde(skip))]
3067 _marker: PhantomData<FreqScale>,
3068}
3069
3070impl<FreqScale> FrequencyAxis<FreqScale>
3071where
3072 FreqScale: Copy + Clone + 'static,
3073{
3074 pub(crate) const fn new(frequencies: NonEmptyVec<f64>) -> Self {
3075 Self {
3076 frequencies,
3077 _marker: PhantomData,
3078 }
3079 }
3080
3081 /// Get the frequency values in Hz.
3082 ///
3083 /// # Returns
3084 ///
3085 /// Returns a non-empty slice of frequencies.
3086 #[inline]
3087 #[must_use]
3088 pub fn frequencies(&self) -> &NonEmptySlice<f64> {
3089 &self.frequencies
3090 }
3091
3092 /// Get the frequency range (min, max) in Hz.
3093 ///
3094 /// # Returns
3095 ///
3096 /// Returns a tuple containing the minimum and maximum frequency.
3097 #[inline]
3098 #[must_use]
3099 pub const fn frequency_range(&self) -> (f64, f64) {
3100 let data = self.frequencies.as_slice();
3101 let min = data[0];
3102 let max_idx = data.len().saturating_sub(1); // safe for non-empty
3103 let max = data[max_idx];
3104 (min, max)
3105 }
3106
3107 /// Get the number of frequency bins.
3108 ///
3109 /// # Returns
3110 ///
3111 /// Returns the number of frequency bins as a NonZeroUsize.
3112 #[inline]
3113 #[must_use]
3114 pub const fn len(&self) -> NonZeroUsize {
3115 self.frequencies.len()
3116 }
3117}
3118
3119/// Spectrogram axes container.
3120///
3121/// Holds frequency and time axes.
3122#[derive(Debug, Clone)]
3123#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3124pub struct Axes<FreqScale>
3125where
3126 FreqScale: Copy + Clone + 'static,
3127{
3128 freq: FrequencyAxis<FreqScale>,
3129 times: NonEmptyVec<f64>,
3130}
3131
3132impl<FreqScale> Axes<FreqScale>
3133where
3134 FreqScale: Copy + Clone + 'static,
3135{
3136 pub(crate) const fn new(freq: FrequencyAxis<FreqScale>, times: NonEmptyVec<f64>) -> Self {
3137 Self { freq, times }
3138 }
3139
3140 /// Get the frequency values in Hz.
3141 ///
3142 /// # Returns
3143 ///
3144 /// Returns a non-empty slice of frequencies.
3145 #[inline]
3146 #[must_use]
3147 pub fn frequencies(&self) -> &NonEmptySlice<f64> {
3148 self.freq.frequencies()
3149 }
3150
3151 /// Get the time values in seconds.
3152 ///
3153 /// # Returns
3154 ///
3155 /// Returns a non-empty slice of time values.
3156 #[inline]
3157 #[must_use]
3158 pub fn times(&self) -> &NonEmptySlice<f64> {
3159 &self.times
3160 }
3161
3162 /// Get the frequency range (min, max) in Hz.
3163 ///
3164 /// # Returns
3165 ///
3166 /// Returns a tuple containing the minimum and maximum frequency.
3167 #[inline]
3168 #[must_use]
3169 pub const fn frequency_range(&self) -> (f64, f64) {
3170 self.freq.frequency_range()
3171 }
3172
3173 /// Get the duration of the spectrogram in seconds.
3174 ///
3175 /// # Returns
3176 ///
3177 /// Returns the duration in seconds.
3178 #[inline]
3179 #[must_use]
3180 pub fn duration(&self) -> f64 {
3181 *self.times.last()
3182 }
3183}
3184
3185// Enum types for frequency and amplitude scales
3186
3187/// Linear frequency scale
3188#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3189#[cfg_attr(feature = "python", pyclass)]
3190#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3191#[non_exhaustive]
3192pub enum LinearHz {
3193 _Phantom,
3194}
3195
3196/// Logarithmic frequency scale
3197#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3198#[cfg_attr(feature = "python", pyclass)]
3199#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3200#[non_exhaustive]
3201pub enum LogHz {
3202 _Phantom,
3203}
3204
3205/// Mel frequency scale
3206#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3207#[cfg_attr(feature = "python", pyclass)]
3208#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3209#[non_exhaustive]
3210pub enum Mel {
3211 _Phantom,
3212}
3213
3214/// ERB/gammatone frequency scale
3215#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3216#[cfg_attr(feature = "python", pyclass)]
3217#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3218#[non_exhaustive]
3219pub enum Erb {
3220 _Phantom,
3221}
3222pub type Gammatone = Erb;
3223
3224/// Constant-Q Transform frequency scale
3225#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3226#[cfg_attr(feature = "python", pyclass)]
3227#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3228#[non_exhaustive]
3229pub enum Cqt {
3230 _Phantom,
3231}
3232
3233// Amplitude scales
3234
3235/// Power amplitude scale
3236#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3237#[cfg_attr(feature = "python", pyclass)]
3238#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3239#[non_exhaustive]
3240pub enum Power {
3241 _Phantom,
3242}
3243
3244/// Decibel amplitude scale
3245#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3246#[cfg_attr(feature = "python", pyclass)]
3247#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3248#[non_exhaustive]
3249pub enum Decibels {
3250 _Phantom,
3251}
3252
3253/// Magnitude amplitude scale
3254#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3255#[cfg_attr(feature = "python", pyclass)]
3256#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3257#[non_exhaustive]
3258pub enum Magnitude {
3259 _Phantom,
3260}
3261
3262/// STFT parameters for spectrogram computation.
3263///
3264/// * `n_fft`: Size of the FFT window.
3265/// * `hop_size`: Number of samples between successive frames.
3266/// * window: Window function to apply to each frame.
3267/// * centre: Whether to pad the input signal so that frames are centered.
3268#[derive(Debug, Clone, Copy, PartialEq)]
3269#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3270pub struct StftParams {
3271 n_fft: NonZeroUsize,
3272 hop_size: NonZeroUsize,
3273 window: WindowType,
3274 centre: bool,
3275}
3276
3277impl StftParams {
3278 /// Create new STFT parameters.
3279 ///
3280 /// # Arguments
3281 ///
3282 /// * `n_fft` - Size of the FFT window
3283 /// * `hop_size` - Number of samples between successive frames
3284 /// * `window` - Window function to apply to each frame
3285 /// * `centre` - Whether to pad the input signal so that frames are centered
3286 ///
3287 /// # Errors
3288 ///
3289 /// Returns an error if `hop_size` > `n_fft`.
3290 ///
3291 /// # Returns
3292 ///
3293 /// New `StftParams` instance.
3294 #[inline]
3295 pub fn new(
3296 n_fft: NonZeroUsize,
3297 hop_size: NonZeroUsize,
3298 window: WindowType,
3299 centre: bool,
3300 ) -> SpectrogramResult<Self> {
3301 if hop_size.get() > n_fft.get() {
3302 return Err(SpectrogramError::invalid_input("hop_size must be <= n_fft"));
3303 }
3304
3305 Ok(Self {
3306 n_fft,
3307 hop_size,
3308 window,
3309 centre,
3310 })
3311 }
3312
3313 const unsafe fn new_unchecked(
3314 n_fft: NonZeroUsize,
3315 hop_size: NonZeroUsize,
3316 window: WindowType,
3317 centre: bool,
3318 ) -> Self {
3319 Self {
3320 n_fft,
3321 hop_size,
3322 window,
3323 centre,
3324 }
3325 }
3326
3327 /// Get the FFT window size.
3328 ///
3329 /// # Returns
3330 ///
3331 /// The FFT window size.
3332 #[inline]
3333 #[must_use]
3334 pub const fn n_fft(&self) -> NonZeroUsize {
3335 self.n_fft
3336 }
3337
3338 /// Get the hop size (samples between successive frames).
3339 ///
3340 /// # Returns
3341 ///
3342 /// The hop size.
3343 #[inline]
3344 #[must_use]
3345 pub const fn hop_size(&self) -> NonZeroUsize {
3346 self.hop_size
3347 }
3348
3349 /// Get the window function.
3350 ///
3351 /// # Returns
3352 ///
3353 /// The window function.
3354 #[inline]
3355 #[must_use]
3356 pub const fn window(&self) -> WindowType {
3357 self.window
3358 }
3359
3360 /// Get whether frames are centered (input signal is padded).
3361 ///
3362 /// # Returns
3363 ///
3364 /// `true` if frames are centered, `false` otherwise.
3365 #[inline]
3366 #[must_use]
3367 pub const fn centre(&self) -> bool {
3368 self.centre
3369 }
3370
3371 /// Create a builder for STFT parameters.
3372 ///
3373 /// # Returns
3374 ///
3375 /// A `StftParamsBuilder` instance.
3376 ///
3377 /// # Examples
3378 ///
3379 /// ```
3380 /// use spectrograms::{StftParams, WindowType, nzu};
3381 ///
3382 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
3383 /// let stft = StftParams::builder()
3384 /// .n_fft(nzu!(2048))
3385 /// .hop_size(nzu!(512))
3386 /// .window(WindowType::Hanning)
3387 /// .centre(true)
3388 /// .build()?;
3389 ///
3390 /// assert_eq!(stft.n_fft(), nzu!(2048));
3391 /// assert_eq!(stft.hop_size(), nzu!(512));
3392 /// # Ok(())
3393 /// # }
3394 /// ```
3395 #[inline]
3396 #[must_use]
3397 pub fn builder() -> StftParamsBuilder {
3398 StftParamsBuilder::default()
3399 }
3400}
3401
3402/// Builder for [`StftParams`].
3403#[derive(Debug, Clone)]
3404pub struct StftParamsBuilder {
3405 n_fft: Option<NonZeroUsize>,
3406 hop_size: Option<NonZeroUsize>,
3407 window: WindowType,
3408 centre: bool,
3409}
3410
3411impl Default for StftParamsBuilder {
3412 #[inline]
3413 fn default() -> Self {
3414 Self {
3415 n_fft: None,
3416 hop_size: None,
3417 window: WindowType::Hanning,
3418 centre: true,
3419 }
3420 }
3421}
3422
3423impl StftParamsBuilder {
3424 /// Set the FFT window size.
3425 ///
3426 /// # Arguments
3427 ///
3428 /// * `n_fft` - Size of the FFT window
3429 ///
3430 /// # Returns
3431 ///
3432 /// The builder with the updated FFT window size.
3433 #[inline]
3434 #[must_use]
3435 pub const fn n_fft(mut self, n_fft: NonZeroUsize) -> Self {
3436 self.n_fft = Some(n_fft);
3437 self
3438 }
3439
3440 /// Set the hop size (samples between successive frames).
3441 ///
3442 /// # Arguments
3443 ///
3444 /// * `hop_size` - Number of samples between successive frames
3445 ///
3446 /// # Returns
3447 ///
3448 /// The builder with the updated hop size.
3449 #[inline]
3450 #[must_use]
3451 pub const fn hop_size(mut self, hop_size: NonZeroUsize) -> Self {
3452 self.hop_size = Some(hop_size);
3453 self
3454 }
3455
3456 /// Set the window function.
3457 ///
3458 /// # Arguments
3459 ///
3460 /// * `window` - Window function to apply to each frame
3461 ///
3462 /// # Returns
3463 ///
3464 /// The builder with the updated window function.
3465 #[inline]
3466 #[must_use]
3467 pub const fn window(mut self, window: WindowType) -> Self {
3468 self.window = window;
3469 self
3470 }
3471
3472 /// Set whether to center frames (pad input signal).
3473 #[inline]
3474 #[must_use]
3475 pub const fn centre(mut self, centre: bool) -> Self {
3476 self.centre = centre;
3477 self
3478 }
3479
3480 /// Build the [`StftParams`].
3481 ///
3482 /// # Errors
3483 ///
3484 /// Returns an error if:
3485 /// - `n_fft` or `hop_size` are not set or are zero
3486 /// - `hop_size` > `n_fft`
3487 #[inline]
3488 pub fn build(self) -> SpectrogramResult<StftParams> {
3489 let n_fft = self
3490 .n_fft
3491 .ok_or_else(|| SpectrogramError::invalid_input("n_fft must be set"))?;
3492 let hop_size = self
3493 .hop_size
3494 .ok_or_else(|| SpectrogramError::invalid_input("hop_size must be set"))?;
3495
3496 StftParams::new(n_fft, hop_size, self.window, self.centre)
3497 }
3498}
3499
3500//
3501// ========================
3502// Mel parameters
3503// ========================
3504//
3505
3506/// Mel filterbank normalization strategy.
3507///
3508/// Determines how the triangular mel filters are normalized after construction.
3509#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3510#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3511pub enum MelNorm {
3512 /// No normalization (triangular filters with peak = 1.0).
3513 ///
3514 /// This is the default and fastest option.
3515 None,
3516
3517 /// Slaney-style area normalization (librosa default).
3518 ///
3519 /// Each mel filter is divided by its bandwidth in Hz: `2.0 / (f_max - f_min)`.
3520 /// This ensures constant energy per mel band regardless of bandwidth.
3521 ///
3522 /// Use this for compatibility with librosa's default behavior.
3523 Slaney,
3524
3525 /// L1 normalization (sum of weights = 1.0).
3526 ///
3527 /// Each mel filter's weights are divided by their sum.
3528 /// Useful when you want each filter to act as a weighted average.
3529 L1,
3530
3531 /// L2 normalization (Euclidean norm = 1.0).
3532 ///
3533 /// Each mel filter's weights are divided by their L2 norm.
3534 /// Provides unit-norm filters in the L2 sense.
3535 L2,
3536}
3537
3538impl Default for MelNorm {
3539 fn default() -> Self {
3540 Self::None
3541 }
3542}
3543
3544/// Mel filter bank parameters
3545///
3546/// * `n_mels`: Number of mel bands
3547/// * `f_min`: Minimum frequency (Hz)
3548/// * `f_max`: Maximum frequency (Hz)
3549/// * `norm`: Filterbank normalization strategy
3550#[derive(Debug, Clone, Copy, PartialEq)]
3551#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3552pub struct MelParams {
3553 n_mels: NonZeroUsize,
3554 f_min: f64,
3555 f_max: f64,
3556 norm: MelNorm,
3557}
3558
3559impl MelParams {
3560 /// Create new mel filter bank parameters.
3561 ///
3562 /// # Arguments
3563 ///
3564 /// * `n_mels` - Number of mel bands
3565 /// * `f_min` - Minimum frequency (Hz)
3566 /// * `f_max` - Maximum frequency (Hz)
3567 ///
3568 /// # Errors
3569 ///
3570 /// Returns an error if:
3571 /// - `f_min` is not >= 0
3572 /// - `f_max` is not > `f_min`
3573 ///
3574 /// # Returns
3575 ///
3576 /// A `MelParams` instance with no normalization (default).
3577 #[inline]
3578 pub fn new(n_mels: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
3579 Self::with_norm(n_mels, f_min, f_max, MelNorm::None)
3580 }
3581
3582 /// Create new mel filter bank parameters with specified normalization.
3583 ///
3584 /// # Arguments
3585 ///
3586 /// * `n_mels` - Number of mel bands
3587 /// * `f_min` - Minimum frequency (Hz)
3588 /// * `f_max` - Maximum frequency (Hz)
3589 /// * `norm` - Filterbank normalization strategy
3590 ///
3591 /// # Errors
3592 ///
3593 /// Returns an error if:
3594 /// - `f_min` is not >= 0
3595 /// - `f_max` is not > `f_min`
3596 ///
3597 /// # Returns
3598 ///
3599 /// A `MelParams` instance.
3600 #[inline]
3601 pub fn with_norm(
3602 n_mels: NonZeroUsize,
3603 f_min: f64,
3604 f_max: f64,
3605 norm: MelNorm,
3606 ) -> SpectrogramResult<Self> {
3607 if f_min < 0.0 {
3608 return Err(SpectrogramError::invalid_input("f_min must be >= 0"));
3609 }
3610
3611 if f_max <= f_min {
3612 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
3613 }
3614
3615 Ok(Self {
3616 n_mels,
3617 f_min,
3618 f_max,
3619 norm,
3620 })
3621 }
3622
3623 const unsafe fn new_unchecked(n_mels: NonZeroUsize, f_min: f64, f_max: f64) -> Self {
3624 Self {
3625 n_mels,
3626 f_min,
3627 f_max,
3628 norm: MelNorm::None,
3629 }
3630 }
3631
3632 /// Get the number of mel bands.
3633 ///
3634 /// # Returns
3635 ///
3636 /// The number of mel bands.
3637 #[inline]
3638 #[must_use]
3639 pub const fn n_mels(&self) -> NonZeroUsize {
3640 self.n_mels
3641 }
3642
3643 /// Get the minimum frequency (Hz).
3644 ///
3645 /// # Returns
3646 ///
3647 /// The minimum frequency in Hz.
3648 #[inline]
3649 #[must_use]
3650 pub const fn f_min(&self) -> f64 {
3651 self.f_min
3652 }
3653
3654 /// Get the maximum frequency (Hz).
3655 ///
3656 /// # Returns
3657 ///
3658 /// The maximum frequency in Hz.
3659 #[inline]
3660 #[must_use]
3661 pub const fn f_max(&self) -> f64 {
3662 self.f_max
3663 }
3664
3665 /// Get the filterbank normalization strategy.
3666 ///
3667 /// # Returns
3668 ///
3669 /// The normalization strategy.
3670 #[inline]
3671 #[must_use]
3672 pub const fn norm(&self) -> MelNorm {
3673 self.norm
3674 }
3675
3676 /// Create standard mel filterbank parameters.
3677 ///
3678 /// Uses 128 mel bands from 0 Hz to the Nyquist frequency.
3679 ///
3680 /// # Arguments
3681 ///
3682 /// * `sample_rate` - Sample rate in Hz (used to determine `f_max`)
3683 ///
3684 /// # Returns
3685 ///
3686 /// A `MelParams` instance with standard settings.
3687 ///
3688 /// # Panics
3689 ///
3690 /// Panics if `sample_rate` is not greater than 0.
3691 #[inline]
3692 #[must_use]
3693 pub const fn standard(sample_rate: f64) -> Self {
3694 assert!(sample_rate > 0.0);
3695 // safety: parameters are known to be valid
3696 unsafe { Self::new_unchecked(nzu!(128), 0.0, sample_rate / 2.0) }
3697 }
3698
3699 /// Create mel filterbank parameters optimized for speech.
3700 ///
3701 /// Uses 40 mel bands from 0 Hz to 8000 Hz (typical speech bandwidth).
3702 ///
3703 /// # Returns
3704 ///
3705 /// A `MelParams` instance with speech-optimized settings.
3706 #[inline]
3707 #[must_use]
3708 pub const fn speech_standard() -> Self {
3709 // safety: parameters are known to be valid
3710 unsafe { Self::new_unchecked(nzu!(40), 0.0, 8000.0) }
3711 }
3712}
3713
3714//
3715// ========================
3716// LogHz parameters
3717// ========================
3718//
3719
3720/// Logarithmic frequency scale parameters
3721///
3722/// * `n_bins`: Number of logarithmically-spaced frequency bins
3723/// * `f_min`: Minimum frequency (Hz)
3724/// * `f_max`: Maximum frequency (Hz)
3725#[derive(Debug, Clone, Copy, PartialEq)]
3726#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3727pub struct LogHzParams {
3728 n_bins: NonZeroUsize,
3729 f_min: f64,
3730 f_max: f64,
3731}
3732
3733impl LogHzParams {
3734 /// Create new logarithmic frequency scale parameters.
3735 ///
3736 /// # Arguments
3737 ///
3738 /// * `n_bins` - Number of logarithmically-spaced frequency bins
3739 /// * `f_min` - Minimum frequency (Hz)
3740 /// * `f_max` - Maximum frequency (Hz)
3741 ///
3742 /// # Errors
3743 ///
3744 /// Returns an error if:
3745 /// - `f_min` is not finite and > 0
3746 /// - `f_max` is not > `f_min`
3747 ///
3748 /// # Returns
3749 ///
3750 /// A `LogHzParams` instance.
3751 #[inline]
3752 pub fn new(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
3753 if !(f_min > 0.0 && f_min.is_finite()) {
3754 return Err(SpectrogramError::invalid_input(
3755 "f_min must be finite and > 0",
3756 ));
3757 }
3758
3759 if f_max <= f_min {
3760 return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
3761 }
3762
3763 Ok(Self {
3764 n_bins,
3765 f_min,
3766 f_max,
3767 })
3768 }
3769
3770 const unsafe fn new_unchecked(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> Self {
3771 Self {
3772 n_bins,
3773 f_min,
3774 f_max,
3775 }
3776 }
3777
3778 /// Get the number of frequency bins.
3779 ///
3780 /// # Returns
3781 ///
3782 /// The number of frequency bins.
3783 #[inline]
3784 #[must_use]
3785 pub const fn n_bins(&self) -> NonZeroUsize {
3786 self.n_bins
3787 }
3788
3789 /// Get the minimum frequency (Hz).
3790 ///
3791 /// # Returns
3792 ///
3793 /// The minimum frequency in Hz.
3794 #[inline]
3795 #[must_use]
3796 pub const fn f_min(&self) -> f64 {
3797 self.f_min
3798 }
3799
3800 /// Get the maximum frequency (Hz).
3801 ///
3802 /// # Returns
3803 ///
3804 /// The maximum frequency in Hz.
3805 #[inline]
3806 #[must_use]
3807 pub const fn f_max(&self) -> f64 {
3808 self.f_max
3809 }
3810
3811 /// Create standard logarithmic frequency parameters.
3812 ///
3813 /// Uses 128 log bins from 20 Hz to the Nyquist frequency.
3814 ///
3815 /// # Arguments
3816 ///
3817 /// * `sample_rate` - Sample rate in Hz (used to determine `f_max`)
3818 #[inline]
3819 #[must_use]
3820 pub fn standard(sample_rate: f64) -> Self {
3821 // safety: parameters are known to be valid
3822 unsafe { Self::new_unchecked(nzu!(128), 20.0, sample_rate / 2.0) }
3823 }
3824
3825 /// Create logarithmic frequency parameters optimized for music.
3826 ///
3827 /// Uses 84 bins (7 octaves * 12 bins/octave) from 27.5 Hz (A0) to 4186 Hz (C8).
3828 #[inline]
3829 #[must_use]
3830 pub const fn music_standard() -> Self {
3831 // safety: parameters are known to be valid
3832 unsafe { Self::new_unchecked(nzu!(84), 27.5, 4186.0) }
3833 }
3834}
3835
3836//
3837// ========================
3838// Log scaling parameters
3839// ========================
3840//
3841
3842#[derive(Debug, Clone, Copy, PartialEq)]
3843#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3844pub struct LogParams {
3845 floor_db: f64,
3846}
3847
3848impl LogParams {
3849 /// Create new logarithmic scaling parameters.
3850 ///
3851 /// # Arguments
3852 ///
3853 /// * `floor_db` - Minimum dB value (floor) for logarithmic scaling
3854 ///
3855 /// # Errors
3856 ///
3857 /// Returns an error if `floor_db` is not finite.
3858 ///
3859 /// # Returns
3860 ///
3861 /// A `LogParams` instance.
3862 #[inline]
3863 pub fn new(floor_db: f64) -> SpectrogramResult<Self> {
3864 if !floor_db.is_finite() {
3865 return Err(SpectrogramError::invalid_input("floor_db must be finite"));
3866 }
3867
3868 Ok(Self { floor_db })
3869 }
3870
3871 /// Get the floor dB value.
3872 #[inline]
3873 #[must_use]
3874 pub const fn floor_db(&self) -> f64 {
3875 self.floor_db
3876 }
3877}
3878
3879/// Spectrogram computation parameters.
3880///
3881/// * `stft`: STFT parameters
3882/// * `sample_rate_hz`: Sample rate in Hz
3883#[derive(Debug, Copy, Clone)]
3884#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3885pub struct SpectrogramParams {
3886 stft: StftParams,
3887 sample_rate_hz: f64,
3888}
3889
3890impl SpectrogramParams {
3891 /// Create new spectrogram parameters.
3892 ///
3893 /// # Arguments
3894 ///
3895 /// * `stft` - STFT parameters
3896 /// * `sample_rate_hz` - Sample rate in Hz
3897 ///
3898 /// # Errors
3899 ///
3900 /// Returns an error if the sample rate is not positive and finite.
3901 ///
3902 /// # Returns
3903 ///
3904 /// A `SpectrogramParams` instance.
3905 #[inline]
3906 pub fn new(stft: StftParams, sample_rate_hz: f64) -> SpectrogramResult<Self> {
3907 if !(sample_rate_hz > 0.0 && sample_rate_hz.is_finite()) {
3908 return Err(SpectrogramError::invalid_input(
3909 "sample_rate_hz must be finite and > 0",
3910 ));
3911 }
3912
3913 Ok(Self {
3914 stft,
3915 sample_rate_hz,
3916 })
3917 }
3918
3919 /// Create a builder for spectrogram parameters.
3920 ///
3921 /// # Errors
3922 ///
3923 /// Returns an error if required parameters are not set or are invalid.
3924 ///
3925 /// # Returns
3926 ///
3927 /// A builder for [`SpectrogramParams`].
3928 ///
3929 /// # Examples
3930 ///
3931 /// ```
3932 /// use spectrograms::{SpectrogramParams, WindowType, nzu};
3933 ///
3934 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
3935 /// let params = SpectrogramParams::builder()
3936 /// .sample_rate(16000.0)
3937 /// .n_fft(nzu!(512))
3938 /// .hop_size(nzu!(256))
3939 /// .window(WindowType::Hanning)
3940 /// .centre(true)
3941 /// .build()?;
3942 ///
3943 /// assert_eq!(params.sample_rate_hz(), 16000.0);
3944 /// # Ok(())
3945 /// # }
3946 /// ```
3947 #[inline]
3948 #[must_use]
3949 pub fn builder() -> SpectrogramParamsBuilder {
3950 SpectrogramParamsBuilder::default()
3951 }
3952
3953 /// Create default parameters for speech processing.
3954 ///
3955 /// # Arguments
3956 ///
3957 /// * `sample_rate_hz` - Sample rate in Hz
3958 ///
3959 /// # Returns
3960 ///
3961 /// A `SpectrogramParams` instance with default settings for music analysis.
3962 ///
3963 /// # Errors
3964 ///
3965 /// Returns an error if the sample rate is not positive and finite.
3966 ///
3967 /// Uses:
3968 /// - `n_fft`: 512 (32ms at 16kHz)
3969 /// - `hop_size`: 160 (10ms at 16kHz)
3970 /// - window: Hanning
3971 /// - centre: true
3972 #[inline]
3973 pub fn speech_default(sample_rate_hz: f64) -> SpectrogramResult<Self> {
3974 // safety: parameters are known to be valid
3975 let stft =
3976 unsafe { StftParams::new_unchecked(nzu!(512), nzu!(160), WindowType::Hanning, true) };
3977
3978 Self::new(stft, sample_rate_hz)
3979 }
3980
3981 /// Create default parameters for music processing.
3982 ///
3983 /// # Arguments
3984 ///
3985 /// * `sample_rate_hz` - Sample rate in Hz
3986 ///
3987 /// # Returns
3988 ///
3989 /// A `SpectrogramParams` instance with default settings for music analysis.
3990 ///
3991 /// # Errors
3992 ///
3993 /// Returns an error if the sample rate is not positive and finite.
3994 ///
3995 /// Uses:
3996 /// - `n_fft`: 2048 (46ms at 44.1kHz)
3997 /// - `hop_size`: 512 (11.6ms at 44.1kHz)
3998 /// - window: Hanning
3999 /// - centre: true
4000 #[inline]
4001 pub fn music_default(sample_rate_hz: f64) -> SpectrogramResult<Self> {
4002 // safety: parameters are known to be valid
4003 let stft =
4004 unsafe { StftParams::new_unchecked(nzu!(2048), nzu!(512), WindowType::Hanning, true) };
4005 Self::new(stft, sample_rate_hz)
4006 }
4007
4008 /// Get the STFT parameters.
4009 #[inline]
4010 #[must_use]
4011 pub const fn stft(&self) -> &StftParams {
4012 &self.stft
4013 }
4014
4015 /// Get the sample rate in Hz.
4016 #[inline]
4017 #[must_use]
4018 pub const fn sample_rate_hz(&self) -> f64 {
4019 self.sample_rate_hz
4020 }
4021
4022 /// Get the frame period in seconds.
4023 #[inline]
4024 #[must_use]
4025 #[allow(clippy::cast_precision_loss)]
4026 pub fn frame_period_seconds(&self) -> f64 {
4027 self.stft.hop_size().get() as f64 / self.sample_rate_hz
4028 }
4029
4030 /// Get the Nyquist frequency in Hz.
4031 #[inline]
4032 #[must_use]
4033 pub fn nyquist_hz(&self) -> f64 {
4034 self.sample_rate_hz * 0.5
4035 }
4036}
4037
4038/// Builder for [`SpectrogramParams`].
4039#[derive(Debug, Clone)]
4040pub struct SpectrogramParamsBuilder {
4041 sample_rate: Option<f64>,
4042 n_fft: Option<NonZeroUsize>,
4043 hop_size: Option<NonZeroUsize>,
4044 window: WindowType,
4045 centre: bool,
4046}
4047
4048impl Default for SpectrogramParamsBuilder {
4049 #[inline]
4050 fn default() -> Self {
4051 Self {
4052 sample_rate: None,
4053 n_fft: None,
4054 hop_size: None,
4055 window: WindowType::Hanning,
4056 centre: true,
4057 }
4058 }
4059}
4060
4061impl SpectrogramParamsBuilder {
4062 /// Set the sample rate in Hz.
4063 ///
4064 /// # Arguments
4065 ///
4066 /// * `sample_rate` - Sample rate in Hz.
4067 ///
4068 /// # Returns
4069 ///
4070 /// The updated builder instance.
4071 #[inline]
4072 #[must_use]
4073 pub const fn sample_rate(mut self, sample_rate: f64) -> Self {
4074 self.sample_rate = Some(sample_rate);
4075 self
4076 }
4077
4078 /// Set the FFT window size.
4079 ///
4080 /// # Arguments
4081 ///
4082 /// * `n_fft` - FFT size.
4083 ///
4084 /// # Returns
4085 ///
4086 /// The updated builder instance.
4087 #[inline]
4088 #[must_use]
4089 pub const fn n_fft(mut self, n_fft: NonZeroUsize) -> Self {
4090 self.n_fft = Some(n_fft);
4091 self
4092 }
4093
4094 /// Set the hop size (samples between successive frames).
4095 ///
4096 /// # Arguments
4097 ///
4098 /// * `hop_size` - Hop size in samples.
4099 ///
4100 /// # Returns
4101 ///
4102 /// The updated builder instance.
4103 #[inline]
4104 #[must_use]
4105 pub const fn hop_size(mut self, hop_size: NonZeroUsize) -> Self {
4106 self.hop_size = Some(hop_size);
4107 self
4108 }
4109
4110 /// Set the window function.
4111 ///
4112 /// # Arguments
4113 ///
4114 /// * `window` - Window function to apply to each frame.
4115 ///
4116 /// # Returns
4117 ///
4118 /// The updated builder instance.
4119 #[inline]
4120 #[must_use]
4121 pub const fn window(mut self, window: WindowType) -> Self {
4122 self.window = window;
4123 self
4124 }
4125
4126 /// Set whether to center frames (pad input signal).
4127 ///
4128 /// # Arguments
4129 ///
4130 /// * `centre` - If true, frames are centered by padding the input signal.
4131 ///
4132 /// # Returns
4133 ///
4134 /// The updated builder instance.
4135 #[inline]
4136 #[must_use]
4137 pub const fn centre(mut self, centre: bool) -> Self {
4138 self.centre = centre;
4139 self
4140 }
4141
4142 /// Build the [`SpectrogramParams`].
4143 ///
4144 /// # Errors
4145 ///
4146 /// Returns an error if required parameters are not set or are invalid.
4147 ///
4148 /// # Returns
4149 ///
4150 /// A `SpectrogramParams` instance.
4151 #[inline]
4152 pub fn build(self) -> SpectrogramResult<SpectrogramParams> {
4153 let sample_rate = self
4154 .sample_rate
4155 .ok_or_else(|| SpectrogramError::invalid_input("sample_rate must be set"))?;
4156 let n_fft = self
4157 .n_fft
4158 .ok_or_else(|| SpectrogramError::invalid_input("n_fft must be set"))?;
4159 let hop_size = self
4160 .hop_size
4161 .ok_or_else(|| SpectrogramError::invalid_input("hop_size must be set"))?;
4162
4163 let stft = StftParams::new(n_fft, hop_size, self.window, self.centre)?;
4164 SpectrogramParams::new(stft, sample_rate)
4165 }
4166}
4167
4168//
4169// ========================
4170// Standalone FFT Functions
4171// ========================
4172//
4173
4174/// Compute the real-to-complex FFT of a real-valued signal.
4175///
4176/// This function performs a forward FFT on real-valued input, returning the
4177/// complex frequency domain representation. Only the positive frequencies
4178/// are returned (length = `n_fft/2` + 1) due to conjugate symmetry.
4179///
4180/// # Arguments
4181///
4182/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4183/// * `n_fft` - FFT size
4184///
4185/// # Returns
4186///
4187/// A vector of complex frequency bins with length `n_fft/2` + 1.
4188///
4189/// # Automatic Zero-Padding
4190///
4191/// If the input signal is shorter than `n_fft`, it will be automatically
4192/// zero-padded to the required length. This is standard DSP practice and
4193/// preserves frequency resolution (bin spacing = sample_rate / n_fft).
4194///
4195/// ```
4196/// use spectrograms::{fft, nzu};
4197/// use non_empty_slice::non_empty_vec;
4198///
4199/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4200/// let signal = non_empty_vec![1.0, 2.0, 3.0]; // Only 3 samples
4201/// let spectrum = fft(&signal, nzu!(8))?; // Automatically padded to 8
4202/// assert_eq!(spectrum.len(), 5); // Output: 8/2 + 1 = 5 bins
4203/// # Ok(())
4204/// # }
4205/// ```
4206///
4207/// # Errors
4208///
4209/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4210///
4211/// # Examples
4212///
4213/// ```
4214/// use spectrograms::*;
4215/// use non_empty_slice::non_empty_vec;
4216///
4217/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4218/// let signal = non_empty_vec![0.0; nzu!(512)];
4219/// let spectrum = fft(&signal, nzu!(512))?;
4220///
4221/// assert_eq!(spectrum.len(), 257); // 512/2 + 1
4222/// # Ok(())
4223/// # }
4224/// ```
4225#[inline]
4226pub fn fft(
4227 samples: &NonEmptySlice<f64>,
4228 n_fft: NonZeroUsize,
4229) -> SpectrogramResult<Array1<Complex<f64>>> {
4230 if samples.len() > n_fft {
4231 return Err(SpectrogramError::invalid_input(format!(
4232 "Input length ({}) exceeds FFT size ({})",
4233 samples.len(),
4234 n_fft
4235 )));
4236 }
4237
4238 let out_len = r2c_output_size(n_fft.get());
4239
4240 // Get FFT plan from global cache (or create if first use)
4241 #[cfg(feature = "realfft")]
4242 let mut fft = {
4243 use crate::fft_backend::get_or_create_r2c_plan;
4244 let plan = get_or_create_r2c_plan(n_fft.get())?;
4245 // Clone the plan to get our own mutable copy with independent scratch buffer
4246 // This is cheap - only clones the scratch buffer, not the expensive twiddle factors
4247 (*plan).clone()
4248 };
4249
4250 #[cfg(feature = "fftw")]
4251 let mut fft = {
4252 use std::sync::Arc;
4253 let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
4254 crate::FftwPlan::new(Arc::new(plan))
4255 };
4256
4257 let input = if samples.len() < n_fft {
4258 let mut padded = vec![0.0; n_fft.get()];
4259 padded[..samples.len().get()].copy_from_slice(samples);
4260 // safety: samples.len() < n_fft checked above and n_fft > 0
4261
4262 unsafe { NonEmptyVec::new_unchecked(padded) }
4263 } else {
4264 samples.to_non_empty_vec()
4265 };
4266
4267 let mut output = vec![Complex::new(0.0, 0.0); out_len];
4268 fft.process(&input, &mut output)?;
4269 let output = Array1::from_vec(output);
4270 Ok(output)
4271}
4272
4273#[inline]
4274/// Compute the real-valued fft of a signal.
4275///
4276/// # Arguments
4277/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4278/// * `n_fft` - FFT size
4279///
4280/// # Returns
4281///
4282/// An array with length `n_fft/2` + 1.
4283///
4284/// # Errors
4285///
4286/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4287///
4288/// # Examples
4289///
4290/// ```
4291/// use spectrograms::*;
4292/// use non_empty_slice::non_empty_vec;
4293///
4294/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4295/// let signal = non_empty_vec![0.0; nzu!(512)];
4296/// let rfft_result = rfft(&signal, nzu!(512))?;
4297/// // equivalent to
4298/// let fft_result = fft(&signal, nzu!(512))?;
4299/// let rfft_result = fft_result.mapv(num_complex::Complex::norm);
4300/// # Ok(())
4301/// # }
4302///
4303pub fn rfft(samples: &NonEmptySlice<f64>, n_fft: NonZeroUsize) -> SpectrogramResult<Array1<f64>> {
4304 Ok(fft(samples, n_fft)?.mapv(Complex::norm))
4305}
4306
4307/// Compute the power spectrum of a signal (|X|²).
4308///
4309/// This function applies an optional window function and computes the
4310/// power spectrum via FFT. The result contains only positive frequencies.
4311///
4312/// # Arguments
4313///
4314/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4315/// * `n_fft` - FFT size
4316/// * `window` - Optional window function (None for rectangular window)
4317///
4318/// # Returns
4319///
4320/// A vector of power values with length `n_fft/2` + 1.
4321///
4322/// # Automatic Zero-Padding
4323///
4324/// If the input signal is shorter than `n_fft`, it will be automatically
4325/// zero-padded to the required length. This is standard DSP practice and
4326/// preserves frequency resolution (bin spacing = sample_rate / n_fft).
4327///
4328/// ```
4329/// use spectrograms::{power_spectrum, nzu};
4330/// use non_empty_slice::non_empty_vec;
4331///
4332/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4333/// let signal = non_empty_vec![1.0, 2.0, 3.0]; // Only 3 samples
4334/// let power = power_spectrum(&signal, nzu!(8), None)?;
4335/// assert_eq!(power.len(), nzu!(5)); // Output: 8/2 + 1 = 5 bins
4336/// # Ok(())
4337/// # }
4338/// ```
4339///
4340/// # Errors
4341///
4342/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4343///
4344/// # Examples
4345///
4346/// ```
4347/// use spectrograms::*;
4348/// use non_empty_slice::non_empty_vec;
4349///
4350/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4351/// let signal = non_empty_vec![0.0; nzu!(512)];
4352/// let power = power_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4353///
4354/// assert_eq!(power.len(), nzu!(257)); // 512/2 + 1
4355/// # Ok(())
4356/// # }
4357/// ```
4358#[inline]
4359pub fn power_spectrum(
4360 samples: &NonEmptySlice<f64>,
4361 n_fft: NonZeroUsize,
4362 window: Option<WindowType>,
4363) -> SpectrogramResult<NonEmptyVec<f64>> {
4364 if samples.len() > n_fft {
4365 return Err(SpectrogramError::invalid_input(format!(
4366 "Input length ({}) exceeds FFT size ({})",
4367 samples.len(),
4368 n_fft
4369 )));
4370 }
4371
4372 let mut windowed = vec![0.0; n_fft.get()];
4373 windowed[..samples.len().get()].copy_from_slice(samples);
4374
4375 if let Some(win_type) = window {
4376 let window_samples = make_window(win_type, n_fft);
4377 for i in 0..n_fft.get() {
4378 windowed[i] *= window_samples[i];
4379 }
4380 }
4381
4382 // safety: windowed is non-empty since n_fft > 0
4383 let windowed = unsafe { NonEmptySlice::new_unchecked(&windowed) };
4384 let fft_result = fft(windowed, n_fft)?;
4385 let fft_result = fft_result
4386 .iter()
4387 .map(num_complex::Complex::norm_sqr)
4388 .collect();
4389 // safety: fft_result is non-empty since fft returned successfully
4390 Ok(unsafe { NonEmptyVec::new_unchecked(fft_result) })
4391}
4392
4393/// Compute the magnitude spectrum of a signal (|X|).
4394///
4395/// This function applies an optional window function and computes the
4396/// magnitude spectrum via FFT. The result contains only positive frequencies.
4397///
4398/// # Arguments
4399///
4400/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4401/// * `n_fft` - FFT size
4402/// * `window` - Optional window function (None for rectangular window)
4403///
4404/// # Automatic Zero-Padding
4405///
4406/// If the input signal is shorter than `n_fft`, it will be automatically
4407/// zero-padded to the required length. This preserves frequency resolution.
4408///
4409/// # Errors
4410///
4411/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4412///
4413/// # Returns
4414///
4415/// A vector of magnitude values with length `n_fft/2` + 1.
4416///
4417/// # Examples
4418///
4419/// ```
4420/// use spectrograms::*;
4421/// use non_empty_slice::non_empty_vec;
4422///
4423/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4424/// let signal = non_empty_vec![0.0; nzu!(512)];
4425/// let magnitude = magnitude_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4426///
4427/// assert_eq!(magnitude.len(), nzu!(257)); // 512/2 + 1
4428/// # Ok(())
4429/// # }
4430/// ```
4431#[inline]
4432pub fn magnitude_spectrum(
4433 samples: &NonEmptySlice<f64>,
4434 n_fft: NonZeroUsize,
4435 window: Option<WindowType>,
4436) -> SpectrogramResult<NonEmptyVec<f64>> {
4437 let power = power_spectrum(samples, n_fft, window)?;
4438 let power = power.iter().map(|&p| p.sqrt()).collect();
4439 // safety: power is non-empty since power_spectrum returned successfully
4440 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
4441}
4442
4443/// Compute the Short-Time Fourier Transform (STFT) of a signal.
4444///
4445/// This function computes the STFT by applying a sliding window and FFT
4446/// to sequential frames of the input signal.
4447///
4448/// # Arguments
4449///
4450/// * `samples` - Input signal (any type that can be converted to a slice)
4451/// * `n_fft` - FFT size
4452/// * `hop_size` - Number of samples between successive frames
4453/// * `window` - Window function to apply to each frame
4454/// * `center` - If true, pad the signal to center frames
4455///
4456/// # Returns
4457///
4458/// A 2D array with shape (`frequency_bins`, `time_frames`) containing complex STFT values.
4459///
4460/// # Errors
4461///
4462/// Returns an error if:
4463/// - `hop_size` > `n_fft`
4464/// - STFT computation fails
4465///
4466/// # Examples
4467///
4468/// ```
4469/// use spectrograms::*;
4470/// use non_empty_slice::non_empty_vec;
4471///
4472/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4473/// let signal = non_empty_vec![0.0; nzu!(16000)];
4474/// let stft_matrix = stft(&signal, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4475///
4476/// println!("STFT: {} bins x {} frames", stft_matrix.nrows(), stft_matrix.ncols());
4477/// # Ok(())
4478/// # }
4479/// ```
4480#[inline]
4481pub fn stft(
4482 samples: &NonEmptySlice<f64>,
4483 n_fft: NonZeroUsize,
4484 hop_size: NonZeroUsize,
4485 window: WindowType,
4486 center: bool,
4487) -> SpectrogramResult<Array2<Complex<f64>>> {
4488 let stft_params = StftParams::new(n_fft, hop_size, window, center)?;
4489 let params = SpectrogramParams::new(stft_params, 1.0)?; // dummy sample rate
4490
4491 let planner = SpectrogramPlanner::new();
4492 let result = planner.compute_stft(samples, ¶ms)?;
4493
4494 Ok(result.data)
4495}
4496
4497/// Compute the inverse real FFT (complex-to-real IFFT).
4498///
4499/// This function performs an inverse FFT, converting frequency domain data
4500/// back to the time domain. Only the positive frequencies need to be provided
4501/// (length = `n_fft/2` + 1) due to conjugate symmetry.
4502///
4503/// # Arguments
4504///
4505/// * `spectrum` - Complex frequency bins (length should be `n_fft/2` + 1)
4506/// * `n_fft` - FFT size (length of the output signal)
4507///
4508/// # Returns
4509///
4510/// A vector of real-valued time-domain samples with length `n_fft`.
4511///
4512/// # Errors
4513///
4514/// Returns an error if:
4515/// - `spectrum` length doesn't match `n_fft/2` + 1
4516/// - Inverse FFT computation fails
4517///
4518/// # Examples
4519///
4520/// ```
4521/// use spectrograms::*;
4522/// use non_empty_slice::{non_empty_vec, NonEmptySlice};
4523/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4524/// // Forward FFT
4525/// let signal = non_empty_vec![1.0, 0.0, -1.0, 0.0, 1.0, 0.0, -1.0, 0.0];
4526/// let spectrum = fft(&signal, nzu!(8))?;
4527/// let slice = spectrum.as_slice().unwrap();
4528/// let spectrum_slice = NonEmptySlice::new(slice).unwrap();
4529/// // Inverse FFT
4530/// let reconstructed = irfft(spectrum_slice, nzu!(8))?;
4531///
4532/// assert_eq!(reconstructed.len(), nzu!(8));
4533/// # Ok(())
4534/// # }
4535/// ```
4536#[inline]
4537pub fn irfft(
4538 spectrum: &NonEmptySlice<Complex<f64>>,
4539 n_fft: NonZeroUsize,
4540) -> SpectrogramResult<NonEmptyVec<f64>> {
4541 use crate::fft_backend::{C2rPlan, r2c_output_size};
4542
4543 let n_fft = n_fft.get();
4544 let expected_len = r2c_output_size(n_fft);
4545 if spectrum.len().get() != expected_len {
4546 return Err(SpectrogramError::dimension_mismatch(
4547 expected_len,
4548 spectrum.len().get(),
4549 ));
4550 }
4551
4552 // Get inverse FFT plan from global cache (or create if first use)
4553 #[cfg(feature = "realfft")]
4554 let mut ifft = {
4555 use crate::fft_backend::get_or_create_c2r_plan;
4556 let plan = get_or_create_c2r_plan(n_fft)?;
4557 // Clone to get our own mutable copy with independent scratch buffer
4558 (*plan).clone()
4559 };
4560
4561 #[cfg(feature = "fftw")]
4562 let mut ifft = {
4563 use crate::fft_backend::C2rPlanner;
4564 let mut planner = crate::FftwPlanner::new();
4565 planner.plan_c2r(n_fft)?
4566 };
4567
4568 let mut output = vec![0.0; n_fft];
4569 ifft.process(spectrum.as_slice(), &mut output)?;
4570
4571 // Safety: output is non-empty since n_fft > 0
4572 Ok(unsafe { NonEmptyVec::new_unchecked(output) })
4573}
4574
4575/// Reconstruct a time-domain signal from its STFT using overlap-add.
4576///
4577/// This function performs the inverse Short-Time Fourier Transform, converting
4578/// a 2D complex STFT matrix back to a 1D time-domain signal using overlap-add
4579/// synthesis with the specified window function.
4580///
4581/// # Arguments
4582///
4583/// * `stft_matrix` - Complex STFT with shape (`frequency_bins`, `time_frames`)
4584/// * `n_fft` - FFT size
4585/// * `hop_size` - Number of samples between successive frames
4586/// * `window` - Window function to apply (should match forward STFT window)
4587/// * `center` - If true, assume the forward STFT was centered
4588///
4589/// # Returns
4590///
4591/// A vector of reconstructed time-domain samples.
4592///
4593/// # Errors
4594///
4595/// Returns an error if:
4596/// - `stft_matrix` dimensions are inconsistent with `n_fft`
4597/// - `hop_size` > `n_fft`
4598/// - Inverse STFT computation fails
4599///
4600/// # Examples
4601///
4602/// ```
4603/// use spectrograms::*;
4604/// use non_empty_slice::non_empty_vec;
4605///
4606/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4607/// // Generate signal
4608/// let signal = non_empty_vec![1.0; nzu!(16000)];
4609///
4610/// // Forward STFT
4611/// let stft_matrix = stft(&signal, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4612///
4613/// // Inverse STFT
4614/// let reconstructed = istft(&stft_matrix, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4615///
4616/// println!("Original: {} samples", signal.len());
4617/// println!("Reconstructed: {} samples", reconstructed.len());
4618/// # Ok(())
4619/// # }
4620/// ```
4621#[inline]
4622pub fn istft(
4623 stft_matrix: &Array2<Complex<f64>>,
4624 n_fft: NonZeroUsize,
4625 hop_size: NonZeroUsize,
4626 window: WindowType,
4627 center: bool,
4628) -> SpectrogramResult<Vec<f64>> {
4629 use crate::fft_backend::{C2rPlan, C2rPlanner, r2c_output_size};
4630
4631 let n_bins = stft_matrix.nrows();
4632 let n_frames = stft_matrix.ncols();
4633
4634 let expected_bins = r2c_output_size(n_fft.get());
4635 if n_bins != expected_bins {
4636 return Err(SpectrogramError::dimension_mismatch(expected_bins, n_bins));
4637 }
4638 if hop_size.get() > n_fft.get() {
4639 return Err(SpectrogramError::invalid_input("hop_size must be <= n_fft"));
4640 }
4641 // Create inverse FFT plan
4642 #[cfg(feature = "realfft")]
4643 let mut ifft = {
4644 let mut planner = crate::RealFftPlanner::new();
4645 planner.plan_c2r(n_fft.get())?
4646 };
4647
4648 #[cfg(feature = "fftw")]
4649 let mut ifft = {
4650 let mut planner = crate::FftwPlanner::new();
4651 planner.plan_c2r(n_fft.get())?
4652 };
4653
4654 // Generate window
4655 let window_samples = make_window(window, n_fft);
4656 let n_fft = n_fft.get();
4657 let hop_size = hop_size.get();
4658 // Calculate output length
4659 let pad = if center { n_fft / 2 } else { 0 };
4660 let output_len = (n_frames - 1) * hop_size + n_fft;
4661 let unpadded_len = output_len.saturating_sub(2 * pad);
4662
4663 // Allocate output buffer and normalization buffer
4664 let mut output = vec![0.0; output_len];
4665 let mut norm = vec![0.0; output_len];
4666
4667 // Overlap-add synthesis
4668 let mut frame_buffer = vec![Complex::new(0.0, 0.0); n_bins];
4669 let mut time_frame = vec![0.0; n_fft];
4670
4671 for frame_idx in 0..n_frames {
4672 // Extract complex frame from STFT matrix
4673 for bin_idx in 0..n_bins {
4674 frame_buffer[bin_idx] = stft_matrix[[bin_idx, frame_idx]];
4675 }
4676
4677 // Inverse FFT
4678 ifft.process(&frame_buffer, &mut time_frame)?;
4679
4680 // Apply window
4681 for i in 0..n_fft {
4682 time_frame[i] *= window_samples[i];
4683 }
4684
4685 // Overlap-add into output buffer
4686 let start = frame_idx * hop_size;
4687 for i in 0..n_fft {
4688 let pos = start + i;
4689 if pos < output_len {
4690 output[pos] += time_frame[i];
4691 norm[pos] += window_samples[i] * window_samples[i];
4692 }
4693 }
4694 }
4695
4696 // Normalize by window energy
4697 for i in 0..output_len {
4698 if norm[i] > 1e-10 {
4699 output[i] /= norm[i];
4700 }
4701 }
4702
4703 // Remove padding if centered
4704 if center && unpadded_len > 0 {
4705 let start = pad;
4706 let end = start + unpadded_len;
4707 output = output[start..end.min(output_len)].to_vec();
4708 }
4709
4710 Ok(output)
4711}
4712
4713//
4714// ========================
4715// Reusable FFT Plans
4716// ========================
4717//
4718
4719/// A reusable FFT planner for efficient repeated FFT operations.
4720///
4721/// This planner caches FFT plans internally, making repeated FFT operations
4722/// of the same size much more efficient than calling `fft()` repeatedly.
4723///
4724/// # Examples
4725///
4726/// ```
4727/// use spectrograms::*;
4728/// use non_empty_slice::non_empty_vec;
4729///
4730/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4731/// let mut planner = FftPlanner::new();
4732///
4733/// // Process multiple signals of the same size efficiently
4734/// for _ in 0..100 {
4735/// let signal = non_empty_vec![0.0; nzu!(512)];
4736/// let spectrum = planner.fft(&signal, nzu!(512))?;
4737/// // ... process spectrum ...
4738/// }
4739/// # Ok(())
4740/// # }
4741/// ```
4742pub struct FftPlanner {
4743 #[cfg(feature = "realfft")]
4744 inner: crate::RealFftPlanner,
4745 #[cfg(feature = "fftw")]
4746 inner: crate::FftwPlanner,
4747}
4748
4749impl FftPlanner {
4750 /// Create a new FFT planner with empty cache.
4751 #[inline]
4752 #[must_use]
4753 pub fn new() -> Self {
4754 Self {
4755 #[cfg(feature = "realfft")]
4756 inner: crate::RealFftPlanner::new(),
4757 #[cfg(feature = "fftw")]
4758 inner: crate::FftwPlanner::new(),
4759 }
4760 }
4761
4762 /// Compute forward FFT, reusing cached plans.
4763 ///
4764 /// This is more efficient than calling the standalone `fft()` function
4765 /// repeatedly for the same FFT size.
4766 ///
4767 /// # Automatic Zero-Padding
4768 ///
4769 /// If the input signal is shorter than `n_fft`, it will be automatically
4770 /// zero-padded to the required length.
4771 ///
4772 /// # Errors
4773 ///
4774 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4775 ///
4776 /// # Examples
4777 ///
4778 /// ```
4779 /// use spectrograms::*;
4780 /// use non_empty_slice::non_empty_vec;
4781 ///
4782 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4783 /// let mut planner = FftPlanner::new();
4784 ///
4785 /// let signal = non_empty_vec![1.0; nzu!(512)];
4786 /// let spectrum = planner.fft(&signal, nzu!(512))?;
4787 ///
4788 /// assert_eq!(spectrum.len(), 257); // 512/2 + 1
4789 /// # Ok(())
4790 /// # }
4791 /// ```
4792 #[inline]
4793 pub fn fft(
4794 &mut self,
4795 samples: &NonEmptySlice<f64>,
4796 n_fft: NonZeroUsize,
4797 ) -> SpectrogramResult<Array1<Complex<f64>>> {
4798 use crate::fft_backend::{R2cPlan, R2cPlanner, r2c_output_size};
4799
4800 if samples.len() > n_fft {
4801 return Err(SpectrogramError::invalid_input(format!(
4802 "Input length ({}) exceeds FFT size ({})",
4803 samples.len(),
4804 n_fft
4805 )));
4806 }
4807
4808 let out_len = r2c_output_size(n_fft.get());
4809 let mut plan = self.inner.plan_r2c(n_fft.get())?;
4810
4811 let input = if samples.len() < n_fft {
4812 let mut padded = vec![0.0; n_fft.get()];
4813 padded[..samples.len().get()].copy_from_slice(samples);
4814
4815 // safety: samples.len() < n_fft checked above and n_fft > 0
4816 unsafe { NonEmptyVec::new_unchecked(padded) }
4817 } else {
4818 samples.to_non_empty_vec()
4819 };
4820
4821 let mut output = vec![Complex::new(0.0, 0.0); out_len];
4822 plan.process(&input, &mut output)?;
4823
4824 let output = Array1::from_vec(output);
4825 Ok(output)
4826 }
4827
4828 /// Compute forward real FFT magnitude
4829 ///
4830 /// # Errors
4831 ///
4832 /// Returns an error if:
4833 /// - `n_fft` doesn't match the samples length
4834 ///
4835 ///
4836 #[inline]
4837 pub fn rfft(
4838 &mut self,
4839 samples: &NonEmptySlice<f64>,
4840 n_fft: NonZeroUsize,
4841 ) -> SpectrogramResult<Array1<f64>> {
4842 let fft_with_complex = fft(samples, n_fft)?;
4843 Ok(fft_with_complex.mapv(Complex::norm))
4844 }
4845
4846 /// Compute inverse FFT, reusing cached plans.
4847 ///
4848 /// This is more efficient than calling the standalone `irfft()` function
4849 /// repeatedly for the same FFT size.
4850 ///
4851 /// # Errors
4852 /// Returns an error if:
4853 ///
4854 /// - The calculated expected length of `spectrum` doesn't match its actual length
4855 ///
4856 /// # Examples
4857 ///
4858 /// ```
4859 /// use spectrograms::*;
4860 /// use non_empty_slice::{non_empty_vec, NonEmptySlice};
4861 ///
4862 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4863 /// let mut planner = FftPlanner::new();
4864 ///
4865 /// // Forward FFT
4866 /// let signal = non_empty_vec![1.0; nzu!(512)];
4867 /// let spectrum = planner.fft(&signal, nzu!(512))?;
4868 ///
4869 /// // Inverse FFT
4870 /// let spectrum_slice = NonEmptySlice::new(spectrum.as_slice().unwrap()).unwrap();
4871 /// let reconstructed = planner.irfft(spectrum_slice, nzu!(512))?;
4872 ///
4873 /// assert_eq!(reconstructed.len(), nzu!(512));
4874 /// # Ok(())
4875 /// # }
4876 /// ```
4877 #[inline]
4878 pub fn irfft(
4879 &mut self,
4880 spectrum: &NonEmptySlice<Complex<f64>>,
4881 n_fft: NonZeroUsize,
4882 ) -> SpectrogramResult<NonEmptyVec<f64>> {
4883 use crate::fft_backend::{C2rPlan, C2rPlanner, r2c_output_size};
4884
4885 let expected_len = r2c_output_size(n_fft.get());
4886 if spectrum.len().get() != expected_len {
4887 return Err(SpectrogramError::dimension_mismatch(
4888 expected_len,
4889 spectrum.len().get(),
4890 ));
4891 }
4892
4893 let mut plan = self.inner.plan_c2r(n_fft.get())?;
4894 let mut output = vec![0.0; n_fft.get()];
4895 plan.process(spectrum, &mut output)?;
4896 // Safety: output is non-empty since n_fft > 0
4897 let output = unsafe { NonEmptyVec::new_unchecked(output) };
4898 Ok(output)
4899 }
4900
4901 /// Compute power spectrum with optional windowing, reusing cached plans.
4902 ///
4903 /// # Automatic Zero-Padding
4904 ///
4905 /// If the input signal is shorter than `n_fft`, it will be automatically
4906 /// zero-padded to the required length.
4907 ///
4908 /// # Errors
4909 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4910 ///
4911 /// # Examples
4912 ///
4913 /// ```
4914 /// use spectrograms::*;
4915 /// use non_empty_slice::non_empty_vec;
4916 ///
4917 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4918 /// let mut planner = FftPlanner::new();
4919 ///
4920 /// let signal = non_empty_vec![1.0; nzu!(512)];
4921 /// let power = planner.power_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4922 ///
4923 /// assert_eq!(power.len(), nzu!(257));
4924 /// # Ok(())
4925 /// # }
4926 /// ```
4927 #[inline]
4928 pub fn power_spectrum(
4929 &mut self,
4930 samples: &NonEmptySlice<f64>,
4931 n_fft: NonZeroUsize,
4932 window: Option<WindowType>,
4933 ) -> SpectrogramResult<NonEmptyVec<f64>> {
4934 if samples.len() > n_fft {
4935 return Err(SpectrogramError::invalid_input(format!(
4936 "Input length ({}) exceeds FFT size ({})",
4937 samples.len(),
4938 n_fft
4939 )));
4940 }
4941
4942 let mut windowed = vec![0.0; n_fft.get()];
4943 windowed[..samples.len().get()].copy_from_slice(samples);
4944 if let Some(win_type) = window {
4945 let window_samples = make_window(win_type, n_fft);
4946 for i in 0..n_fft.get() {
4947 windowed[i] *= window_samples[i];
4948 }
4949 }
4950
4951 // safety: windowed is non-empty since n_fft > 0
4952 let windowed = unsafe { NonEmptySlice::new_unchecked(&windowed) };
4953 let fft_result = self.fft(windowed, n_fft)?;
4954 let f = fft_result
4955 .iter()
4956 .map(num_complex::Complex::norm_sqr)
4957 .collect();
4958 // safety: fft_result is non-empty since fft returned successfully
4959 Ok(unsafe { NonEmptyVec::new_unchecked(f) })
4960 }
4961
4962 /// Compute magnitude spectrum with optional windowing, reusing cached plans.
4963 ///
4964 /// # Automatic Zero-Padding
4965 ///
4966 /// If the input signal is shorter than `n_fft`, it will be automatically
4967 /// zero-padded to the required length.
4968 ///
4969 /// # Errors
4970 /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4971 ///
4972 /// # Examples
4973 ///
4974 /// ```
4975 /// use spectrograms::*;
4976 /// use non_empty_slice::non_empty_vec;
4977 ///
4978 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4979 /// let mut planner = FftPlanner::new();
4980 ///
4981 /// let signal = non_empty_vec![1.0; nzu!(512)];
4982 /// let magnitude = planner.magnitude_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4983 ///
4984 /// assert_eq!(magnitude.len(), nzu!(257));
4985 /// # Ok(())
4986 /// # }
4987 /// ```
4988 #[inline]
4989 pub fn magnitude_spectrum(
4990 &mut self,
4991 samples: &NonEmptySlice<f64>,
4992 n_fft: NonZeroUsize,
4993 window: Option<WindowType>,
4994 ) -> SpectrogramResult<NonEmptyVec<f64>> {
4995 let power = self.power_spectrum(samples, n_fft, window)?;
4996 let power = power.iter().map(|&p| p.sqrt()).collect::<Vec<f64>>();
4997 // safety: power is non-empty since power_spectrum returned successfully
4998 Ok(unsafe { NonEmptyVec::new_unchecked(power) })
4999 }
5000}
5001
5002impl Default for FftPlanner {
5003 #[inline]
5004 fn default() -> Self {
5005 Self::new()
5006 }
5007}
5008
5009#[cfg(test)]
5010mod tests {
5011 use super::*;
5012
5013 #[test]
5014 fn test_sparse_matrix_basic() {
5015 // Create a simple 3x5 sparse matrix
5016 let mut sparse = SparseMatrix::new(3, 5);
5017
5018 // Row 0: only column 1 has value 2.0
5019 sparse.set(0, 1, 2.0);
5020
5021 // Row 1: columns 2 and 3
5022 sparse.set(1, 2, 0.5);
5023 sparse.set(1, 3, 1.5);
5024
5025 // Row 2: columns 0 and 4
5026 sparse.set(2, 0, 3.0);
5027 sparse.set(2, 4, 1.0);
5028
5029 // Test matrix-vector multiplication
5030 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
5031 let mut output = vec![0.0; 3];
5032
5033 sparse.multiply_vec(&input, &mut output);
5034
5035 // Expected results:
5036 // Row 0: 2.0 * 2.0 = 4.0
5037 // Row 1: 0.5 * 3.0 + 1.5 * 4.0 = 1.5 + 6.0 = 7.5
5038 // Row 2: 3.0 * 1.0 + 1.0 * 5.0 = 3.0 + 5.0 = 8.0
5039 assert_eq!(output[0], 4.0);
5040 assert_eq!(output[1], 7.5);
5041 assert_eq!(output[2], 8.0);
5042 }
5043
5044 #[test]
5045 fn test_sparse_matrix_zeros_ignored() {
5046 // Verify that zero values are not stored
5047 let mut sparse = SparseMatrix::new(2, 3);
5048
5049 sparse.set(0, 0, 1.0);
5050 sparse.set(0, 1, 0.0); // Should be ignored
5051 sparse.set(0, 2, 2.0);
5052
5053 // Only 2 values should be stored in row 0
5054 assert_eq!(sparse.values[0].len(), 2);
5055 assert_eq!(sparse.indices[0].len(), 2);
5056
5057 // The stored indices should be 0 and 2
5058 assert_eq!(sparse.indices[0], vec![0, 2]);
5059 assert_eq!(sparse.values[0], vec![1.0, 2.0]);
5060 }
5061
5062 #[test]
5063 fn test_loghz_matrix_sparsity() {
5064 // Verify that LogHz matrices are very sparse (1-2 non-zeros per row)
5065 let sample_rate = 16000.0;
5066 let n_fft = nzu!(512);
5067 let n_bins = nzu!(128);
5068 let f_min = 20.0;
5069 let f_max = sample_rate / 2.0;
5070
5071 let (matrix, _freqs) =
5072 build_loghz_matrix(sample_rate, n_fft, n_bins, f_min, f_max).unwrap();
5073
5074 // Each row should have at most 2 non-zero values (linear interpolation)
5075 for row_idx in 0..matrix.nrows() {
5076 let nnz = matrix.values[row_idx].len();
5077 assert!(
5078 nnz <= 2,
5079 "Row {} has {} non-zeros, expected at most 2",
5080 row_idx,
5081 nnz
5082 );
5083 assert!(nnz >= 1, "Row {} has no non-zeros", row_idx);
5084 }
5085
5086 // Total non-zeros should be close to n_bins * 2
5087 let total_nnz: usize = matrix.values.iter().map(|v| v.len()).sum();
5088 assert!(total_nnz <= n_bins.get() * 2);
5089 assert!(total_nnz >= n_bins.get()); // At least 1 per row
5090 }
5091
5092 #[test]
5093 fn test_mel_matrix_sparsity() {
5094 // Verify that Mel matrices are sparse (triangular filters)
5095 let sample_rate = 16000.0;
5096 let n_fft = nzu!(512);
5097 let n_mels = nzu!(40);
5098 let f_min = 0.0;
5099 let f_max = sample_rate / 2.0;
5100
5101 let matrix =
5102 build_mel_filterbank_matrix(sample_rate, n_fft, n_mels, f_min, f_max, MelNorm::None)
5103 .unwrap();
5104
5105 let n_fft_bins = r2c_output_size(n_fft.get());
5106
5107 // Calculate sparsity
5108 let total_nnz: usize = matrix.values.iter().map(|v| v.len()).sum();
5109 let total_elements = n_mels.get() * n_fft_bins;
5110 let sparsity = 1.0 - (total_nnz as f64 / total_elements as f64);
5111
5112 // Mel filterbanks should be >80% sparse
5113 assert!(
5114 sparsity > 0.8,
5115 "Mel matrix sparsity is only {:.1}%, expected >80%",
5116 sparsity * 100.0
5117 );
5118
5119 // Each mel filter should have significantly fewer than n_fft_bins non-zeros
5120 for row_idx in 0..matrix.nrows() {
5121 let nnz = matrix.values[row_idx].len();
5122 assert!(
5123 nnz < n_fft_bins / 2,
5124 "Mel filter {} is not sparse enough",
5125 row_idx
5126 );
5127 }
5128 }
5129}