Skip to main content

spectrograms/
spectrogram.rs

1use std::marker::PhantomData;
2use std::num::NonZeroUsize;
3use std::ops::{Deref, DerefMut};
4
5use ndarray::{Array1, Array2};
6use non_empty_slice::{NonEmptySlice, NonEmptyVec, non_empty_vec};
7use num_complex::Complex;
8
9#[cfg(feature = "python")]
10use pyo3::prelude::*;
11
12use crate::cqt::CqtKernel;
13use crate::erb::ErbFilterbank;
14use crate::{
15    CqtParams, ErbParams, R2cPlan, SpectrogramError, SpectrogramResult, WindowType,
16    min_max_single_pass, nzu,
17};
18const EPS: f64 = 1e-12;
19
20//
21// ========================
22// Sparse Matrix for efficient filterbank multiplication
23// ========================
24//
25
26/// Row-wise sparse matrix optimized for matrix-vector multiplication.
27///
28/// This structure stores sparse data as a vector of vectors, where each row maintains
29/// its own list of non-zero values and corresponding column indices. This is more
30/// flexible than traditional CSR format and allows efficient row-by-row construction.
31///
32/// This structure is designed for matrices with very few non-zero values per row,
33/// such as mel filterbanks (triangular filters) and logarithmic frequency mappings
34/// (linear interpolation between 1-2 adjacent bins).
35///
36/// For typical spectrograms:
37/// - `LogHz` interpolation: Only 1-2 non-zeros per row (~99% sparse)
38/// - Mel filterbank: ~10-50 non-zeros per row depending on FFT size (~90-98% sparse)
39///
40/// By storing only non-zero values, we avoid wasting CPU cycles multiplying by zero,
41/// which can provide 10-100x speedup compared to dense matrix multiplication.
42#[derive(Debug, Clone)]
43struct SparseMatrix {
44    /// Number of rows
45    nrows: usize,
46    /// Number of columns
47    ncols: usize,
48    /// Non-zero values for each row (row-major order)
49    values: Vec<Vec<f64>>,
50    /// Column indices for each non-zero value
51    indices: Vec<Vec<usize>>,
52}
53
54impl SparseMatrix {
55    /// Create a new sparse matrix with the given dimensions.
56    fn new(nrows: usize, ncols: usize) -> Self {
57        Self {
58            nrows,
59            ncols,
60            values: vec![Vec::new(); nrows],
61            indices: vec![Vec::new(); nrows],
62        }
63    }
64
65    /// Set a value in the matrix. Only stores if value is non-zero.
66    ///
67    /// # Panics (debug mode only)
68    /// Panics in debug builds if row or col are out of bounds.
69    fn set(&mut self, row: usize, col: usize, value: f64) {
70        debug_assert!(
71            row < self.nrows && col < self.ncols,
72            "SparseMatrix index out of bounds: ({}, {}) for {}x{} matrix",
73            row,
74            col,
75            self.nrows,
76            self.ncols
77        );
78        if row >= self.nrows || col >= self.ncols {
79            return;
80        }
81
82        // Only store non-zero values (with small epsilon for numerical stability)
83        if value.abs() > 1e-10 {
84            self.values[row].push(value);
85            self.indices[row].push(col);
86        }
87    }
88
89    /// Get the number of rows.
90    const fn nrows(&self) -> usize {
91        self.nrows
92    }
93
94    /// Get the number of columns.
95    const fn ncols(&self) -> usize {
96        self.ncols
97    }
98
99    /// Perform sparse matrix-vector multiplication: out = self * input
100    /// This is much faster than dense multiplication when the matrix is sparse.
101    #[inline]
102    fn multiply_vec(&self, input: &[f64], out: &mut [f64]) {
103        debug_assert_eq!(input.len(), self.ncols);
104        debug_assert_eq!(out.len(), self.nrows);
105
106        for (row_idx, (row_values, row_indices)) in
107            self.values.iter().zip(&self.indices).enumerate()
108        {
109            let mut acc = 0.0;
110            for (&value, &col_idx) in row_values.iter().zip(row_indices) {
111                acc += value * input[col_idx];
112            }
113            out[row_idx] = acc;
114        }
115    }
116}
117
118// Linear frequency
119pub type LinearPowerSpectrogram = Spectrogram<LinearHz, Power>;
120pub type LinearMagnitudeSpectrogram = Spectrogram<LinearHz, Magnitude>;
121pub type LinearDbSpectrogram = Spectrogram<LinearHz, Decibels>;
122pub type LinearSpectrogram<AmpScale> = Spectrogram<LinearHz, AmpScale>;
123
124// Log-frequency (e.g. CQT-style)
125pub type LogHzPowerSpectrogram = Spectrogram<LogHz, Power>;
126pub type LogHzMagnitudeSpectrogram = Spectrogram<LogHz, Magnitude>;
127pub type LogHzDbSpectrogram = Spectrogram<LogHz, Decibels>;
128pub type LogHzSpectrogram<AmpScale> = Spectrogram<LogHz, AmpScale>;
129
130// ERB / gammatone
131pub type ErbPowerSpectrogram = Spectrogram<Erb, Power>;
132pub type ErbMagnitudeSpectrogram = Spectrogram<Erb, Magnitude>;
133pub type ErbDbSpectrogram = Spectrogram<Erb, Decibels>;
134pub type GammatonePowerSpectrogram = ErbPowerSpectrogram;
135pub type GammatoneMagnitudeSpectrogram = ErbMagnitudeSpectrogram;
136pub type GammatoneDbSpectrogram = ErbDbSpectrogram;
137pub type ErbSpectrogram<AmpScale> = Spectrogram<Erb, AmpScale>;
138pub type GammatoneSpectrogram<AmpScale> = ErbSpectrogram<AmpScale>;
139
140// Mel
141pub type MelMagnitudeSpectrogram = Spectrogram<Mel, Magnitude>;
142pub type MelPowerSpectrogram = Spectrogram<Mel, Power>;
143pub type MelDbSpectrogram = Spectrogram<Mel, Decibels>;
144pub type LogMelSpectrogram = MelDbSpectrogram;
145pub type MelSpectrogram<AmpScale> = Spectrogram<Mel, AmpScale>;
146
147// CQT
148pub type CqtPowerSpectrogram = Spectrogram<Cqt, Power>;
149pub type CqtMagnitudeSpectrogram = Spectrogram<Cqt, Magnitude>;
150pub type CqtDbSpectrogram = Spectrogram<Cqt, Decibels>;
151pub type CqtSpectrogram<AmpScale> = Spectrogram<Cqt, AmpScale>;
152
153use crate::fft_backend::r2c_output_size;
154
155/// A spectrogram plan is the compiled, reusable execution object.
156///
157/// It owns:
158/// - FFT plan (reusable)
159/// - window samples
160/// - mapping (identity / mel filterbank / etc.)
161/// - amplitude scaling config
162/// - workspace buffers to avoid allocations in hot loops
163///
164/// It computes one specific spectrogram type: `Spectrogram<FreqScale, AmpScale>`.
165///
166/// # Type Parameters
167///
168/// - `FreqScale`: Frequency scale type (e.g. `LinearHz`, `LogHz`, `Mel`, etc.)
169/// - `AmpScale`: Amplitude scaling type (e.g. `Power`, `Magnitude`, `Decibels`, etc.)
170pub struct SpectrogramPlan<FreqScale, AmpScale>
171where
172    AmpScale: AmpScaleSpec + 'static,
173    FreqScale: Copy + Clone + 'static,
174{
175    params: SpectrogramParams,
176
177    stft: StftPlan,
178    mapping: FrequencyMapping<FreqScale>,
179    scaling: AmplitudeScaling<AmpScale>,
180
181    freq_axis: FrequencyAxis<FreqScale>,
182    workspace: Workspace,
183
184    _amp: PhantomData<AmpScale>,
185}
186
187impl<FreqScale, AmpScale> SpectrogramPlan<FreqScale, AmpScale>
188where
189    AmpScale: AmpScaleSpec + 'static,
190    FreqScale: Copy + Clone + 'static,
191{
192    /// Get the spectrogram parameters used to create this plan.
193    ///
194    /// # Returns
195    ///
196    /// A reference to the `SpectrogramParams` used in this plan.
197    #[inline]
198    #[must_use]
199    pub const fn params(&self) -> &SpectrogramParams {
200        &self.params
201    }
202
203    /// Get the frequency axis for this spectrogram plan.
204    ///
205    /// # Returns
206    ///
207    /// A reference to the `FrequencyAxis<FreqScale>` used in this plan.
208    #[inline]
209    #[must_use]
210    pub const fn freq_axis(&self) -> &FrequencyAxis<FreqScale> {
211        &self.freq_axis
212    }
213
214    /// Compute a spectrogram for a mono signal.
215    ///
216    /// This function performs:
217    /// - framing + windowing
218    /// - FFT per frame
219    /// - magnitude/power
220    /// - frequency mapping (identity/mel/etc.)
221    /// - amplitude scaling (linear or dB)
222    ///
223    /// It allocates the output `Array2` once, but does not allocate per-frame.
224    ///
225    /// # Arguments
226    ///
227    /// * `samples` - Audio samples
228    ///
229    /// # Returns
230    ///
231    /// A `Spectrogram<FreqScale, AmpScale>` containing the computed spectrogram.
232    ///
233    /// # Errors
234    ///
235    /// Returns an error if STFT computation or mapping fails.
236    #[inline]
237    pub fn compute(
238        &mut self,
239        samples: &NonEmptySlice<f64>,
240    ) -> SpectrogramResult<Spectrogram<FreqScale, AmpScale>> {
241        let n_frames = self.stft.frame_count(samples.len())?;
242        let n_bins = self.mapping.output_bins();
243
244        // Create output matrix: (n_bins, n_frames)
245        let mut data = Array2::<f64>::zeros((n_bins.get(), n_frames.get()));
246
247        // Ensure workspace is correctly sized
248        self.workspace
249            .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
250
251        // Main loop: fill each frame (column)
252        // TODO: Parallelize with rayon once thread-safety issues are resolved
253        for frame_idx in 0..n_frames.get() {
254            self.stft
255                .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
256
257            // mapping: spectrum(out_len) -> mapped(n_bins)
258            // For CQT, this uses workspace.frame; for others, workspace.spectrum
259            // For ERB, we need the complex FFT output (fft_out)
260            // We need to borrow workspace fields separately to avoid borrow conflicts
261            let Workspace {
262                spectrum,
263                mapped,
264                frame,
265                ..
266            } = &mut self.workspace;
267
268            self.mapping.apply(spectrum, frame, mapped)?;
269
270            // amplitude scaling in-place on mapped vector
271            self.scaling.apply_in_place(mapped)?;
272
273            // write column into output
274            for (row, &val) in mapped.iter().enumerate() {
275                data[[row, frame_idx]] = val;
276            }
277        }
278
279        let times = build_time_axis_seconds(&self.params, n_frames);
280        let axes = Axes::new(self.freq_axis.clone(), times);
281
282        Ok(Spectrogram::new(data, axes, self.params))
283    }
284
285    /// Compute a single frame of the spectrogram.
286    ///
287    /// This is useful for streaming/online processing where you want to
288    /// process audio frame-by-frame without computing the entire spectrogram.
289    ///
290    /// # Arguments
291    ///
292    /// * `samples` - Audio samples (must contain at least enough samples for the requested frame)
293    /// * `frame_idx` - Frame index to compute
294    ///
295    /// # Returns
296    ///
297    /// A vector of frequency bin values for the requested frame.
298    ///
299    /// # Errors
300    ///
301    /// Returns an error if the frame index is out of bounds or if STFT computation fails.
302    ///
303    /// # Examples
304    ///
305    /// ```
306    /// use spectrograms::*;
307    /// use non_empty_slice::non_empty_vec;
308    ///
309    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
310    /// let samples = non_empty_vec![0.0; nzu!(16000)];
311    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
312    /// let params = SpectrogramParams::new(stft, 16000.0)?;
313    ///
314    /// let planner = SpectrogramPlanner::new();
315    /// let mut plan = planner.linear_plan::<Power>(&params, None)?;
316    ///
317    /// // Compute just the first frame
318    /// let frame = plan.compute_frame(&samples, 0)?;
319    /// assert_eq!(frame.len(), nzu!(257)); // n_fft/2 + 1
320    /// # Ok(())
321    /// # }
322    /// ```
323    #[inline]
324    pub fn compute_frame(
325        &mut self,
326        samples: &NonEmptySlice<f64>,
327        frame_idx: usize,
328    ) -> SpectrogramResult<NonEmptyVec<f64>> {
329        let n_bins = self.mapping.output_bins();
330
331        // Ensure workspace is correctly sized
332        self.workspace
333            .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
334
335        // Compute frame spectrum
336        self.stft
337            .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
338
339        // Apply mapping (using split borrows to avoid borrow conflicts)
340        let Workspace {
341            spectrum,
342            mapped,
343            frame,
344            ..
345        } = &mut self.workspace;
346
347        self.mapping.apply(spectrum, frame, mapped)?;
348
349        // Apply amplitude scaling
350        self.scaling.apply_in_place(mapped)?;
351
352        Ok(mapped.clone())
353    }
354
355    /// Compute spectrogram into a pre-allocated buffer.
356    ///
357    /// This avoids allocating the output matrix, which is useful when
358    /// you want to reuse buffers or have strict memory requirements.
359    ///
360    /// # Arguments
361    ///
362    /// * `samples` - Audio samples
363    /// * `output` - Pre-allocated output matrix (must be correct size: `n_bins` x `n_frames`)
364    ///
365    /// # Returns
366    ///
367    /// An empty result on success.
368    ///
369    /// # Errors
370    ///
371    /// Returns an error if the output buffer dimensions don't match the expected size.
372    ///
373    /// # Examples
374    ///
375    /// ```
376    /// use spectrograms::*;
377    /// use ndarray::Array2;
378    /// use non_empty_slice::non_empty_vec;
379    ///
380    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
381    /// let samples = non_empty_vec![0.0; nzu!(16000)];
382    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
383    /// let params = SpectrogramParams::new(stft, 16000.0)?;
384    ///
385    /// let planner = SpectrogramPlanner::new();
386    /// let mut plan = planner.linear_plan::<Power>(&params, None)?;
387    ///
388    /// // Pre-allocate output buffer
389    /// let mut output = Array2::<f64>::zeros((257, 63));
390    /// plan.compute_into(&samples, &mut output)?;
391    /// # Ok(())
392    /// # }
393    /// ```
394    #[inline]
395    pub fn compute_into(
396        &mut self,
397        samples: &NonEmptySlice<f64>,
398        output: &mut Array2<f64>,
399    ) -> SpectrogramResult<()> {
400        let n_frames = self.stft.frame_count(samples.len())?;
401        let n_bins = self.mapping.output_bins();
402
403        // Validate output dimensions
404        if output.nrows() != n_bins.get() {
405            return Err(SpectrogramError::dimension_mismatch(
406                n_bins.get(),
407                output.nrows(),
408            ));
409        }
410        if output.ncols() != n_frames.get() {
411            return Err(SpectrogramError::dimension_mismatch(
412                n_frames.get(),
413                output.ncols(),
414            ));
415        }
416
417        // Ensure workspace is correctly sized
418        self.workspace
419            .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
420
421        // Main loop: fill each frame (column)
422        for frame_idx in 0..n_frames.get() {
423            self.stft
424                .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
425
426            // mapping: spectrum(out_len) -> mapped(n_bins)
427            // For CQT, this uses workspace.frame; for others, workspace.spectrum
428            // For ERB, we need the complex FFT output (fft_out)
429            // We need to borrow workspace fields separately to avoid borrow conflicts
430            let Workspace {
431                spectrum,
432                mapped,
433                frame,
434                ..
435            } = &mut self.workspace;
436
437            self.mapping.apply(spectrum, frame, mapped)?;
438
439            // amplitude scaling in-place on mapped vector
440            self.scaling.apply_in_place(mapped)?;
441
442            // write column into output
443            for (row, &val) in mapped.iter().enumerate() {
444                output[[row, frame_idx]] = val;
445            }
446        }
447
448        Ok(())
449    }
450
451    /// Get the expected output dimensions for a given signal length.
452    ///
453    /// # Arguments
454    ///
455    /// * `signal_length` - Length of the input signal in samples.
456    ///
457    /// # Returns
458    ///
459    /// A tuple `(n_bins, n_frames)` representing the output spectrogram shape.
460    ///
461    /// # Errors
462    ///
463    /// Returns an error if the signal length is too short to produce any frames.
464    ///
465    /// # Examples
466    ///
467    /// ```
468    /// use spectrograms::*;
469    ///
470    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
471    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
472    /// let params = SpectrogramParams::new(stft, 16000.0)?;
473    ///
474    /// let planner = SpectrogramPlanner::new();
475    /// let plan = planner.linear_plan::<Power>(&params, None)?;
476    ///
477    /// let (n_bins, n_frames) = plan.output_shape(nzu!(16000))?;
478    /// assert_eq!(n_bins, nzu!(257));
479    /// assert_eq!(n_frames, nzu!(63));
480    /// # Ok(())
481    /// # }
482    /// ```
483    #[inline]
484    pub fn output_shape(
485        &self,
486        signal_length: NonZeroUsize,
487    ) -> SpectrogramResult<(NonZeroUsize, NonZeroUsize)> {
488        let n_frames = self.stft.frame_count(signal_length)?;
489        let n_bins = self.mapping.output_bins();
490        Ok((n_bins, n_frames))
491    }
492}
493
494/// STFT (Short-Time Fourier Transform) result containing complex frequency bins.
495///
496/// This is the raw STFT output before any frequency mapping or amplitude scaling.
497///
498/// # Fields
499///
500/// - `data`: Complex STFT matrix with shape (`frequency_bins`, `time_frames`)
501/// - `frequencies`: Frequency axis in Hz
502/// - `sample_rate`: Sample rate in Hz
503/// - `params`: STFT computation parameters
504#[derive(Debug, Clone)]
505#[non_exhaustive]
506pub struct StftResult {
507    /// Complex STFT matrix with shape (`frequency_bins`, `time_frames`)
508    pub data: Array2<Complex<f64>>,
509    /// Frequency axis in Hz
510    pub frequencies: NonEmptyVec<f64>,
511    /// Sample rate in Hz
512    pub sample_rate: f64,
513    pub params: StftParams,
514}
515
516impl StftResult {
517    /// Get the number of frequency bins.
518    ///
519    /// # Returns
520    ///
521    /// Number of frequency bins in the STFT result.
522    #[inline]
523    #[must_use]
524    pub fn n_bins(&self) -> NonZeroUsize {
525        // safety: nrows() > 0 for NonEmptyVec
526        unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
527    }
528
529    /// Get the number of time frames.
530    ///
531    /// # Returns
532    ///
533    /// Number of time frames in the STFT result.
534    #[inline]
535    #[must_use]
536    pub fn n_frames(&self) -> NonZeroUsize {
537        // safety: ncols() > 0 for NonEmptyVec
538        unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
539    }
540
541    /// Get the frequency resolution in Hz
542    ///
543    /// # Returns
544    ///
545    /// Frequency bin width in Hz.
546    #[inline]
547    #[must_use]
548    pub fn frequency_resolution(&self) -> f64 {
549        self.sample_rate / self.params.n_fft().get() as f64
550    }
551
552    /// Get the time resolution in seconds.
553    ///
554    /// # Returns
555    ///
556    /// Time between successive frames in seconds.
557    #[inline]
558    #[must_use]
559    pub fn time_resolution(&self) -> f64 {
560        self.params.hop_size().get() as f64 / self.sample_rate
561    }
562
563    /// Normalizes self.data to remove the complex aspect of it.
564    ///
565    /// # Returns
566    ///
567    /// An Array2<f64> containing the norms of each complex number in self.data.
568    #[inline]
569    pub fn norm(&self) -> Array2<f64> {
570        self.as_ref().mapv(Complex::norm)
571    }
572}
573
574impl AsRef<Array2<Complex<f64>>> for StftResult {
575    #[inline]
576    fn as_ref(&self) -> &Array2<Complex<f64>> {
577        &self.data
578    }
579}
580
581impl AsMut<Array2<Complex<f64>>> for StftResult {
582    #[inline]
583    fn as_mut(&mut self) -> &mut Array2<Complex<f64>> {
584        &mut self.data
585    }
586}
587
588impl Deref for StftResult {
589    type Target = Array2<Complex<f64>>;
590
591    #[inline]
592    fn deref(&self) -> &Self::Target {
593        &self.data
594    }
595}
596
597impl DerefMut for StftResult {
598    #[inline]
599    fn deref_mut(&mut self) -> &mut Self::Target {
600        &mut self.data
601    }
602}
603
604/// A planner is an object that can build spectrogram plans.
605///
606/// In your design, this is where:
607/// - FFT plans are created
608/// - mapping matrices are compiled
609/// - axes are computed
610///
611/// This allows you to keep plan building separate from the output types.
612#[derive(Debug, Default)]
613#[non_exhaustive]
614pub struct SpectrogramPlanner;
615
616impl SpectrogramPlanner {
617    /// Create a new spectrogram planner.
618    ///
619    /// # Returns
620    ///
621    /// A new `SpectrogramPlanner` instance.
622    #[inline]
623    #[must_use]
624    pub const fn new() -> Self {
625        Self
626    }
627
628    /// Compute the Short-Time Fourier Transform (STFT) of a signal.
629    ///
630    /// This returns the raw complex STFT matrix before any frequency mapping
631    /// or amplitude scaling. Useful for applications that need the full complex
632    /// spectrum or custom processing.
633    ///
634    /// # Arguments
635    ///
636    /// * `samples` - Audio samples (any type that can be converted to a slice)
637    /// * `params` - STFT computation parameters
638    ///
639    /// # Returns
640    ///
641    /// An `StftResult` containing the complex STFT matrix and metadata.
642    ///
643    /// # Errors
644    ///
645    /// Returns an error if STFT computation fails.
646    ///
647    /// # Examples
648    ///
649    /// ```
650    /// use spectrograms::*;
651    /// use non_empty_slice::non_empty_vec;
652    ///
653    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
654    /// let samples = non_empty_vec![0.0; nzu!(16000)];
655    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
656    /// let params = SpectrogramParams::new(stft, 16000.0)?;
657    ///
658    /// let planner = SpectrogramPlanner::new();
659    /// let stft_result = planner.compute_stft(&samples, &params)?;
660    ///
661    /// println!("STFT: {} bins x {} frames", stft_result.n_bins(), stft_result.n_frames());
662    /// # Ok(())
663    /// # }
664    /// ```
665    ///
666    /// # Performance Note
667    ///
668    /// This method creates a new FFT plan each time. For processing multiple
669    /// signals, create a reusable plan with `StftPlan::new()` instead.
670    ///
671    /// # Examples
672    ///
673    /// ```rust
674    /// use spectrograms::*;
675    /// use non_empty_slice::non_empty_vec;
676    ///
677    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
678    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
679    /// let params = SpectrogramParams::new(stft, 16000.0)?;
680    ///
681    /// // One-shot (convenient)
682    /// let planner = SpectrogramPlanner::new();
683    /// let stft_result = planner.compute_stft(&non_empty_vec![0.0; nzu!(16000)], &params)?;
684    ///
685    /// // Reusable plan (efficient for batch)
686    /// let mut plan = StftPlan::new(&params)?;
687    /// for signal in &[non_empty_vec![0.0; nzu!(16000)], non_empty_vec![1.0; nzu!(16000)]] {
688    ///     let stft = plan.compute(&signal, &params)?;
689    /// }
690    /// # Ok(())
691    /// # }
692    /// ```
693    #[inline]
694    pub fn compute_stft(
695        &self,
696        samples: &NonEmptySlice<f64>,
697        params: &SpectrogramParams,
698    ) -> SpectrogramResult<StftResult> {
699        let mut plan = StftPlan::new(params)?;
700        plan.compute(samples, params)
701    }
702
703    /// Compute the power spectrum of a single audio frame.
704    ///
705    /// This is useful for real-time processing or analyzing individual frames.
706    ///
707    /// # Arguments
708    ///
709    /// * `samples` - Audio frame (length ≤ n_fft, will be zero-padded if shorter)
710    /// * `n_fft` - FFT size
711    /// * `window` - Window type to apply
712    ///
713    /// # Returns
714    ///
715    /// A vector of power values (|X|²) with length `n_fft/2` + 1.
716    ///
717    /// # Automatic Zero-Padding
718    ///
719    /// If the input signal is shorter than `n_fft`, it will be automatically
720    /// zero-padded to the required length.
721    ///
722    /// # Errors
723    ///
724    /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
725    ///
726    /// # Examples
727    ///
728    /// ```
729    /// use spectrograms::*;
730    /// use non_empty_slice::non_empty_vec;
731    ///
732    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
733    /// let frame = non_empty_vec![0.0; nzu!(512)];
734    ///
735    /// let planner = SpectrogramPlanner::new();
736    /// let power = planner.compute_power_spectrum(frame.as_ref(), nzu!(512), WindowType::Hanning)?;
737    ///
738    /// assert_eq!(power.len(), nzu!(257)); // 512/2 + 1
739    /// # Ok(())
740    /// # }
741    /// ```
742    #[inline]
743    pub fn compute_power_spectrum(
744        &self,
745        samples: &NonEmptySlice<f64>,
746        n_fft: NonZeroUsize,
747        window: WindowType,
748    ) -> SpectrogramResult<NonEmptyVec<f64>> {
749        if samples.len() > n_fft {
750            return Err(SpectrogramError::invalid_input(format!(
751                "Input length ({}) exceeds FFT size ({})",
752                samples.len(),
753                n_fft
754            )));
755        }
756
757        let window_samples = make_window(window, n_fft);
758        let out_len = r2c_output_size(n_fft.get());
759
760        // Create FFT plan
761        #[cfg(feature = "realfft")]
762        let mut fft = {
763            let mut planner = crate::RealFftPlanner::new();
764            let plan = planner.get_or_create(n_fft.get());
765            crate::RealFftPlan::new(n_fft.get(), plan)
766        };
767
768        #[cfg(feature = "fftw")]
769        let mut fft = {
770            use std::sync::Arc;
771            let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
772            crate::FftwPlan::new(Arc::new(plan))
773        };
774
775        // Apply window and compute FFT
776        let mut windowed = vec![0.0; n_fft.get()];
777        for i in 0..samples.len().get() {
778            windowed[i] = samples[i] * window_samples[i];
779        }
780        // The rest is already zero-padded
781        let mut fft_out = vec![Complex::new(0.0, 0.0); out_len];
782        fft.process(&windowed, &mut fft_out)?;
783
784        // Convert to power
785        let power: Vec<f64> = fft_out.iter().map(num_complex::Complex::norm_sqr).collect();
786        // safety: power is non-empty since n_fft > 0
787        Ok(unsafe { NonEmptyVec::new_unchecked(power) })
788    }
789
790    /// Compute the magnitude spectrum of a single audio frame.
791    ///
792    /// This is useful for real-time processing or analyzing individual frames.
793    ///
794    /// # Arguments
795    ///
796    /// * `samples` - Audio frame (length ≤ n_fft, will be zero-padded if shorter)
797    /// * `n_fft` - FFT size
798    /// * `window` - Window type to apply
799    ///
800    /// # Returns
801    ///
802    /// A vector of magnitude values (|X|) with length `n_fft/2` + 1.
803    ///
804    /// # Automatic Zero-Padding
805    ///
806    /// If the input signal is shorter than `n_fft`, it will be automatically
807    /// zero-padded to the required length.
808    ///
809    /// # Errors
810    ///
811    /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
812    ///
813    /// # Examples
814    ///
815    /// ```
816    /// use spectrograms::*;
817    /// use non_empty_slice::non_empty_vec;
818    ///
819    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
820    /// let frame = non_empty_vec![0.0; nzu!(512)];
821    ///
822    /// let planner = SpectrogramPlanner::new();
823    /// let magnitude = planner.compute_magnitude_spectrum(frame.as_ref(), nzu!(512), WindowType::Hanning)?;
824    ///
825    /// assert_eq!(magnitude.len(), nzu!(257)); // 512/2 + 1
826    /// # Ok(())
827    /// # }
828    /// ```
829    #[inline]
830    pub fn compute_magnitude_spectrum(
831        &self,
832        samples: &NonEmptySlice<f64>,
833        n_fft: NonZeroUsize,
834        window: WindowType,
835    ) -> SpectrogramResult<NonEmptyVec<f64>> {
836        let power = self.compute_power_spectrum(samples, n_fft, window)?;
837        let power = power.iter().map(|&p| p.sqrt()).collect::<Vec<f64>>();
838        // safety: power is non-empty since power_spectrum returned successfully
839        Ok(unsafe { NonEmptyVec::new_unchecked(power) })
840    }
841
842    /// Build a linear-frequency spectrogram plan.
843    ///
844    /// # Type Parameters
845    ///
846    /// `AmpScale` determines whether output is:
847    /// - Magnitude
848    /// - Power
849    /// - Decibels
850    ///
851    /// # Arguments
852    ///
853    /// * `params` - Spectrogram parameters
854    /// * `db` - Logarithmic scaling parameters (only used if `AmpScale
855    /// is `Decibels`)
856    ///
857    /// # Returns
858    ///
859    /// A `SpectrogramPlan` configured for linear-frequency spectrogram computation.
860    ///
861    /// # Errors
862    ///
863    /// Returns an error if the plan cannot be created due to invalid parameters.
864    #[inline]
865    pub fn linear_plan<AmpScale>(
866        &self,
867        params: &SpectrogramParams,
868        db: Option<&LogParams>, // only used when AmpScale = Decibels
869    ) -> SpectrogramResult<SpectrogramPlan<LinearHz, AmpScale>>
870    where
871        AmpScale: AmpScaleSpec + 'static,
872    {
873        let stft = StftPlan::new(params)?;
874        let mapping = FrequencyMapping::<LinearHz>::new(params)?;
875        let scaling = AmplitudeScaling::<AmpScale>::new(db);
876        let freq_axis = build_frequency_axis::<LinearHz>(params, &mapping);
877
878        let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
879
880        Ok(SpectrogramPlan {
881            params: *params,
882            stft,
883            mapping,
884            scaling,
885            freq_axis,
886            workspace,
887            _amp: PhantomData,
888        })
889    }
890
891    /// Build a mel-frequency spectrogram plan.
892    ///
893    /// This compiles a mel filterbank matrix and caches it inside the plan.
894    ///
895    /// # Type Parameters
896    ///
897    /// `AmpScale`: determines whether output is:
898    /// - Magnitude
899    /// - Power
900    /// - Decibels
901    ///
902    /// # Arguments
903    ///
904    /// * `params` - Spectrogram parameters
905    /// * `mel` - Mel-specific parameters
906    /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
907    ///
908    /// # Returns
909    ///
910    /// A `SpectrogramPlan` configured for mel spectrogram computation.
911    ///
912    /// # Errors
913    ///
914    /// Returns an error if the plan cannot be created due to invalid parameters.
915    #[inline]
916    pub fn mel_plan<AmpScale>(
917        &self,
918        params: &SpectrogramParams,
919        mel: &MelParams,
920        db: Option<&LogParams>, // only used when AmpScale = Decibels
921    ) -> SpectrogramResult<SpectrogramPlan<Mel, AmpScale>>
922    where
923        AmpScale: AmpScaleSpec + 'static,
924    {
925        // cross-validation: mel range must be compatible with sample rate
926        let nyquist = params.nyquist_hz();
927        if mel.f_max() > nyquist {
928            return Err(SpectrogramError::invalid_input(
929                "mel f_max must be <= Nyquist",
930            ));
931        }
932
933        let stft = StftPlan::new(params)?;
934        let mapping = FrequencyMapping::<Mel>::new_mel(params, mel)?;
935        let scaling = AmplitudeScaling::<AmpScale>::new(db);
936        let freq_axis = build_frequency_axis::<Mel>(params, &mapping);
937
938        let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
939
940        Ok(SpectrogramPlan {
941            params: *params,
942            stft,
943            mapping,
944            scaling,
945            freq_axis,
946            workspace,
947            _amp: PhantomData,
948        })
949    }
950
951    /// Build an ERB-scale spectrogram plan.
952    ///
953    /// This creates a spectrogram with ERB-spaced frequency bands using gammatone
954    /// filterbank approximation in the frequency domain.
955    ///
956    /// # Type Parameters
957    ///
958    /// `AmpScale`: determines whether output is:
959    /// - Magnitude
960    /// - Power
961    /// - Decibels
962    ///
963    /// # Arguments
964    ///
965    /// * `params` - Spectrogram parameters
966    /// * `erb` - ERB-specific parameters
967    /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
968    ///
969    /// # Returns
970    ///
971    /// A `SpectrogramPlan` configured for ERB spectrogram computation.
972    ///
973    /// # Errors
974    ///
975    /// Returns an error if the plan cannot be created due to invalid parameters.
976    #[inline]
977    pub fn erb_plan<AmpScale>(
978        &self,
979        params: &SpectrogramParams,
980        erb: &ErbParams,
981        db: Option<&LogParams>,
982    ) -> SpectrogramResult<SpectrogramPlan<Erb, AmpScale>>
983    where
984        AmpScale: AmpScaleSpec + 'static,
985    {
986        // cross-validation: erb range must be compatible with sample rate
987        let nyquist = params.nyquist_hz();
988        if erb.f_max() > nyquist {
989            return Err(SpectrogramError::invalid_input(format!(
990                "f_max={} exceeds Nyquist={}",
991                erb.f_max(),
992                nyquist
993            )));
994        }
995
996        let stft = StftPlan::new(params)?;
997        let mapping = FrequencyMapping::<Erb>::new_erb(params, erb)?;
998        let scaling = AmplitudeScaling::<AmpScale>::new(db);
999        let freq_axis = build_frequency_axis::<Erb>(params, &mapping);
1000
1001        let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1002
1003        Ok(SpectrogramPlan {
1004            params: *params,
1005            stft,
1006            mapping,
1007            scaling,
1008            freq_axis,
1009            workspace,
1010            _amp: PhantomData,
1011        })
1012    }
1013
1014    /// Build a log-frequency plan.
1015    ///
1016    /// This creates a spectrogram with logarithmically-spaced frequency bins.
1017    ///
1018    /// # Type Parameters
1019    ///
1020    /// `AmpScale`: determines whether output is:
1021    /// - Magnitude
1022    /// - Power
1023    /// - Decibels
1024    ///
1025    /// # Arguments
1026    ///
1027    /// * `params` - Spectrogram parameters
1028    /// * `loghz` - LogHz-specific parameters
1029    /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
1030    ///
1031    /// # Returns
1032    ///
1033    /// A `SpectrogramPlan` configured for log-frequency spectrogram computation.
1034    ///
1035    /// # Errors
1036    ///
1037    /// Returns an error if the plan cannot be created due to invalid parameters.
1038    #[inline]
1039    pub fn log_hz_plan<AmpScale>(
1040        &self,
1041        params: &SpectrogramParams,
1042        loghz: &LogHzParams,
1043        db: Option<&LogParams>,
1044    ) -> SpectrogramResult<SpectrogramPlan<LogHz, AmpScale>>
1045    where
1046        AmpScale: AmpScaleSpec + 'static,
1047    {
1048        // cross-validation: loghz range must be compatible with sample rate
1049        let nyquist = params.nyquist_hz();
1050        if loghz.f_max() > nyquist {
1051            return Err(SpectrogramError::invalid_input(format!(
1052                "f_max={} exceeds Nyquist={}",
1053                loghz.f_max(),
1054                nyquist
1055            )));
1056        }
1057
1058        let stft = StftPlan::new(params)?;
1059        let mapping = FrequencyMapping::<LogHz>::new_loghz(params, loghz)?;
1060        let scaling = AmplitudeScaling::<AmpScale>::new(db);
1061        let freq_axis = build_frequency_axis::<LogHz>(params, &mapping);
1062
1063        let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1064
1065        Ok(SpectrogramPlan {
1066            params: *params,
1067            stft,
1068            mapping,
1069            scaling,
1070            freq_axis,
1071            workspace,
1072            _amp: PhantomData,
1073        })
1074    }
1075
1076    /// Build a cqt spectrogram plan.
1077    ///
1078    /// # Type Parameters
1079    ///
1080    /// `AmpScale`: determines whether output is:
1081    /// - Magnitude
1082    /// - Power
1083    /// - Decibels
1084    ///
1085    /// # Arguments
1086    ///
1087    /// * `params` - Spectrogram parameters
1088    /// * `cqt` - CQT-specific parameters
1089    /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
1090    ///
1091    /// # Returns
1092    ///
1093    /// A `SpectrogramPlan` configured for CQT spectrogram computation.
1094    ///
1095    /// # Errors
1096    ///
1097    /// Returns an error if the plan cannot be created due to invalid parameters.
1098    #[inline]
1099    pub fn cqt_plan<AmpScale>(
1100        &self,
1101        params: &SpectrogramParams,
1102        cqt: &CqtParams,
1103        db: Option<&LogParams>, // only used when AmpScale = Decibels
1104    ) -> SpectrogramResult<SpectrogramPlan<Cqt, AmpScale>>
1105    where
1106        AmpScale: AmpScaleSpec + 'static,
1107    {
1108        let stft = StftPlan::new(params)?;
1109        let mapping = FrequencyMapping::<Cqt>::new(params, cqt)?;
1110        let scaling = AmplitudeScaling::<AmpScale>::new(db);
1111        let freq_axis = build_frequency_axis::<Cqt>(params, &mapping);
1112
1113        let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1114
1115        Ok(SpectrogramPlan {
1116            params: *params,
1117            stft,
1118            mapping,
1119            scaling,
1120            freq_axis,
1121            workspace,
1122            _amp: PhantomData,
1123        })
1124    }
1125}
1126
1127/// STFT plan containing reusable FFT plan and buffers.
1128///
1129/// This struct is responsible for performing the Short-Time Fourier Transform (STFT)
1130/// on audio signals based on the provided parameters.
1131///
1132/// It encapsulates the FFT plan, windowing function, and internal buffers to efficiently
1133/// compute the STFT for multiple frames of audio data.
1134///
1135/// # Fields
1136///
1137/// - `n_fft`: Size of the FFT.
1138/// - `hop_size`: Hop size between consecutive frames.
1139/// - `window`: Windowing function samples.
1140/// - `centre`: Whether to centre the frames with padding.
1141/// - `out_len`: Length of the FFT output.
1142/// - `fft`: Boxed FFT plan for real-to-complex transformation.
1143/// - `fft_out`: Internal buffer for FFT output.
1144/// - `frame`: Internal buffer for windowed audio frames.
1145pub struct StftPlan {
1146    n_fft: NonZeroUsize,
1147    hop_size: NonZeroUsize,
1148    window: NonEmptyVec<f64>,
1149    centre: bool,
1150
1151    out_len: NonZeroUsize,
1152
1153    // FFT plan (reused for all frames)
1154    fft: Box<dyn R2cPlan>,
1155
1156    // internal scratch
1157    fft_out: NonEmptyVec<Complex<f64>>,
1158    frame: NonEmptyVec<f64>,
1159}
1160
1161impl StftPlan {
1162    /// Create a new STFT plan from parameters.
1163    ///
1164    /// # Arguments
1165    ///
1166    /// * `params` - Spectrogram parameters containing STFT config
1167    ///
1168    /// # Returns
1169    ///
1170    /// A new `StftPlan` instance.
1171    ///
1172    /// # Errors
1173    ///
1174    /// Returns an error if the FFT plan cannot be created.
1175    #[inline]
1176    pub fn new(params: &SpectrogramParams) -> SpectrogramResult<Self> {
1177        let stft = params.stft();
1178        let n_fft = stft.n_fft();
1179        let hop_size = stft.hop_size();
1180        let centre = stft.centre();
1181
1182        let window = make_window(stft.window(), n_fft);
1183
1184        let out_len = r2c_output_size(n_fft.get());
1185        let out_len = NonZeroUsize::new(out_len)
1186            .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1187
1188        #[cfg(feature = "realfft")]
1189        let fft = {
1190            let mut planner = crate::RealFftPlanner::new();
1191            let plan = planner.get_or_create(n_fft.get());
1192            let plan = crate::RealFftPlan::new(n_fft.get(), plan);
1193            Box::new(plan)
1194        };
1195
1196        #[cfg(feature = "fftw")]
1197        let fft = {
1198            use std::sync::Arc;
1199            let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
1200            Box::new(crate::FftwPlan::new(Arc::new(plan)))
1201        };
1202
1203        Ok(Self {
1204            n_fft,
1205            hop_size,
1206            window,
1207            centre,
1208            out_len,
1209            fft,
1210            fft_out: non_empty_vec![Complex::new(0.0, 0.0); out_len],
1211            frame: non_empty_vec![0.0; n_fft],
1212        })
1213    }
1214
1215    fn frame_count(&self, n_samples: NonZeroUsize) -> SpectrogramResult<NonZeroUsize> {
1216        // Framing policy:
1217        // - centre = true: implicit padding of n_fft/2 on both sides
1218        // - centre = false: no padding
1219        //
1220        // Define the number of frames such that each frame has a valid centre sample position.
1221        let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1222        let padded_len = n_samples.get() + 2 * pad;
1223
1224        if padded_len < self.n_fft.get() {
1225            // still produce 1 frame (all padding / partial)
1226            return Ok(nzu!(1));
1227        }
1228
1229        let remaining = padded_len - self.n_fft.get();
1230        let n_frames = remaining / self.hop_size().get() + 1;
1231        let n_frames = NonZeroUsize::new(n_frames).ok_or_else(|| {
1232            SpectrogramError::invalid_input("computed number of frames must be non-zero")
1233        })?;
1234        Ok(n_frames)
1235    }
1236
1237    /// Compute one frame FFT using internal buffers only.
1238    fn compute_frame_fft_simple(
1239        &mut self,
1240        samples: &NonEmptySlice<f64>,
1241        frame_idx: usize,
1242    ) -> SpectrogramResult<()> {
1243        let out = self.frame.as_mut_slice();
1244        debug_assert_eq!(out.len(), self.n_fft.get());
1245
1246        let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1247        let start = frame_idx
1248            .checked_mul(self.hop_size.get())
1249            .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1250
1251        // Fill windowed frame
1252        for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1253            let v_idx = start + i;
1254            let s_idx = v_idx as isize - pad as isize;
1255
1256            let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1257                0.0
1258            } else {
1259                samples[s_idx as usize]
1260            };
1261            *sample = sample_val * self.window[i];
1262        }
1263
1264        // Compute FFT
1265        let fft_out = self.fft_out.as_mut_slice();
1266        self.fft.process(out, fft_out)?;
1267
1268        Ok(())
1269    }
1270
1271    /// Compute one frame spectrum into workspace:
1272    /// - fills windowed frame
1273    /// - runs FFT
1274    /// - converts to magnitude/power based on `AmpScale` later
1275    fn compute_frame_spectrum(
1276        &mut self,
1277        samples: &NonEmptySlice<f64>,
1278        frame_idx: usize,
1279        workspace: &mut Workspace,
1280    ) -> SpectrogramResult<()> {
1281        let out = workspace.frame.as_mut_slice();
1282
1283        // self.fill_frame(samples, frame_idx, frame)?;
1284        debug_assert_eq!(out.len(), self.n_fft.get());
1285
1286        let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1287        let start = frame_idx
1288            .checked_mul(self.hop_size().get())
1289            .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1290
1291        // The "virtual" signal is samples with pad zeros on both sides.
1292        // Virtual index 0..padded_len
1293        // Map virtual index to original samples by subtracting pad.
1294        for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1295            let v_idx = start + i;
1296            let s_idx = v_idx as isize - pad as isize;
1297
1298            let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1299                0.0
1300            } else {
1301                samples[s_idx as usize]
1302            };
1303
1304            *sample = sample_val * self.window[i];
1305        }
1306        let fft_out = workspace.fft_out.as_mut_slice();
1307        // FFT
1308        self.fft.process(out, fft_out)?;
1309
1310        // Convert complex spectrum to linear magnitude OR power here? No:
1311        // Keep "spectrum" as power by default? That would entangle semantics.
1312        //
1313        // Instead, we store magnitude^2 (power) as the canonical intermediate,
1314        // and let AmpScale decide later whether output is magnitude or power.
1315        //
1316        // This is consistent and avoids recomputing norms multiple times.
1317        for (i, c) in workspace.fft_out.iter().enumerate() {
1318            workspace.spectrum[i] = c.norm_sqr();
1319        }
1320
1321        Ok(())
1322    }
1323
1324    /// Compute the full STFT for a signal, returning an `StftResult`.
1325    ///
1326    /// This is a convenience method that handles frame iteration and
1327    /// builds the complete STFT matrix.
1328    ///
1329    ///
1330    /// # Arguments
1331    ///
1332    /// * `samples` - Input audio samples
1333    /// * `params` - STFT computation parameters
1334    ///
1335    /// # Returns
1336    ///
1337    /// An `StftResult` containing the complex STFT matrix and metadata.
1338    ///
1339    /// # Errors
1340    ///
1341    /// Returns an error if computation fails.
1342    ///
1343    /// # Examples
1344    ///
1345    /// ```rust
1346    /// use spectrograms::*;
1347    /// use non_empty_slice::non_empty_vec;
1348    ///
1349    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1350    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1351    /// let params = SpectrogramParams::new(stft, 16000.0)?;
1352    /// let mut plan = StftPlan::new(&params)?;
1353    ///
1354    /// let samples = non_empty_vec![0.0; nzu!(16000)];
1355    /// let stft_result = plan.compute(&samples, &params)?;
1356    ///
1357    /// println!("STFT: {} bins x {} frames", stft_result.n_bins(), stft_result.n_frames());
1358    /// # Ok(())
1359    /// # }
1360    /// ```
1361    #[inline]
1362    pub fn compute(
1363        &mut self,
1364        samples: &NonEmptySlice<f64>,
1365        params: &SpectrogramParams,
1366    ) -> SpectrogramResult<StftResult> {
1367        let n_frames = self.frame_count(samples.len())?;
1368        let n_bins = self.out_len;
1369
1370        // Allocate output matrix (frequency_bins x time_frames)
1371        let mut data = Array2::<Complex<f64>>::zeros((n_bins.get(), n_frames.get()));
1372
1373        // Compute each frame
1374        for frame_idx in 0..n_frames.get() {
1375            self.compute_frame_fft_simple(samples, frame_idx)?;
1376
1377            // Copy from internal buffer to output
1378            for (bin_idx, &value) in self.fft_out.iter().enumerate() {
1379                data[[bin_idx, frame_idx]] = value;
1380            }
1381        }
1382
1383        // Build frequency axis
1384        let frequencies: Vec<f64> = (0..n_bins.get())
1385            .map(|k| k as f64 * params.sample_rate_hz() / params.stft().n_fft().get() as f64)
1386            .collect();
1387        // SAFETY: n_bins > 0
1388        let frequencies = unsafe { NonEmptyVec::new_unchecked(frequencies) };
1389
1390        Ok(StftResult {
1391            data,
1392            frequencies,
1393            sample_rate: params.sample_rate_hz(),
1394            params: *params.stft(),
1395        })
1396    }
1397
1398    /// Compute a single frame of STFT, returning the complex spectrum.
1399    ///
1400    /// This is useful for streaming/online processing where you want to
1401    /// process audio frame-by-frame.
1402    ///
1403    /// # Arguments
1404    ///
1405    /// * `samples` - Input audio samples
1406    /// * `frame_idx` - Index of the frame to compute
1407    ///
1408    /// # Returns
1409    ///
1410    /// A `NonEmptyVec` containing the complex spectrum for the specified frame.
1411    ///
1412    /// # Errors
1413    ///
1414    /// Returns an error if computation fails.
1415    ///
1416    /// # Examples
1417    ///
1418    /// ```rust
1419    /// use spectrograms::*;
1420    /// use non_empty_slice::non_empty_vec;
1421    ///
1422    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1423    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1424    /// let params = SpectrogramParams::new(stft, 16000.0)?;
1425    /// let mut plan = StftPlan::new(&params)?;
1426    ///
1427    /// let samples = non_empty_vec![0.0; nzu!(16000)];
1428    /// let (_, n_frames) = plan.output_shape(samples.len())?;
1429    ///
1430    /// for frame_idx in 0..n_frames.get() {
1431    ///     let spectrum = plan.compute_frame_simple(&samples, frame_idx)?;
1432    ///     // Process spectrum...
1433    /// }
1434    /// # Ok(())
1435    /// # }
1436    /// ```
1437    #[inline]
1438    pub fn compute_frame_simple(
1439        &mut self,
1440        samples: &NonEmptySlice<f64>,
1441        frame_idx: usize,
1442    ) -> SpectrogramResult<NonEmptyVec<Complex<f64>>> {
1443        self.compute_frame_fft_simple(samples, frame_idx)?;
1444        Ok(self.fft_out.clone())
1445    }
1446
1447    /// Compute STFT into a pre-allocated buffer.
1448    ///
1449    /// This avoids allocating the output matrix, useful for reusing buffers.
1450    ///
1451    /// # Arguments
1452    ///
1453    /// * `samples` - Input audio samples
1454    /// * `output` - Pre-allocated output buffer (shape: `n_bins` x `n_frames`)
1455    ///  
1456    /// # Returns
1457    ///
1458    /// `Ok(())` on success, or an error if dimensions mismatch.
1459    ///
1460    /// # Errors
1461    ///
1462    /// Returns `DimensionMismatch` error if the output buffer has incorrect shape.
1463    ///
1464    /// # Examples
1465    ///
1466    /// ```rust
1467    /// use spectrograms::*;
1468    /// use ndarray::Array2;
1469    /// use num_complex::Complex;
1470    /// use non_empty_slice::non_empty_vec;
1471    ///
1472    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1473    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1474    /// let params = SpectrogramParams::new(stft, 16000.0)?;
1475    /// let mut plan = StftPlan::new(&params)?;
1476    ///
1477    /// let samples = non_empty_vec![0.0; nzu!(16000)];
1478    /// let (n_bins, n_frames) = plan.output_shape(samples.len())?;
1479    /// let mut output = Array2::<Complex<f64>>::zeros((n_bins.get(), n_frames.get()));
1480    ///
1481    /// plan.compute_into(&samples, &mut output)?;
1482    /// # Ok(())
1483    /// # }
1484    /// ```
1485    #[inline]
1486    pub fn compute_into(
1487        &mut self,
1488        samples: &NonEmptySlice<f64>,
1489        output: &mut Array2<Complex<f64>>,
1490    ) -> SpectrogramResult<()> {
1491        let n_frames = self.frame_count(samples.len())?;
1492        let n_bins = self.out_len;
1493
1494        // Validate output dimensions
1495        if output.nrows() != n_bins.get() {
1496            return Err(SpectrogramError::dimension_mismatch(
1497                n_bins.get(),
1498                output.nrows(),
1499            ));
1500        }
1501        if output.ncols() != n_frames.get() {
1502            return Err(SpectrogramError::dimension_mismatch(
1503                n_frames.get(),
1504                output.ncols(),
1505            ));
1506        }
1507
1508        // Compute into pre-allocated buffer
1509        for frame_idx in 0..n_frames.get() {
1510            self.compute_frame_fft_simple(samples, frame_idx)?;
1511
1512            for (bin_idx, &value) in self.fft_out.iter().enumerate() {
1513                output[[bin_idx, frame_idx]] = value;
1514            }
1515        }
1516
1517        Ok(())
1518    }
1519
1520    /// Get the expected output dimensions for a given signal length.
1521    ///
1522    /// # Arguments
1523    ///
1524    /// * `signal_length` - Length of the input signal in samples
1525    ///
1526    /// # Returns
1527    ///
1528    /// A tuple `(n_frequency_bins, n_time_frames)`.
1529    ///
1530    /// # Errors
1531    ///
1532    /// Returns an error if the computed number of frames is invalid.
1533    #[inline]
1534    pub fn output_shape(
1535        &self,
1536        signal_length: NonZeroUsize,
1537    ) -> SpectrogramResult<(NonZeroUsize, NonZeroUsize)> {
1538        let n_frames = self.frame_count(signal_length)?;
1539        Ok((self.out_len, n_frames))
1540    }
1541
1542    /// Get the number of frequency bins in the output.
1543    ///
1544    /// # Returns
1545    ///
1546    /// The number of frequency bins.
1547    #[inline]
1548    #[must_use]
1549    pub const fn n_bins(&self) -> NonZeroUsize {
1550        self.out_len
1551    }
1552
1553    /// Get the FFT size.
1554    ///
1555    /// # Returns
1556    ///
1557    /// The FFT size.
1558    #[inline]
1559    #[must_use]
1560    pub const fn n_fft(&self) -> NonZeroUsize {
1561        self.n_fft
1562    }
1563
1564    /// Get the hop size.
1565    ///
1566    /// # Returns
1567    ///
1568    /// The hop size.
1569    #[inline]
1570    #[must_use]
1571    pub const fn hop_size(&self) -> NonZeroUsize {
1572        self.hop_size
1573    }
1574}
1575
1576#[derive(Debug, Clone)]
1577enum MappingKind {
1578    Identity {
1579        out_len: NonZeroUsize,
1580    },
1581    Mel {
1582        matrix: SparseMatrix,
1583    }, // shape: (n_mels, out_len)
1584    LogHz {
1585        matrix: SparseMatrix,
1586        frequencies: NonEmptyVec<f64>,
1587    }, // shape: (n_bins, out_len)
1588    Erb {
1589        filterbank: ErbFilterbank,
1590    },
1591    Cqt {
1592        kernel: CqtKernel,
1593    },
1594}
1595
1596/// Typed mapping wrapper.
1597#[derive(Debug, Clone)]
1598struct FrequencyMapping<FreqScale> {
1599    kind: MappingKind,
1600    _marker: PhantomData<FreqScale>,
1601}
1602
1603impl FrequencyMapping<LinearHz> {
1604    fn new(params: &SpectrogramParams) -> SpectrogramResult<Self> {
1605        let out_len = r2c_output_size(params.stft().n_fft().get());
1606        let out_len = NonZeroUsize::new(out_len)
1607            .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1608        Ok(Self {
1609            kind: MappingKind::Identity { out_len },
1610            _marker: PhantomData,
1611        })
1612    }
1613}
1614
1615impl FrequencyMapping<Mel> {
1616    fn new_mel(params: &SpectrogramParams, mel: &MelParams) -> SpectrogramResult<Self> {
1617        let n_fft = params.stft().n_fft();
1618        let out_len = r2c_output_size(n_fft.get());
1619        let out_len = NonZeroUsize::new(out_len)
1620            .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1621
1622        // Validate: mel bins must be <= something sensible
1623        if mel.n_mels() > nzu!(10_000) {
1624            return Err(SpectrogramError::invalid_input(
1625                "n_mels is unreasonably large",
1626            ));
1627        }
1628
1629        let matrix = build_mel_filterbank_matrix(
1630            params.sample_rate_hz(),
1631            n_fft,
1632            mel.n_mels(),
1633            mel.f_min(),
1634            mel.f_max(),
1635            mel.norm(),
1636        )?;
1637
1638        // matrix must be (n_mels, out_len)
1639        if matrix.nrows() != mel.n_mels().get() || matrix.ncols() != out_len.get() {
1640            return Err(SpectrogramError::invalid_input(
1641                "mel filterbank matrix shape mismatch",
1642            ));
1643        }
1644
1645        Ok(Self {
1646            kind: MappingKind::Mel { matrix },
1647            _marker: PhantomData,
1648        })
1649    }
1650}
1651
1652impl FrequencyMapping<LogHz> {
1653    fn new_loghz(params: &SpectrogramParams, loghz: &LogHzParams) -> SpectrogramResult<Self> {
1654        let n_fft = params.stft().n_fft();
1655        let out_len = r2c_output_size(n_fft.get());
1656        let out_len = NonZeroUsize::new(out_len)
1657            .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1658        // Validate: n_bins must be <= something sensible
1659        if loghz.n_bins() > nzu!(10_000) {
1660            return Err(SpectrogramError::invalid_input(
1661                "n_bins is unreasonably large",
1662            ));
1663        }
1664
1665        let (matrix, frequencies) = build_loghz_matrix(
1666            params.sample_rate_hz(),
1667            n_fft,
1668            loghz.n_bins(),
1669            loghz.f_min(),
1670            loghz.f_max(),
1671        )?;
1672
1673        // matrix must be (n_bins, out_len)
1674        if matrix.nrows() != loghz.n_bins().get() || matrix.ncols() != out_len.get() {
1675            return Err(SpectrogramError::invalid_input(
1676                "loghz matrix shape mismatch",
1677            ));
1678        }
1679
1680        Ok(Self {
1681            kind: MappingKind::LogHz {
1682                matrix,
1683                frequencies,
1684            },
1685            _marker: PhantomData,
1686        })
1687    }
1688}
1689
1690impl FrequencyMapping<Erb> {
1691    fn new_erb(params: &SpectrogramParams, erb: &crate::erb::ErbParams) -> SpectrogramResult<Self> {
1692        let n_fft = params.stft().n_fft();
1693        let sample_rate = params.sample_rate_hz();
1694
1695        // Validate: n_filters must be <= something sensible
1696        if erb.n_filters() > nzu!(10_000) {
1697            return Err(SpectrogramError::invalid_input(
1698                "n_filters is unreasonably large",
1699            ));
1700        }
1701
1702        // Generate ERB filterbank with pre-computed frequency responses
1703        let filterbank = crate::erb::ErbFilterbank::generate(erb, sample_rate, n_fft)?;
1704
1705        Ok(Self {
1706            kind: MappingKind::Erb { filterbank },
1707            _marker: PhantomData,
1708        })
1709    }
1710}
1711
1712impl FrequencyMapping<Cqt> {
1713    fn new(params: &SpectrogramParams, cqt: &CqtParams) -> SpectrogramResult<Self> {
1714        let sample_rate = params.sample_rate_hz();
1715        let n_fft = params.stft().n_fft();
1716
1717        // Validate that frequency range is reasonable
1718        let f_max = cqt.bin_frequency(cqt.num_bins().get().saturating_sub(1));
1719        if f_max >= sample_rate / 2.0 {
1720            return Err(SpectrogramError::invalid_input(
1721                "CQT maximum frequency must be below Nyquist frequency",
1722            ));
1723        }
1724
1725        // Generate CQT kernel using n_fft as the signal length for kernel generation
1726        let kernel = CqtKernel::generate(cqt, sample_rate, n_fft);
1727
1728        Ok(Self {
1729            kind: MappingKind::Cqt { kernel },
1730            _marker: PhantomData,
1731        })
1732    }
1733}
1734
1735impl<FreqScale> FrequencyMapping<FreqScale> {
1736    const fn output_bins(&self) -> NonZeroUsize {
1737        // safety: all variants ensure output bins > 0 OR rely on a matrix that is guaranteed to have rows > 0
1738        match &self.kind {
1739            MappingKind::Identity { out_len } => *out_len,
1740            // safety: matrix.nrows() > 0
1741            MappingKind::LogHz { matrix, .. } | MappingKind::Mel { matrix } => unsafe {
1742                NonZeroUsize::new_unchecked(matrix.nrows())
1743            },
1744            MappingKind::Erb { filterbank, .. } => filterbank.num_filters(),
1745            MappingKind::Cqt { kernel, .. } => kernel.num_bins(),
1746        }
1747    }
1748
1749    fn apply(
1750        &self,
1751        spectrum: &NonEmptySlice<f64>,
1752        frame: &NonEmptySlice<f64>,
1753        out: &mut NonEmptySlice<f64>,
1754    ) -> SpectrogramResult<()> {
1755        match &self.kind {
1756            MappingKind::Identity { out_len } => {
1757                if spectrum.len() != *out_len {
1758                    return Err(SpectrogramError::dimension_mismatch(
1759                        (*out_len).get(),
1760                        spectrum.len().get(),
1761                    ));
1762                }
1763                if out.len() != *out_len {
1764                    return Err(SpectrogramError::dimension_mismatch(
1765                        (*out_len).get(),
1766                        out.len().get(),
1767                    ));
1768                }
1769                out.copy_from_slice(spectrum);
1770                Ok(())
1771            }
1772            MappingKind::LogHz { matrix, .. } | MappingKind::Mel { matrix } => {
1773                let out_bins = matrix.nrows();
1774                let in_bins = matrix.ncols();
1775
1776                if spectrum.len().get() != in_bins {
1777                    return Err(SpectrogramError::dimension_mismatch(
1778                        in_bins,
1779                        spectrum.len().get(),
1780                    ));
1781                }
1782                if out.len().get() != out_bins {
1783                    return Err(SpectrogramError::dimension_mismatch(
1784                        out_bins,
1785                        out.len().get(),
1786                    ));
1787                }
1788
1789                // Sparse matrix-vector multiplication: out = matrix * spectrum
1790                matrix.multiply_vec(spectrum, out);
1791                Ok(())
1792            }
1793            MappingKind::Erb { filterbank } => {
1794                // Apply ERB filterbank using pre-computed frequency responses
1795                // The filterbank already has |H(f)|^2 pre-computed, so we just
1796                // apply it to the power spectrum
1797                let erb_out = filterbank.apply_to_power_spectrum(spectrum)?;
1798
1799                if out.len().get() != erb_out.len().get() {
1800                    return Err(SpectrogramError::dimension_mismatch(
1801                        erb_out.len().get(),
1802                        out.len().get(),
1803                    ));
1804                }
1805
1806                out.copy_from_slice(&erb_out);
1807                Ok(())
1808            }
1809            MappingKind::Cqt { kernel } => {
1810                // CQT works on time-domain windowed frame, not FFT spectrum
1811                // Apply CQT kernel to get complex coefficients
1812                let cqt_complex = kernel.apply(frame)?;
1813
1814                if out.len().get() != cqt_complex.len().get() {
1815                    return Err(SpectrogramError::dimension_mismatch(
1816                        cqt_complex.len().get(),
1817                        out.len().get(),
1818                    ));
1819                }
1820
1821                // Convert complex coefficients to power (|z|^2)
1822                // This matches the convention where intermediate values are in power domain
1823                for (i, c) in cqt_complex.iter().enumerate() {
1824                    out[i] = c.norm_sqr();
1825                }
1826
1827                Ok(())
1828            }
1829        }
1830    }
1831
1832    fn frequencies_hz(&self, params: &SpectrogramParams) -> NonEmptyVec<f64> {
1833        match &self.kind {
1834            MappingKind::Identity { out_len } => {
1835                // Standard R2C bins: k * sr / n_fft
1836                let n_fft = params.stft().n_fft().get() as f64;
1837                let sr = params.sample_rate_hz();
1838                let df = sr / n_fft;
1839
1840                let mut f = Vec::with_capacity((*out_len).get());
1841                for k in 0..(*out_len).get() {
1842                    f.push(k as f64 * df);
1843                }
1844                // safety: out_len > 0
1845                unsafe { NonEmptyVec::new_unchecked(f) }
1846            }
1847            MappingKind::Mel { matrix } => {
1848                // For mel, the axis is defined by the mel band centre frequencies.
1849                // We compute and store them consistently with how we built the filterbank.
1850                let n_mels = matrix.nrows();
1851                // safety: n_mels > 0
1852                let n_mels = unsafe { NonZeroUsize::new_unchecked(n_mels) };
1853                mel_band_centres_hz(n_mels, params.sample_rate_hz(), params.nyquist_hz())
1854            }
1855            MappingKind::LogHz { frequencies, .. } => {
1856                // Frequencies are stored when the mapping is created
1857                frequencies.clone()
1858            }
1859            MappingKind::Erb { filterbank, .. } => {
1860                // ERB center frequencies
1861                filterbank.center_frequencies().to_non_empty_vec()
1862            }
1863            MappingKind::Cqt { kernel, .. } => {
1864                // CQT center frequencies from the kernel
1865                kernel.frequencies().to_non_empty_vec()
1866            }
1867        }
1868    }
1869}
1870
1871//
1872// ========================
1873// Amplitude scaling
1874// ========================
1875//
1876
1877/// Marker trait so we can specialise behaviour by `AmpScale`.
1878pub trait AmpScaleSpec: Sized + Send + Sync {
1879    /// Apply conversion from power-domain value to the desired amplitude scale.
1880    ///
1881    /// # Arguments
1882    ///
1883    /// - `power`: input power-domain value.
1884    ///
1885    /// # Returns
1886    ///
1887    /// Converted amplitude value.
1888    fn apply_from_power(power: f64) -> f64;
1889
1890    /// Apply dB conversion in-place on a power-domain vector.
1891    ///
1892    /// This is a no-op for Power and Magnitude scales.
1893    ///
1894    /// # Arguments
1895    ///
1896    /// - `x`: power-domain values to convert to dB in-place.
1897    /// - `floor_db`: dB floor value to apply.
1898    ///
1899    /// # Returns
1900    ///
1901    /// `SpectrogramResult<()>`: Ok on success, error on invalid input.
1902    ///
1903    /// # Errors
1904    ///
1905    /// - If `floor_db` is not finite.
1906    fn apply_db_in_place(x: &mut [f64], floor_db: f64) -> SpectrogramResult<()>;
1907}
1908
1909impl AmpScaleSpec for Power {
1910    #[inline]
1911    fn apply_from_power(power: f64) -> f64 {
1912        power
1913    }
1914
1915    #[inline]
1916    fn apply_db_in_place(_x: &mut [f64], _floor_db: f64) -> SpectrogramResult<()> {
1917        Ok(())
1918    }
1919}
1920
1921impl AmpScaleSpec for Magnitude {
1922    #[inline]
1923    fn apply_from_power(power: f64) -> f64 {
1924        power.sqrt()
1925    }
1926
1927    #[inline]
1928    fn apply_db_in_place(_x: &mut [f64], _floor_db: f64) -> SpectrogramResult<()> {
1929        Ok(())
1930    }
1931}
1932
1933impl AmpScaleSpec for Decibels {
1934    #[inline]
1935    fn apply_from_power(power: f64) -> f64 {
1936        // dB conversion is applied in batch, not here.
1937        power
1938    }
1939
1940    #[inline]
1941    fn apply_db_in_place(x: &mut [f64], floor_db: f64) -> SpectrogramResult<()> {
1942        // Convert power -> dB: 10*log10(max(power, eps))
1943        // where eps is derived from floor_db to ensure consistency
1944        if !floor_db.is_finite() {
1945            return Err(SpectrogramError::invalid_input("floor_db must be finite"));
1946        }
1947
1948        // Convert floor_db to linear scale to get epsilon
1949        // e.g., floor_db = -80 dB -> eps = 10^(-80/10) = 1e-8
1950        let eps = 10.0_f64.powf(floor_db / 10.0);
1951
1952        for v in x.iter_mut() {
1953            // Clamp power to epsilon before log to avoid log(0) and ensure floor
1954            *v = 10.0 * v.max(eps).log10();
1955        }
1956        Ok(())
1957    }
1958}
1959
1960/// Amplitude scaling configuration.
1961///
1962/// This handles conversion from power-domain intermediate to the desired amplitude scale (Power, Magnitude, Decibels).
1963#[derive(Debug, Clone)]
1964struct AmplitudeScaling<AmpScale> {
1965    db_floor: Option<f64>,
1966    _marker: PhantomData<AmpScale>,
1967}
1968
1969impl<AmpScale> AmplitudeScaling<AmpScale>
1970where
1971    AmpScale: AmpScaleSpec + 'static,
1972{
1973    fn new(db: Option<&LogParams>) -> Self {
1974        let db_floor = db.map(LogParams::floor_db);
1975        Self {
1976            db_floor,
1977            _marker: PhantomData,
1978        }
1979    }
1980
1981    /// Apply amplitude scaling in-place on a mapped spectrum vector.
1982    ///
1983    /// The input vector is assumed to be in the *power* domain (|X|^2),
1984    /// because the STFT stage produces power as the canonical intermediate.
1985    ///
1986    /// - Power: leaves values unchanged.
1987    /// - Magnitude: sqrt(power).
1988    /// - Decibels: converts power -> dB and floors at `db_floor`.
1989    pub fn apply_in_place(&self, x: &mut [f64]) -> SpectrogramResult<()> {
1990        // Convert from canonical power-domain intermediate into the requested linear domain.
1991        for v in x.iter_mut() {
1992            *v = AmpScale::apply_from_power(*v);
1993        }
1994
1995        // Apply dB conversion if configured (no-op for Power/Magnitude via trait impls).
1996        if let Some(floor_db) = self.db_floor {
1997            AmpScale::apply_db_in_place(x, floor_db)?;
1998        }
1999
2000        Ok(())
2001    }
2002}
2003
2004#[derive(Debug, Clone)]
2005struct Workspace {
2006    spectrum: NonEmptyVec<f64>,         // out_len (power spectrum)
2007    mapped: NonEmptyVec<f64>,           // n_bins (after mapping)
2008    frame: NonEmptyVec<f64>,            // n_fft (windowed frame for FFT)
2009    fft_out: NonEmptyVec<Complex<f64>>, // out_len (FFT output)
2010}
2011
2012impl Workspace {
2013    fn new(n_fft: NonZeroUsize, out_len: NonZeroUsize, n_bins: NonZeroUsize) -> Self {
2014        Self {
2015            spectrum: non_empty_vec![0.0; out_len],
2016            mapped: non_empty_vec![0.0; n_bins],
2017            frame: non_empty_vec![0.0; n_fft],
2018            fft_out: non_empty_vec![Complex::new(0.0, 0.0); out_len],
2019        }
2020    }
2021
2022    fn ensure_sizes(&mut self, n_fft: NonZeroUsize, out_len: NonZeroUsize, n_bins: NonZeroUsize) {
2023        if self.spectrum.len() != out_len {
2024            self.spectrum.resize(out_len, 0.0);
2025        }
2026        if self.mapped.len() != n_bins {
2027            self.mapped.resize(n_bins, 0.0);
2028        }
2029        if self.frame.len() != n_fft {
2030            self.frame.resize(n_fft, 0.0);
2031        }
2032        if self.fft_out.len() != out_len {
2033            self.fft_out.resize(out_len, Complex::new(0.0, 0.0));
2034        }
2035    }
2036}
2037
2038fn build_frequency_axis<FreqScale>(
2039    params: &SpectrogramParams,
2040    mapping: &FrequencyMapping<FreqScale>,
2041) -> FrequencyAxis<FreqScale>
2042where
2043    FreqScale: Copy + Clone + 'static,
2044{
2045    let frequencies = mapping.frequencies_hz(params);
2046    FrequencyAxis::new(frequencies)
2047}
2048
2049fn build_time_axis_seconds(params: &SpectrogramParams, n_frames: NonZeroUsize) -> NonEmptyVec<f64> {
2050    let dt = params.frame_period_seconds();
2051    let mut times = Vec::with_capacity(n_frames.get());
2052
2053    for i in 0..n_frames.get() {
2054        times.push(i as f64 * dt);
2055    }
2056
2057    // safety: times is guaranteed non-empty since n_frames > 0
2058
2059    unsafe { NonEmptyVec::new_unchecked(times) }
2060}
2061
2062/// Generate window function samples.
2063///
2064/// Supports various window types including Rectangular, Hanning, Hamming, Blackman, Kaiser, and Gaussian.
2065///
2066/// # Arguments
2067///
2068/// * `window` - The type of window function to generate.
2069/// * `n_fft` - The size of the FFT, which determines the length of the window.
2070///
2071/// # Returns
2072///
2073/// A `NonEmptyVec<f64>` containing the window function samples.
2074#[inline]
2075#[must_use]
2076pub fn make_window(window: WindowType, n_fft: NonZeroUsize) -> NonEmptyVec<f64> {
2077    let n_fft = n_fft.get();
2078    let mut w = vec![0.0; n_fft];
2079
2080    match window {
2081        WindowType::Rectangular => {
2082            w.fill(1.0);
2083        }
2084        WindowType::Hanning => {
2085            // Hann: 0.5 - 0.5*cos(2Ï€n/(N-1))
2086            let n1 = (n_fft - 1) as f64;
2087            for (n, v) in w.iter_mut().enumerate() {
2088                *v = 0.5f64.mul_add(-(2.0 * std::f64::consts::PI * (n as f64) / n1).cos(), 0.5);
2089            }
2090        }
2091        WindowType::Hamming => {
2092            // Hamming: 0.54 - 0.46*cos(2Ï€n/(N-1))
2093            let n1 = (n_fft - 1) as f64;
2094            for (n, v) in w.iter_mut().enumerate() {
2095                *v = 0.46f64.mul_add(-(2.0 * std::f64::consts::PI * (n as f64) / n1).cos(), 0.54);
2096            }
2097        }
2098        WindowType::Blackman => {
2099            // Blackman: 0.42 - 0.5*cos(2Ï€n/(N-1)) + 0.08*cos(4Ï€n/(N-1))
2100            let n1 = (n_fft - 1) as f64;
2101            for (n, v) in w.iter_mut().enumerate() {
2102                let a = 2.0 * std::f64::consts::PI * (n as f64) / n1;
2103                *v = 0.08f64.mul_add((2.0 * a).cos(), 0.5f64.mul_add(-a.cos(), 0.42));
2104            }
2105        }
2106        WindowType::Kaiser { beta } => {
2107            (0..n_fft).for_each(|i| {
2108                let n = i as f64;
2109                let n_max: f64 = (n_fft - 1) as f64;
2110                let alpha: f64 = (n - n_max / 2.0) / (n_max / 2.0);
2111                let bessel_arg = beta * alpha.mul_add(-alpha, 1.0).sqrt();
2112                // Simplified approximation of modified Bessel function
2113                let x = 1.0
2114                    + bessel_arg / 2.0
2115                        // Normalize by I0(beta) approximation
2116                        / (1.0 + beta / 2.0);
2117                w[i] = x;
2118            });
2119        }
2120        WindowType::Gaussian { std } => (0..n_fft).for_each(|i| {
2121            let n = i as f64;
2122            let center: f64 = (n_fft - 1) as f64 / 2.0;
2123            let exponent: f64 = -0.5 * ((n - center) / std).powi(2);
2124            w[i] = exponent.exp();
2125        }),
2126    }
2127
2128    // safety: window is guaranteed non-empty since n_fft > 0
2129    unsafe { NonEmptyVec::new_unchecked(w) }
2130}
2131
2132/// Convert Hz to mel scale using Slaney formula (librosa default, htk=False).
2133///
2134/// Uses a hybrid scale:
2135/// - Linear below 1000 Hz: mel = hz / (200/3)
2136/// - Logarithmic above 1000 Hz: mel = 15 + log(hz/1000) / log_step
2137///
2138/// This matches librosa's default behavior.
2139fn hz_to_mel(hz: f64) -> f64 {
2140    const F_MIN: f64 = 0.0;
2141    const F_SP: f64 = 200.0 / 3.0; // ~66.667
2142    const MIN_LOG_HZ: f64 = 1000.0;
2143    const MIN_LOG_MEL: f64 = (MIN_LOG_HZ - F_MIN) / F_SP; // = 15.0
2144    const LOGSTEP: f64 = 0.06853891945200942; // ln(6.4) / 27
2145
2146    if hz >= MIN_LOG_HZ {
2147        // Logarithmic region
2148        MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / LOGSTEP
2149    } else {
2150        // Linear region
2151        (hz - F_MIN) / F_SP
2152    }
2153}
2154
2155/// Convert mel to Hz using Slaney formula (librosa default, htk=False).
2156///
2157/// Inverse of hz_to_mel.
2158fn mel_to_hz(mel: f64) -> f64 {
2159    const F_MIN: f64 = 0.0;
2160    const F_SP: f64 = 200.0 / 3.0; // ~66.667
2161    const MIN_LOG_HZ: f64 = 1000.0;
2162    const MIN_LOG_MEL: f64 = (MIN_LOG_HZ - F_MIN) / F_SP; // = 15.0
2163    const LOGSTEP: f64 = 0.06853891945200942; // ln(6.4) / 27
2164
2165    if mel >= MIN_LOG_MEL {
2166        // Logarithmic region
2167        MIN_LOG_HZ * (LOGSTEP * (mel - MIN_LOG_MEL)).exp()
2168    } else {
2169        // Linear region
2170        F_MIN + F_SP * mel
2171    }
2172}
2173
2174fn build_mel_filterbank_matrix(
2175    sample_rate_hz: f64,
2176    n_fft: NonZeroUsize,
2177    n_mels: NonZeroUsize,
2178    f_min: f64,
2179    f_max: f64,
2180    norm: MelNorm,
2181) -> SpectrogramResult<SparseMatrix> {
2182    if sample_rate_hz <= 0.0 || !sample_rate_hz.is_finite() {
2183        return Err(SpectrogramError::invalid_input(
2184            "sample_rate_hz must be finite and > 0",
2185        ));
2186    }
2187    if f_min < 0.0 || f_min.is_infinite() {
2188        return Err(SpectrogramError::invalid_input("f_min must be >= 0"));
2189    }
2190    if f_max <= f_min {
2191        return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
2192    }
2193    if f_max > sample_rate_hz * 0.5 {
2194        return Err(SpectrogramError::invalid_input("f_max must be <= Nyquist"));
2195    }
2196    let n_mels = n_mels.get();
2197    let n_fft = n_fft.get();
2198    let out_len = r2c_output_size(n_fft);
2199
2200    // FFT bin frequencies
2201    let df = sample_rate_hz / n_fft as f64;
2202
2203    // Mel points: n_mels + 2 (for triangular edges)
2204    let mel_min = hz_to_mel(f_min);
2205    let mel_max = hz_to_mel(f_max);
2206
2207    let n_points = n_mels + 2;
2208    let step = (mel_max - mel_min) / (n_points - 1) as f64;
2209
2210    let mut mel_points = Vec::with_capacity(n_points);
2211    for i in 0..n_points {
2212        mel_points.push((i as f64).mul_add(step, mel_min));
2213    }
2214
2215    let mut hz_points = Vec::with_capacity(n_points);
2216    for m in &mel_points {
2217        hz_points.push(mel_to_hz(*m));
2218    }
2219
2220    // Convert Hz points into FFT bin indices
2221    let mut bin_points = Vec::with_capacity(n_points);
2222    for hz in hz_points {
2223        let b = (hz / df).floor() as isize;
2224        let b = b.clamp(0, (out_len - 1) as isize);
2225        bin_points.push(b as usize);
2226    }
2227
2228    // Build filterbank as sparse matrix
2229    let mut fb = SparseMatrix::new(n_mels, out_len);
2230
2231    for m in 0..n_mels {
2232        let left = bin_points[m];
2233        let centre = bin_points[m + 1];
2234        let right = bin_points[m + 2];
2235
2236        if left == centre || centre == right {
2237            // Degenerate triangle, skip (should be rare if params are sensible)
2238            continue;
2239        }
2240
2241        // Rising slope: left -> centre
2242        for k in left..centre {
2243            let v = (k - left) as f64 / (centre - left) as f64;
2244            fb.set(m, k, v);
2245        }
2246
2247        // Falling slope: centre -> right
2248        for k in centre..right {
2249            let v = (right - k) as f64 / (right - centre) as f64;
2250            fb.set(m, k, v);
2251        }
2252    }
2253
2254    // Apply normalization
2255    match norm {
2256        MelNorm::None => {
2257            // No normalization needed
2258        }
2259        MelNorm::Slaney => {
2260            // Slaney-style area normalization: 2 / (hz_max - hz_min) for each triangle
2261            // NOTE: Uses Hz bandwidth, not mel bandwidth (to match librosa's implementation)
2262            for m in 0..n_mels {
2263                let mel_left = mel_points[m];
2264                let mel_right = mel_points[m + 2];
2265                let hz_left = mel_to_hz(mel_left);
2266                let hz_right = mel_to_hz(mel_right);
2267                let enorm = 2.0 / (hz_right - hz_left);
2268
2269                // Normalize all values in this row
2270                for val in &mut fb.values[m] {
2271                    *val *= enorm;
2272                }
2273            }
2274        }
2275        MelNorm::L1 => {
2276            // L1 normalization: sum of weights = 1.0
2277            for m in 0..n_mels {
2278                let sum: f64 = fb.values[m].iter().sum();
2279                if sum > 0.0 {
2280                    let normalizer = 1.0 / sum;
2281                    for val in &mut fb.values[m] {
2282                        *val *= normalizer;
2283                    }
2284                }
2285            }
2286        }
2287        MelNorm::L2 => {
2288            // L2 normalization: L2 norm = 1.0
2289            for m in 0..n_mels {
2290                let norm_val: f64 = fb.values[m].iter().map(|&v| v * v).sum::<f64>().sqrt();
2291                if norm_val > 0.0 {
2292                    let normalizer = 1.0 / norm_val;
2293                    for val in &mut fb.values[m] {
2294                        *val *= normalizer;
2295                    }
2296                }
2297            }
2298        }
2299    }
2300
2301    Ok(fb)
2302}
2303
2304/// Build a logarithmic frequency interpolation matrix.
2305///
2306/// Maps linearly-spaced FFT bins to logarithmically-spaced frequency bins
2307/// using linear interpolation.
2308fn build_loghz_matrix(
2309    sample_rate_hz: f64,
2310    n_fft: NonZeroUsize,
2311    n_bins: NonZeroUsize,
2312    f_min: f64,
2313    f_max: f64,
2314) -> SpectrogramResult<(SparseMatrix, NonEmptyVec<f64>)> {
2315    if sample_rate_hz <= 0.0 || !sample_rate_hz.is_finite() {
2316        return Err(SpectrogramError::invalid_input(
2317            "sample_rate_hz must be finite and > 0",
2318        ));
2319    }
2320    if f_min <= 0.0 || f_min.is_infinite() {
2321        return Err(SpectrogramError::invalid_input(
2322            "f_min must be finite and > 0",
2323        ));
2324    }
2325    if f_max <= f_min {
2326        return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
2327    }
2328    if f_max > sample_rate_hz * 0.5 {
2329        return Err(SpectrogramError::invalid_input("f_max must be <= Nyquist"));
2330    }
2331
2332    let n_bins = n_bins.get();
2333    let n_fft = n_fft.get();
2334
2335    let out_len = r2c_output_size(n_fft);
2336    let df = sample_rate_hz / n_fft as f64;
2337
2338    // Generate logarithmically-spaced frequencies
2339    let log_f_min = f_min.ln();
2340    let log_f_max = f_max.ln();
2341    let log_step = (log_f_max - log_f_min) / (n_bins - 1) as f64;
2342
2343    let mut log_frequencies = Vec::with_capacity(n_bins);
2344    for i in 0..n_bins {
2345        let log_f = (i as f64).mul_add(log_step, log_f_min);
2346        log_frequencies.push(log_f.exp());
2347    }
2348    // safety: n_bins > 0
2349    let log_frequencies = unsafe { NonEmptyVec::new_unchecked(log_frequencies) };
2350
2351    // Build interpolation matrix as sparse matrix
2352    let mut matrix = SparseMatrix::new(n_bins, out_len);
2353
2354    for (bin_idx, &target_freq) in log_frequencies.iter().enumerate() {
2355        // Find the two FFT bins that bracket this frequency
2356        let exact_bin = target_freq / df;
2357        let lower_bin = exact_bin.floor() as usize;
2358        let upper_bin = (exact_bin.ceil() as usize).min(out_len - 1);
2359
2360        if lower_bin >= out_len {
2361            continue;
2362        }
2363
2364        if lower_bin == upper_bin {
2365            // Exact match
2366            matrix.set(bin_idx, lower_bin, 1.0);
2367        } else {
2368            // Linear interpolation
2369            let frac = exact_bin - lower_bin as f64;
2370            matrix.set(bin_idx, lower_bin, 1.0 - frac);
2371            if upper_bin < out_len {
2372                matrix.set(bin_idx, upper_bin, frac);
2373            }
2374        }
2375    }
2376
2377    Ok((matrix, log_frequencies))
2378}
2379
2380fn mel_band_centres_hz(
2381    n_mels: NonZeroUsize,
2382    sample_rate_hz: f64,
2383    nyquist_hz: f64,
2384) -> NonEmptyVec<f64> {
2385    let f_min = 0.0;
2386    let f_max = nyquist_hz.min(sample_rate_hz * 0.5);
2387
2388    let mel_min = hz_to_mel(f_min);
2389    let mel_max = hz_to_mel(f_max);
2390    let n_mels = n_mels.get();
2391    let step = (mel_max - mel_min) / (n_mels + 1) as f64;
2392
2393    let mut centres = Vec::with_capacity(n_mels);
2394    for i in 0..n_mels {
2395        let mel = (i as f64 + 1.0).mul_add(step, mel_min);
2396        centres.push(mel_to_hz(mel));
2397    }
2398    // safety: centres is guaranteed non-empty since n_mels > 0
2399    unsafe { NonEmptyVec::new_unchecked(centres) }
2400}
2401
2402/// Spectrogram structure holding the computed spectrogram data and metadata.
2403///
2404/// # Type Parameters
2405///
2406/// * `FreqScale`: The frequency scale type (e.g., `LinearHz`, `Mel`, `LogHz`, etc.).
2407/// * `AmpScale`: The amplitude scale type (e.g., `Power`, `Magnitude`, `Decibels`).
2408///
2409/// # Fields
2410///
2411/// * `data`: A 2D array containing the spectrogram data.
2412/// * `axes`: The axes of the spectrogram (frequency and time).
2413/// * `params`: The parameters used to compute the spectrogram.
2414/// * `_amp`: A phantom data marker for the amplitude scale type.
2415#[derive(Debug, Clone)]
2416#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
2417pub struct Spectrogram<FreqScale, AmpScale>
2418where
2419    AmpScale: AmpScaleSpec + 'static,
2420    FreqScale: Copy + Clone + 'static,
2421{
2422    data: Array2<f64>,
2423    axes: Axes<FreqScale>,
2424    params: SpectrogramParams,
2425    #[cfg_attr(feature = "serde", serde(skip))]
2426    _amp: PhantomData<AmpScale>,
2427}
2428
2429impl<FreqScale, AmpScale> Spectrogram<FreqScale, AmpScale>
2430where
2431    AmpScale: AmpScaleSpec + 'static,
2432    FreqScale: Copy + Clone + 'static,
2433{
2434    /// Get the X-axis label.
2435    ///
2436    /// # Returns
2437    ///
2438    /// A static string slice representing the X-axis label.
2439    #[inline]
2440    #[must_use]
2441    pub const fn x_axis_label() -> &'static str {
2442        "Time (s)"
2443    }
2444
2445    /// Get the Y-axis label based on the frequency scale type.
2446    ///
2447    /// # Returns
2448    ///
2449    /// A static string slice representing the Y-axis label.
2450    #[inline]
2451    #[must_use]
2452    pub fn y_axis_label() -> &'static str {
2453        match std::any::TypeId::of::<FreqScale>() {
2454            id if id == std::any::TypeId::of::<LinearHz>() => "Frequency (Hz)",
2455            id if id == std::any::TypeId::of::<Mel>() => "Frequency (Mel)",
2456            id if id == std::any::TypeId::of::<LogHz>() => "Frequency (Log Hz)",
2457            id if id == std::any::TypeId::of::<Erb>() => "Frequency (ERB)",
2458            id if id == std::any::TypeId::of::<Cqt>() => "Frequency (CQT Bins)",
2459            _ => "Frequency",
2460        }
2461    }
2462
2463    /// Internal constructor. Only callable inside the crate.
2464    ///
2465    /// All inputs must already be validated and consistent.
2466    pub(crate) fn new(data: Array2<f64>, axes: Axes<FreqScale>, params: SpectrogramParams) -> Self {
2467        debug_assert_eq!(data.nrows(), axes.frequencies().len().get());
2468        debug_assert_eq!(data.ncols(), axes.times().len().get());
2469
2470        Self {
2471            data,
2472            axes,
2473            params,
2474            _amp: PhantomData,
2475        }
2476    }
2477
2478    /// Set spectrogram data matrix
2479    ///
2480    /// # Arguments
2481    ///
2482    /// * `data` - The new spectrogram data matrix.
2483    #[inline]
2484    pub fn set_data(&mut self, data: Array2<f64>) {
2485        self.data = data;
2486    }
2487
2488    /// Spectrogram data matrix
2489    ///
2490    /// # Returns
2491    ///
2492    /// A reference to the spectrogram data matrix.
2493    #[inline]
2494    #[must_use]
2495    pub const fn data(&self) -> &Array2<f64> {
2496        &self.data
2497    }
2498
2499    /// Axes of the spectrogram
2500    ///
2501    /// # Returns
2502    ///
2503    /// A reference to the axes of the spectrogram.
2504    #[inline]
2505    #[must_use]
2506    pub const fn axes(&self) -> &Axes<FreqScale> {
2507        &self.axes
2508    }
2509
2510    /// Frequency axis in Hz
2511    ///
2512    /// # Returns
2513    ///
2514    /// A reference to the frequency axis in Hz.
2515    #[inline]
2516    #[must_use]
2517    pub fn frequencies(&self) -> &NonEmptySlice<f64> {
2518        self.axes.frequencies()
2519    }
2520
2521    /// Frequency range in Hz (min, max)
2522    ///
2523    /// # Returns
2524    ///
2525    /// A tuple containing the minimum and maximum frequencies in Hz.
2526    #[inline]
2527    #[must_use]
2528    pub const fn frequency_range(&self) -> (f64, f64) {
2529        self.axes.frequency_range()
2530    }
2531
2532    /// Time axis in seconds
2533    ///
2534    /// # Returns
2535    ///
2536    /// A reference to the time axis in seconds.
2537    #[inline]
2538    #[must_use]
2539    pub fn times(&self) -> &NonEmptySlice<f64> {
2540        self.axes.times()
2541    }
2542
2543    /// Spectrogram computation parameters
2544    ///
2545    /// # Returns
2546    ///
2547    /// A reference to the spectrogram computation parameters.
2548    #[inline]
2549    #[must_use]
2550    pub const fn params(&self) -> &SpectrogramParams {
2551        &self.params
2552    }
2553
2554    /// Duration of the spectrogram in seconds
2555    ///
2556    /// # Returns
2557    ///
2558    /// The duration of the spectrogram in seconds.
2559    #[inline]
2560    #[must_use]
2561    pub fn duration(&self) -> f64 {
2562        self.axes.duration()
2563    }
2564
2565    /// If this is a dB spectrogram, return the (min, max) dB values. otherwise do the maths to compute dB range.
2566    ///
2567    /// # Returns
2568    ///
2569    /// The (min, max) dB values of the spectrogram, or `None` if the amplitude scale is unknown.
2570    #[inline]
2571    #[must_use]
2572    pub fn db_range(&self) -> Option<(f64, f64)> {
2573        let type_self = std::any::TypeId::of::<AmpScale>();
2574
2575        if type_self == std::any::TypeId::of::<Decibels>() {
2576            let (min, max) = min_max_single_pass(self.data.as_slice()?);
2577            Some((min, max))
2578        } else if type_self == std::any::TypeId::of::<Power>() {
2579            // Not a dB spectrogram; compute dB range from power values
2580            let mut min_db = f64::INFINITY;
2581            let mut max_db = f64::NEG_INFINITY;
2582            for &v in &self.data {
2583                let db = 10.0 * (v + EPS).log10();
2584                if db < min_db {
2585                    min_db = db;
2586                }
2587                if db > max_db {
2588                    max_db = db;
2589                }
2590            }
2591            Some((min_db, max_db))
2592        } else if type_self == std::any::TypeId::of::<Magnitude>() {
2593            // Not a dB spectrogram; compute dB range from magnitude values
2594            let mut min_db = f64::INFINITY;
2595            let mut max_db = f64::NEG_INFINITY;
2596
2597            for &v in &self.data {
2598                let power = v * v;
2599                let db = 10.0 * (power + EPS).log10();
2600                if db < min_db {
2601                    min_db = db;
2602                }
2603                if db > max_db {
2604                    max_db = db;
2605                }
2606            }
2607
2608            Some((min_db, max_db))
2609        } else {
2610            // Unknown AmpScale type; return dummy values
2611            None
2612        }
2613    }
2614
2615    /// Number of frequency bins
2616    ///
2617    /// # Returns
2618    ///
2619    /// The number of frequency bins in the spectrogram.
2620    #[inline]
2621    #[must_use]
2622    pub fn n_bins(&self) -> NonZeroUsize {
2623        // safety: data.nrows() > 0 is guaranteed by construction
2624        unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
2625    }
2626
2627    /// Number of time frames in the spectrogram
2628    ///
2629    /// # Returns
2630    ///
2631    /// The number of time frames (columns) in the spectrogram.
2632    #[inline]
2633    #[must_use]
2634    pub fn n_frames(&self) -> NonZeroUsize {
2635        // safety: data.ncols() > 0 is guaranteed by construction
2636        unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
2637    }
2638}
2639
2640impl AsRef<Array2<f64>> for Spectrogram<LinearHz, Power> {
2641    #[inline]
2642    fn as_ref(&self) -> &Array2<f64> {
2643        &self.data
2644    }
2645}
2646
2647impl Deref for Spectrogram<LinearHz, Power> {
2648    type Target = Array2<f64>;
2649
2650    #[inline]
2651    fn deref(&self) -> &Self::Target {
2652        &self.data
2653    }
2654}
2655
2656impl<AmpScale> Spectrogram<LinearHz, AmpScale>
2657where
2658    AmpScale: AmpScaleSpec + 'static,
2659{
2660    /// Compute a linear-frequency spectrogram from audio samples.
2661    ///
2662    /// This is a convenience method that creates a planner internally and computes
2663    /// the spectrogram in one call. For processing multiple signals with the same
2664    /// parameters, use [`SpectrogramPlanner::linear_plan`] to create a reusable plan.
2665    ///
2666    /// # Arguments
2667    ///
2668    /// * `samples` - Audio samples (any type that can be converted to a slice)
2669    /// * `params` - Spectrogram computation parameters
2670    /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2671    ///
2672    /// # Returns
2673    ///
2674    /// A linear-frequency spectrogram with the specified amplitude scale.
2675    ///
2676    /// # Errors
2677    ///
2678    /// Returns an error if:
2679    /// - The samples slice is empty
2680    /// - Parameters are invalid
2681    /// - FFT computation fails
2682    ///
2683    /// # Examples
2684    ///
2685    /// ```
2686    /// use spectrograms::*;
2687    /// use non_empty_slice::non_empty_vec;
2688    ///
2689    /// # fn example() -> SpectrogramResult<()> {
2690    /// // Create a simple test signal
2691    /// let sample_rate = 16000.0;
2692    /// let samples_vec: Vec<f64> = (0..16000).map(|i| {
2693    ///     (2.0 * std::f64::consts::PI * 440.0 * i as f64 / sample_rate).sin()
2694    /// }).collect();
2695    /// let samples = non_empty_slice::NonEmptyVec::new(samples_vec).unwrap();
2696    ///
2697    /// // Set up parameters
2698    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2699    /// let params = SpectrogramParams::new(stft, sample_rate)?;
2700    ///
2701    /// // Compute power spectrogram
2702    /// let spec = LinearPowerSpectrogram::compute(&samples, &params, None)?;
2703    ///
2704    /// println!("Computed spectrogram: {} bins x {} frames", spec.n_bins(), spec.n_frames());
2705    /// # Ok(())
2706    /// # }
2707    /// ```
2708    #[inline]
2709    pub fn compute(
2710        samples: &NonEmptySlice<f64>,
2711        params: &SpectrogramParams,
2712        db: Option<&LogParams>,
2713    ) -> SpectrogramResult<Self> {
2714        let planner = SpectrogramPlanner::new();
2715        let mut plan = planner.linear_plan(params, db)?;
2716        plan.compute(samples)
2717    }
2718}
2719
2720impl<AmpScale> Spectrogram<Mel, AmpScale>
2721where
2722    AmpScale: AmpScaleSpec + 'static,
2723{
2724    /// Compute a mel-frequency spectrogram from audio samples.
2725    ///
2726    /// This is a convenience method that creates a planner internally and computes
2727    /// the spectrogram in one call. For processing multiple signals with the same
2728    /// parameters, use [`SpectrogramPlanner::mel_plan`] to create a reusable plan.
2729    ///
2730    /// # Arguments
2731    ///
2732    /// * `samples` - Audio samples (any type that can be converted to a slice)
2733    /// * `params` - Spectrogram computation parameters
2734    /// * `mel` - Mel filterbank parameters
2735    /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2736    ///
2737    /// # Returns
2738    ///
2739    /// A mel-frequency spectrogram with the specified amplitude scale.
2740    ///
2741    /// # Errors
2742    ///
2743    /// Returns an error if:
2744    /// - The samples slice is empty
2745    /// - Parameters are invalid
2746    /// - Mel `f_max` exceeds Nyquist frequency
2747    /// - FFT computation fails
2748    ///
2749    /// # Examples
2750    ///
2751    /// ```
2752    /// use spectrograms::*;
2753    /// use non_empty_slice::non_empty_vec;
2754    ///
2755    /// # fn example() -> SpectrogramResult<()> {
2756    /// // Create a simple test signal
2757    /// let sample_rate = 16000.0;
2758    /// let samples_vec: Vec<f64> = (0..16000).map(|i| {
2759    ///     (2.0 * std::f64::consts::PI * 440.0 * i as f64 / sample_rate).sin()
2760    /// }).collect();
2761    /// let samples = non_empty_slice::NonEmptyVec::new(samples_vec).unwrap();
2762    ///
2763    /// // Set up parameters
2764    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2765    /// let params = SpectrogramParams::new(stft, sample_rate)?;
2766    /// let mel = MelParams::new(nzu!(80), 0.0, 8000.0)?;
2767    ///
2768    /// // Compute mel spectrogram in dB scale
2769    /// let db = LogParams::new(-80.0)?;
2770    /// let spec = MelDbSpectrogram::compute(&samples, &params, &mel, Some(&db))?;
2771    ///
2772    /// println!("Computed mel spectrogram: {} mels x {} frames", spec.n_bins(), spec.n_frames());
2773    /// # Ok(())
2774    /// # }
2775    /// ```
2776    #[inline]
2777    pub fn compute(
2778        samples: &NonEmptySlice<f64>,
2779        params: &SpectrogramParams,
2780        mel: &MelParams,
2781        db: Option<&LogParams>,
2782    ) -> SpectrogramResult<Self> {
2783        let planner = SpectrogramPlanner::new();
2784        let mut plan = planner.mel_plan(params, mel, db)?;
2785        plan.compute(samples)
2786    }
2787}
2788
2789impl<AmpScale> Spectrogram<Erb, AmpScale>
2790where
2791    AmpScale: AmpScaleSpec + 'static,
2792{
2793    /// Compute an ERB-frequency spectrogram from audio samples.
2794    ///
2795    /// This is a convenience method that creates a planner internally and computes
2796    /// the spectrogram in one call. For processing multiple signals with the same
2797    /// parameters, use [`SpectrogramPlanner::erb_plan`] to create a reusable plan.
2798    ///
2799    /// # Arguments
2800    ///
2801    /// * `samples` - Audio samples (any type that can be converted to a slice)
2802    /// * `params` - Spectrogram computation parameters
2803    /// * `erb` - ERB frequency scale parameters
2804    /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2805    ///
2806    /// # Returns
2807    ///
2808    /// An ERB-scale spectrogram with the specified amplitude scale.
2809    ///
2810    /// # Errors
2811    ///
2812    /// Returns an error if:
2813    /// - The samples slice is empty
2814    /// - Parameters are invalid
2815    /// - FFT computation fails
2816    ///
2817    /// # Examples
2818    ///
2819    /// ```
2820    /// use spectrograms::*;
2821    /// use non_empty_slice::non_empty_vec;
2822    ///
2823    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2824    /// let samples = non_empty_vec![0.0; nzu!(16000)];
2825    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2826    /// let params = SpectrogramParams::new(stft, 16000.0)?;
2827    /// let erb = ErbParams::speech_standard();
2828    ///
2829    /// let spec = ErbPowerSpectrogram::compute(&samples, &params, &erb, None)?;
2830    /// assert_eq!(spec.n_bins(), nzu!(40));
2831    /// # Ok(())
2832    /// # }
2833    /// ```
2834    #[inline]
2835    pub fn compute(
2836        samples: &NonEmptySlice<f64>,
2837        params: &SpectrogramParams,
2838        erb: &ErbParams,
2839        db: Option<&LogParams>,
2840    ) -> SpectrogramResult<Self> {
2841        let planner = SpectrogramPlanner::new();
2842        let mut plan = planner.erb_plan(params, erb, db)?;
2843        plan.compute(samples)
2844    }
2845}
2846
2847impl<AmpScale> Spectrogram<LogHz, AmpScale>
2848where
2849    AmpScale: AmpScaleSpec + 'static,
2850{
2851    /// Compute a logarithmic-frequency spectrogram from audio samples.
2852    ///
2853    /// This is a convenience method that creates a planner internally and computes
2854    /// the spectrogram in one call. For processing multiple signals with the same
2855    /// parameters, use [`SpectrogramPlanner::log_hz_plan`] to create a reusable plan.
2856    ///
2857    /// # Arguments
2858    ///
2859    /// * `samples` - Audio samples (any type that can be converted to a slice)
2860    /// * `params` - Spectrogram computation parameters
2861    /// * `loghz` - Logarithmic frequency scale parameters
2862    /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2863    ///
2864    /// # Returns
2865    ///
2866    /// A logarithmic-frequency spectrogram with the specified amplitude scale.
2867    ///
2868    /// # Errors
2869    ///
2870    /// Returns an error if:
2871    /// - The samples slice is empty
2872    /// - Parameters are invalid
2873    /// - FFT computation fails
2874    ///
2875    /// # Examples
2876    ///
2877    /// ```
2878    /// use spectrograms::*;
2879    /// use non_empty_slice::non_empty_vec;
2880    ///
2881    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2882    /// let samples = non_empty_vec![0.0; nzu!(16000)];
2883    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2884    /// let params = SpectrogramParams::new(stft, 16000.0)?;
2885    /// let loghz = LogHzParams::new(nzu!(128), 20.0, 8000.0)?;
2886    ///
2887    /// let spec = LogHzPowerSpectrogram::compute(&samples, &params, &loghz, None)?;
2888    /// assert_eq!(spec.n_bins(), nzu!(128));
2889    /// # Ok(())
2890    /// # }
2891    /// ```
2892    #[inline]
2893    pub fn compute(
2894        samples: &NonEmptySlice<f64>,
2895        params: &SpectrogramParams,
2896        loghz: &LogHzParams,
2897        db: Option<&LogParams>,
2898    ) -> SpectrogramResult<Self> {
2899        let planner = SpectrogramPlanner::new();
2900        let mut plan = planner.log_hz_plan(params, loghz, db)?;
2901        plan.compute(samples)
2902    }
2903}
2904
2905impl<AmpScale> Spectrogram<Cqt, AmpScale>
2906where
2907    AmpScale: AmpScaleSpec + 'static,
2908{
2909    /// Compute a constant-Q transform (CQT) spectrogram from audio samples.
2910    ///
2911    /// # Arguments
2912    ///
2913    /// * `samples` - Audio samples (any type that can be converted to a slice)
2914    /// * `params` - Spectrogram computation parameters
2915    /// * `cqt` - CQT parameters
2916    /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2917    ///
2918    /// # Returns
2919    ///
2920    /// A CQT spectrogram with the specified amplitude scale.
2921    ///
2922    /// # Errors
2923    ///
2924    /// Returns an error if:
2925    ///
2926    #[inline]
2927    pub fn compute(
2928        samples: &NonEmptySlice<f64>,
2929        params: &SpectrogramParams,
2930        cqt: &CqtParams,
2931        db: Option<&LogParams>,
2932    ) -> SpectrogramResult<Self> {
2933        let planner = SpectrogramPlanner::new();
2934        let mut plan = planner.cqt_plan(params, cqt, db)?;
2935        plan.compute(samples)
2936    }
2937}
2938
2939// ========================
2940// Display implementations
2941// ========================
2942
2943/// Helper function to get amplitude scale name
2944fn amp_scale_name<AmpScale>() -> &'static str
2945where
2946    AmpScale: AmpScaleSpec + 'static,
2947{
2948    match std::any::TypeId::of::<AmpScale>() {
2949        id if id == std::any::TypeId::of::<Power>() => "Power",
2950        id if id == std::any::TypeId::of::<Magnitude>() => "Magnitude",
2951        id if id == std::any::TypeId::of::<Decibels>() => "Decibels",
2952        _ => "Unknown",
2953    }
2954}
2955
2956/// Helper function to get frequency scale name
2957fn freq_scale_name<FreqScale>() -> &'static str
2958where
2959    FreqScale: Copy + Clone + 'static,
2960{
2961    match std::any::TypeId::of::<FreqScale>() {
2962        id if id == std::any::TypeId::of::<LinearHz>() => "Linear Hz",
2963        id if id == std::any::TypeId::of::<LogHz>() => "Log Hz",
2964        id if id == std::any::TypeId::of::<Mel>() => "Mel",
2965        id if id == std::any::TypeId::of::<Erb>() => "ERB",
2966        id if id == std::any::TypeId::of::<Cqt>() => "CQT",
2967        _ => "Unknown",
2968    }
2969}
2970
2971impl<FreqScale, AmpScale> core::fmt::Display for Spectrogram<FreqScale, AmpScale>
2972where
2973    AmpScale: AmpScaleSpec + 'static,
2974    FreqScale: Copy + Clone + 'static,
2975{
2976    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
2977        let (freq_min, freq_max) = self.frequency_range();
2978        let duration = self.duration();
2979        let (rows, cols) = self.data.dim();
2980
2981        // Alternative formatting (#) provides more detailed output with data
2982        if f.alternate() {
2983            writeln!(f, "Spectrogram {{")?;
2984            writeln!(f, "  Frequency Scale: {}", freq_scale_name::<FreqScale>())?;
2985            writeln!(f, "  Amplitude Scale: {}", amp_scale_name::<AmpScale>())?;
2986            writeln!(f, "  Shape: {} frequency bins × {} time frames", rows, cols)?;
2987            writeln!(
2988                f,
2989                "  Frequency Range: {:.2} Hz - {:.2} Hz",
2990                freq_min, freq_max
2991            )?;
2992            writeln!(f, "  Duration: {:.3} s", duration)?;
2993            writeln!(f)?;
2994
2995            // Display parameters
2996            writeln!(f, "  Parameters:")?;
2997            writeln!(f, "    Sample Rate: {} Hz", self.params.sample_rate_hz())?;
2998            writeln!(f, "    FFT Size: {}", self.params.stft().n_fft())?;
2999            writeln!(f, "    Hop Size: {}", self.params.stft().hop_size())?;
3000            writeln!(f, "    Window: {:?}", self.params.stft().window())?;
3001            writeln!(f, "    Centered: {}", self.params.stft().centre())?;
3002            writeln!(f)?;
3003
3004            // Display data statistics
3005            let data_slice = self.data.as_slice().unwrap_or(&[]);
3006            if !data_slice.is_empty() {
3007                let (min_val, max_val) = min_max_single_pass(data_slice);
3008                let mean = data_slice.iter().sum::<f64>() / data_slice.len() as f64;
3009                writeln!(f, "  Data Statistics:")?;
3010                writeln!(f, "    Min: {:.6}", min_val)?;
3011                writeln!(f, "    Max: {:.6}", max_val)?;
3012                writeln!(f, "    Mean: {:.6}", mean)?;
3013                writeln!(f)?;
3014            }
3015
3016            // Display actual data (truncated if too large)
3017            writeln!(f, "  Data Matrix:")?;
3018            let max_rows_to_display = 8;
3019            let max_cols_to_display = 8;
3020
3021            for i in 0..rows.min(max_rows_to_display) {
3022                write!(f, "    [")?;
3023                for j in 0..cols.min(max_cols_to_display) {
3024                    if j > 0 {
3025                        write!(f, ", ")?;
3026                    }
3027                    write!(f, "{:9.4}", self.data[[i, j]])?;
3028                }
3029                if cols > max_cols_to_display {
3030                    write!(f, ", ... ({} more)", cols - max_cols_to_display)?;
3031                }
3032                writeln!(f, "]")?;
3033            }
3034
3035            if rows > max_rows_to_display {
3036                writeln!(f, "    ... ({} more rows)", rows - max_rows_to_display)?;
3037            }
3038
3039            write!(f, "}}")?;
3040        } else {
3041            // Default formatting: compact summary
3042            write!(
3043                f,
3044                "Spectrogram<{}, {}>[{}×{}] ({:.2}-{:.2} Hz, {:.3}s)",
3045                freq_scale_name::<FreqScale>(),
3046                amp_scale_name::<AmpScale>(),
3047                rows,
3048                cols,
3049                freq_min,
3050                freq_max,
3051                duration
3052            )?;
3053        }
3054
3055        Ok(())
3056    }
3057}
3058
3059#[derive(Debug, Clone)]
3060#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3061pub struct FrequencyAxis<FreqScale>
3062where
3063    FreqScale: Copy + Clone + 'static,
3064{
3065    frequencies: NonEmptyVec<f64>,
3066    #[cfg_attr(feature = "serde", serde(skip))]
3067    _marker: PhantomData<FreqScale>,
3068}
3069
3070impl<FreqScale> FrequencyAxis<FreqScale>
3071where
3072    FreqScale: Copy + Clone + 'static,
3073{
3074    pub(crate) const fn new(frequencies: NonEmptyVec<f64>) -> Self {
3075        Self {
3076            frequencies,
3077            _marker: PhantomData,
3078        }
3079    }
3080
3081    /// Get the frequency values in Hz.
3082    ///
3083    /// # Returns
3084    ///
3085    /// Returns a non-empty slice of frequencies.
3086    #[inline]
3087    #[must_use]
3088    pub fn frequencies(&self) -> &NonEmptySlice<f64> {
3089        &self.frequencies
3090    }
3091
3092    /// Get the frequency range (min, max) in Hz.
3093    ///
3094    /// # Returns
3095    ///
3096    /// Returns a tuple containing the minimum and maximum frequency.
3097    #[inline]
3098    #[must_use]
3099    pub const fn frequency_range(&self) -> (f64, f64) {
3100        let data = self.frequencies.as_slice();
3101        let min = data[0];
3102        let max_idx = data.len().saturating_sub(1); // safe for non-empty
3103        let max = data[max_idx];
3104        (min, max)
3105    }
3106
3107    /// Get the number of frequency bins.
3108    ///
3109    /// # Returns
3110    ///
3111    /// Returns the number of frequency bins as a NonZeroUsize.
3112    #[inline]
3113    #[must_use]
3114    pub const fn len(&self) -> NonZeroUsize {
3115        self.frequencies.len()
3116    }
3117}
3118
3119/// Spectrogram axes container.
3120///
3121/// Holds frequency and time axes.
3122#[derive(Debug, Clone)]
3123#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3124pub struct Axes<FreqScale>
3125where
3126    FreqScale: Copy + Clone + 'static,
3127{
3128    freq: FrequencyAxis<FreqScale>,
3129    times: NonEmptyVec<f64>,
3130}
3131
3132impl<FreqScale> Axes<FreqScale>
3133where
3134    FreqScale: Copy + Clone + 'static,
3135{
3136    pub(crate) const fn new(freq: FrequencyAxis<FreqScale>, times: NonEmptyVec<f64>) -> Self {
3137        Self { freq, times }
3138    }
3139
3140    /// Get the frequency values in Hz.
3141    ///
3142    /// # Returns
3143    ///
3144    /// Returns a non-empty slice of frequencies.
3145    #[inline]
3146    #[must_use]
3147    pub fn frequencies(&self) -> &NonEmptySlice<f64> {
3148        self.freq.frequencies()
3149    }
3150
3151    /// Get the time values in seconds.
3152    ///
3153    /// # Returns
3154    ///
3155    /// Returns a non-empty slice of time values.
3156    #[inline]
3157    #[must_use]
3158    pub fn times(&self) -> &NonEmptySlice<f64> {
3159        &self.times
3160    }
3161
3162    /// Get the frequency range (min, max) in Hz.
3163    ///
3164    /// # Returns
3165    ///
3166    /// Returns a tuple containing the minimum and maximum frequency.
3167    #[inline]
3168    #[must_use]
3169    pub const fn frequency_range(&self) -> (f64, f64) {
3170        self.freq.frequency_range()
3171    }
3172
3173    /// Get the duration of the spectrogram in seconds.
3174    ///
3175    /// # Returns
3176    ///
3177    /// Returns the duration in seconds.
3178    #[inline]
3179    #[must_use]
3180    pub fn duration(&self) -> f64 {
3181        *self.times.last()
3182    }
3183}
3184
3185// Enum types for frequency and amplitude scales
3186
3187/// Linear frequency scale
3188#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3189#[cfg_attr(feature = "python", pyclass)]
3190#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3191#[non_exhaustive]
3192pub enum LinearHz {
3193    _Phantom,
3194}
3195
3196/// Logarithmic frequency scale
3197#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3198#[cfg_attr(feature = "python", pyclass)]
3199#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3200#[non_exhaustive]
3201pub enum LogHz {
3202    _Phantom,
3203}
3204
3205/// Mel frequency scale
3206#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3207#[cfg_attr(feature = "python", pyclass)]
3208#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3209#[non_exhaustive]
3210pub enum Mel {
3211    _Phantom,
3212}
3213
3214/// ERB/gammatone frequency scale
3215#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3216#[cfg_attr(feature = "python", pyclass)]
3217#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3218#[non_exhaustive]
3219pub enum Erb {
3220    _Phantom,
3221}
3222pub type Gammatone = Erb;
3223
3224/// Constant-Q Transform frequency scale
3225#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3226#[cfg_attr(feature = "python", pyclass)]
3227#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3228#[non_exhaustive]
3229pub enum Cqt {
3230    _Phantom,
3231}
3232
3233// Amplitude scales
3234
3235/// Power amplitude scale
3236#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3237#[cfg_attr(feature = "python", pyclass)]
3238#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3239#[non_exhaustive]
3240pub enum Power {
3241    _Phantom,
3242}
3243
3244/// Decibel amplitude scale
3245#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3246#[cfg_attr(feature = "python", pyclass)]
3247#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3248#[non_exhaustive]
3249pub enum Decibels {
3250    _Phantom,
3251}
3252
3253/// Magnitude amplitude scale
3254#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3255#[cfg_attr(feature = "python", pyclass)]
3256#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3257#[non_exhaustive]
3258pub enum Magnitude {
3259    _Phantom,
3260}
3261
3262/// STFT parameters for spectrogram computation.
3263///
3264/// * `n_fft`: Size of the FFT window.
3265/// * `hop_size`: Number of samples between successive frames.
3266/// * window: Window function to apply to each frame.
3267/// * centre: Whether to pad the input signal so that frames are centered.
3268#[derive(Debug, Clone, Copy, PartialEq)]
3269#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3270pub struct StftParams {
3271    n_fft: NonZeroUsize,
3272    hop_size: NonZeroUsize,
3273    window: WindowType,
3274    centre: bool,
3275}
3276
3277impl StftParams {
3278    /// Create new STFT parameters.
3279    ///
3280    /// # Arguments
3281    ///
3282    /// * `n_fft` - Size of the FFT window
3283    /// * `hop_size` - Number of samples between successive frames
3284    /// * `window` - Window function to apply to each frame
3285    /// * `centre` - Whether to pad the input signal so that frames are centered
3286    ///
3287    /// # Errors
3288    ///
3289    /// Returns an error if `hop_size` > `n_fft`.
3290    ///
3291    /// # Returns
3292    ///
3293    /// New `StftParams` instance.
3294    #[inline]
3295    pub fn new(
3296        n_fft: NonZeroUsize,
3297        hop_size: NonZeroUsize,
3298        window: WindowType,
3299        centre: bool,
3300    ) -> SpectrogramResult<Self> {
3301        if hop_size.get() > n_fft.get() {
3302            return Err(SpectrogramError::invalid_input("hop_size must be <= n_fft"));
3303        }
3304
3305        Ok(Self {
3306            n_fft,
3307            hop_size,
3308            window,
3309            centre,
3310        })
3311    }
3312
3313    const unsafe fn new_unchecked(
3314        n_fft: NonZeroUsize,
3315        hop_size: NonZeroUsize,
3316        window: WindowType,
3317        centre: bool,
3318    ) -> Self {
3319        Self {
3320            n_fft,
3321            hop_size,
3322            window,
3323            centre,
3324        }
3325    }
3326
3327    /// Get the FFT window size.
3328    ///
3329    /// # Returns
3330    ///
3331    /// The FFT window size.
3332    #[inline]
3333    #[must_use]
3334    pub const fn n_fft(&self) -> NonZeroUsize {
3335        self.n_fft
3336    }
3337
3338    /// Get the hop size (samples between successive frames).
3339    ///
3340    /// # Returns
3341    ///
3342    /// The hop size.
3343    #[inline]
3344    #[must_use]
3345    pub const fn hop_size(&self) -> NonZeroUsize {
3346        self.hop_size
3347    }
3348
3349    /// Get the window function.
3350    ///
3351    /// # Returns
3352    ///
3353    /// The window function.
3354    #[inline]
3355    #[must_use]
3356    pub const fn window(&self) -> WindowType {
3357        self.window
3358    }
3359
3360    /// Get whether frames are centered (input signal is padded).
3361    ///
3362    /// # Returns
3363    ///
3364    /// `true` if frames are centered, `false` otherwise.
3365    #[inline]
3366    #[must_use]
3367    pub const fn centre(&self) -> bool {
3368        self.centre
3369    }
3370
3371    /// Create a builder for STFT parameters.
3372    ///
3373    /// # Returns
3374    ///
3375    /// A `StftParamsBuilder` instance.
3376    ///
3377    /// # Examples
3378    ///
3379    /// ```
3380    /// use spectrograms::{StftParams, WindowType, nzu};
3381    ///
3382    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
3383    /// let stft = StftParams::builder()
3384    ///     .n_fft(nzu!(2048))
3385    ///     .hop_size(nzu!(512))
3386    ///     .window(WindowType::Hanning)
3387    ///     .centre(true)
3388    ///     .build()?;
3389    ///
3390    /// assert_eq!(stft.n_fft(), nzu!(2048));
3391    /// assert_eq!(stft.hop_size(), nzu!(512));
3392    /// # Ok(())
3393    /// # }
3394    /// ```
3395    #[inline]
3396    #[must_use]
3397    pub fn builder() -> StftParamsBuilder {
3398        StftParamsBuilder::default()
3399    }
3400}
3401
3402/// Builder for [`StftParams`].
3403#[derive(Debug, Clone)]
3404pub struct StftParamsBuilder {
3405    n_fft: Option<NonZeroUsize>,
3406    hop_size: Option<NonZeroUsize>,
3407    window: WindowType,
3408    centre: bool,
3409}
3410
3411impl Default for StftParamsBuilder {
3412    #[inline]
3413    fn default() -> Self {
3414        Self {
3415            n_fft: None,
3416            hop_size: None,
3417            window: WindowType::Hanning,
3418            centre: true,
3419        }
3420    }
3421}
3422
3423impl StftParamsBuilder {
3424    /// Set the FFT window size.
3425    ///
3426    /// # Arguments
3427    ///
3428    /// * `n_fft` - Size of the FFT window
3429    ///
3430    /// # Returns
3431    ///
3432    /// The builder with the updated FFT window size.
3433    #[inline]
3434    #[must_use]
3435    pub const fn n_fft(mut self, n_fft: NonZeroUsize) -> Self {
3436        self.n_fft = Some(n_fft);
3437        self
3438    }
3439
3440    /// Set the hop size (samples between successive frames).
3441    ///
3442    /// # Arguments
3443    ///
3444    /// * `hop_size` - Number of samples between successive frames
3445    ///
3446    /// # Returns
3447    ///
3448    /// The builder with the updated hop size.
3449    #[inline]
3450    #[must_use]
3451    pub const fn hop_size(mut self, hop_size: NonZeroUsize) -> Self {
3452        self.hop_size = Some(hop_size);
3453        self
3454    }
3455
3456    /// Set the window function.
3457    ///
3458    /// # Arguments
3459    ///
3460    /// * `window` - Window function to apply to each frame
3461    ///
3462    /// # Returns
3463    ///
3464    /// The builder with the updated window function.
3465    #[inline]
3466    #[must_use]
3467    pub const fn window(mut self, window: WindowType) -> Self {
3468        self.window = window;
3469        self
3470    }
3471
3472    /// Set whether to center frames (pad input signal).
3473    #[inline]
3474    #[must_use]
3475    pub const fn centre(mut self, centre: bool) -> Self {
3476        self.centre = centre;
3477        self
3478    }
3479
3480    /// Build the [`StftParams`].
3481    ///
3482    /// # Errors
3483    ///
3484    /// Returns an error if:
3485    /// - `n_fft` or `hop_size` are not set or are zero
3486    /// - `hop_size` > `n_fft`
3487    #[inline]
3488    pub fn build(self) -> SpectrogramResult<StftParams> {
3489        let n_fft = self
3490            .n_fft
3491            .ok_or_else(|| SpectrogramError::invalid_input("n_fft must be set"))?;
3492        let hop_size = self
3493            .hop_size
3494            .ok_or_else(|| SpectrogramError::invalid_input("hop_size must be set"))?;
3495
3496        StftParams::new(n_fft, hop_size, self.window, self.centre)
3497    }
3498}
3499
3500//
3501// ========================
3502// Mel parameters
3503// ========================
3504//
3505
3506/// Mel filterbank normalization strategy.
3507///
3508/// Determines how the triangular mel filters are normalized after construction.
3509#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3510#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3511pub enum MelNorm {
3512    /// No normalization (triangular filters with peak = 1.0).
3513    ///
3514    /// This is the default and fastest option.
3515    None,
3516
3517    /// Slaney-style area normalization (librosa default).
3518    ///
3519    /// Each mel filter is divided by its bandwidth in Hz: `2.0 / (f_max - f_min)`.
3520    /// This ensures constant energy per mel band regardless of bandwidth.
3521    ///
3522    /// Use this for compatibility with librosa's default behavior.
3523    Slaney,
3524
3525    /// L1 normalization (sum of weights = 1.0).
3526    ///
3527    /// Each mel filter's weights are divided by their sum.
3528    /// Useful when you want each filter to act as a weighted average.
3529    L1,
3530
3531    /// L2 normalization (Euclidean norm = 1.0).
3532    ///
3533    /// Each mel filter's weights are divided by their L2 norm.
3534    /// Provides unit-norm filters in the L2 sense.
3535    L2,
3536}
3537
3538impl Default for MelNorm {
3539    fn default() -> Self {
3540        Self::None
3541    }
3542}
3543
3544/// Mel filter bank parameters
3545///
3546/// * `n_mels`: Number of mel bands
3547/// * `f_min`: Minimum frequency (Hz)
3548/// * `f_max`: Maximum frequency (Hz)
3549/// * `norm`: Filterbank normalization strategy
3550#[derive(Debug, Clone, Copy, PartialEq)]
3551#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3552pub struct MelParams {
3553    n_mels: NonZeroUsize,
3554    f_min: f64,
3555    f_max: f64,
3556    norm: MelNorm,
3557}
3558
3559impl MelParams {
3560    /// Create new mel filter bank parameters.
3561    ///
3562    /// # Arguments
3563    ///
3564    /// * `n_mels` - Number of mel bands
3565    /// * `f_min` - Minimum frequency (Hz)
3566    /// * `f_max` - Maximum frequency (Hz)
3567    ///
3568    /// # Errors
3569    ///
3570    /// Returns an error if:
3571    /// - `f_min` is not >= 0
3572    /// - `f_max` is not > `f_min`
3573    ///
3574    /// # Returns
3575    ///
3576    /// A `MelParams` instance with no normalization (default).
3577    #[inline]
3578    pub fn new(n_mels: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
3579        Self::with_norm(n_mels, f_min, f_max, MelNorm::None)
3580    }
3581
3582    /// Create new mel filter bank parameters with specified normalization.
3583    ///
3584    /// # Arguments
3585    ///
3586    /// * `n_mels` - Number of mel bands
3587    /// * `f_min` - Minimum frequency (Hz)
3588    /// * `f_max` - Maximum frequency (Hz)
3589    /// * `norm` - Filterbank normalization strategy
3590    ///
3591    /// # Errors
3592    ///
3593    /// Returns an error if:
3594    /// - `f_min` is not >= 0
3595    /// - `f_max` is not > `f_min`
3596    ///
3597    /// # Returns
3598    ///
3599    /// A `MelParams` instance.
3600    #[inline]
3601    pub fn with_norm(
3602        n_mels: NonZeroUsize,
3603        f_min: f64,
3604        f_max: f64,
3605        norm: MelNorm,
3606    ) -> SpectrogramResult<Self> {
3607        if f_min < 0.0 {
3608            return Err(SpectrogramError::invalid_input("f_min must be >= 0"));
3609        }
3610
3611        if f_max <= f_min {
3612            return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
3613        }
3614
3615        Ok(Self {
3616            n_mels,
3617            f_min,
3618            f_max,
3619            norm,
3620        })
3621    }
3622
3623    const unsafe fn new_unchecked(n_mels: NonZeroUsize, f_min: f64, f_max: f64) -> Self {
3624        Self {
3625            n_mels,
3626            f_min,
3627            f_max,
3628            norm: MelNorm::None,
3629        }
3630    }
3631
3632    /// Get the number of mel bands.
3633    ///
3634    /// # Returns
3635    ///
3636    /// The number of mel bands.
3637    #[inline]
3638    #[must_use]
3639    pub const fn n_mels(&self) -> NonZeroUsize {
3640        self.n_mels
3641    }
3642
3643    /// Get the minimum frequency (Hz).
3644    ///
3645    /// # Returns
3646    ///
3647    /// The minimum frequency in Hz.
3648    #[inline]
3649    #[must_use]
3650    pub const fn f_min(&self) -> f64 {
3651        self.f_min
3652    }
3653
3654    /// Get the maximum frequency (Hz).
3655    ///
3656    /// # Returns
3657    ///
3658    /// The maximum frequency in Hz.
3659    #[inline]
3660    #[must_use]
3661    pub const fn f_max(&self) -> f64 {
3662        self.f_max
3663    }
3664
3665    /// Get the filterbank normalization strategy.
3666    ///
3667    /// # Returns
3668    ///
3669    /// The normalization strategy.
3670    #[inline]
3671    #[must_use]
3672    pub const fn norm(&self) -> MelNorm {
3673        self.norm
3674    }
3675
3676    /// Create standard mel filterbank parameters.
3677    ///
3678    /// Uses 128 mel bands from 0 Hz to the Nyquist frequency.
3679    ///
3680    /// # Arguments
3681    ///
3682    /// * `sample_rate` - Sample rate in Hz (used to determine `f_max`)
3683    ///
3684    /// # Returns
3685    ///
3686    /// A `MelParams` instance with standard settings.
3687    ///
3688    /// # Panics
3689    ///
3690    /// Panics if `sample_rate` is not greater than 0.
3691    #[inline]
3692    #[must_use]
3693    pub const fn standard(sample_rate: f64) -> Self {
3694        assert!(sample_rate > 0.0);
3695        // safety: parameters are known to be valid
3696        unsafe { Self::new_unchecked(nzu!(128), 0.0, sample_rate / 2.0) }
3697    }
3698
3699    /// Create mel filterbank parameters optimized for speech.
3700    ///
3701    /// Uses 40 mel bands from 0 Hz to 8000 Hz (typical speech bandwidth).
3702    ///
3703    /// # Returns
3704    ///
3705    /// A `MelParams` instance with speech-optimized settings.
3706    #[inline]
3707    #[must_use]
3708    pub const fn speech_standard() -> Self {
3709        // safety: parameters are known to be valid
3710        unsafe { Self::new_unchecked(nzu!(40), 0.0, 8000.0) }
3711    }
3712}
3713
3714//
3715// ========================
3716// LogHz parameters
3717// ========================
3718//
3719
3720/// Logarithmic frequency scale parameters
3721///
3722/// * `n_bins`: Number of logarithmically-spaced frequency bins
3723/// * `f_min`: Minimum frequency (Hz)
3724/// * `f_max`: Maximum frequency (Hz)
3725#[derive(Debug, Clone, Copy, PartialEq)]
3726#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3727pub struct LogHzParams {
3728    n_bins: NonZeroUsize,
3729    f_min: f64,
3730    f_max: f64,
3731}
3732
3733impl LogHzParams {
3734    /// Create new logarithmic frequency scale parameters.
3735    ///
3736    /// # Arguments
3737    ///
3738    /// * `n_bins` - Number of logarithmically-spaced frequency bins
3739    /// * `f_min` - Minimum frequency (Hz)
3740    /// * `f_max` - Maximum frequency (Hz)
3741    ///
3742    /// # Errors
3743    ///
3744    /// Returns an error if:
3745    /// - `f_min` is not finite and > 0
3746    /// - `f_max` is not > `f_min`
3747    ///
3748    /// # Returns
3749    ///
3750    /// A `LogHzParams` instance.
3751    #[inline]
3752    pub fn new(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
3753        if !(f_min > 0.0 && f_min.is_finite()) {
3754            return Err(SpectrogramError::invalid_input(
3755                "f_min must be finite and > 0",
3756            ));
3757        }
3758
3759        if f_max <= f_min {
3760            return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
3761        }
3762
3763        Ok(Self {
3764            n_bins,
3765            f_min,
3766            f_max,
3767        })
3768    }
3769
3770    const unsafe fn new_unchecked(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> Self {
3771        Self {
3772            n_bins,
3773            f_min,
3774            f_max,
3775        }
3776    }
3777
3778    /// Get the number of frequency bins.
3779    ///
3780    /// # Returns
3781    ///
3782    /// The number of frequency bins.
3783    #[inline]
3784    #[must_use]
3785    pub const fn n_bins(&self) -> NonZeroUsize {
3786        self.n_bins
3787    }
3788
3789    /// Get the minimum frequency (Hz).
3790    ///
3791    /// # Returns
3792    ///
3793    /// The minimum frequency in Hz.
3794    #[inline]
3795    #[must_use]
3796    pub const fn f_min(&self) -> f64 {
3797        self.f_min
3798    }
3799
3800    /// Get the maximum frequency (Hz).
3801    ///
3802    /// # Returns
3803    ///
3804    /// The maximum frequency in Hz.
3805    #[inline]
3806    #[must_use]
3807    pub const fn f_max(&self) -> f64 {
3808        self.f_max
3809    }
3810
3811    /// Create standard logarithmic frequency parameters.
3812    ///
3813    /// Uses 128 log bins from 20 Hz to the Nyquist frequency.
3814    ///
3815    /// # Arguments
3816    ///
3817    /// * `sample_rate` - Sample rate in Hz (used to determine `f_max`)
3818    #[inline]
3819    #[must_use]
3820    pub fn standard(sample_rate: f64) -> Self {
3821        // safety: parameters are known to be valid
3822        unsafe { Self::new_unchecked(nzu!(128), 20.0, sample_rate / 2.0) }
3823    }
3824
3825    /// Create logarithmic frequency parameters optimized for music.
3826    ///
3827    /// Uses 84 bins (7 octaves * 12 bins/octave) from 27.5 Hz (A0) to 4186 Hz (C8).
3828    #[inline]
3829    #[must_use]
3830    pub const fn music_standard() -> Self {
3831        // safety: parameters are known to be valid
3832        unsafe { Self::new_unchecked(nzu!(84), 27.5, 4186.0) }
3833    }
3834}
3835
3836//
3837// ========================
3838// Log scaling parameters
3839// ========================
3840//
3841
3842#[derive(Debug, Clone, Copy, PartialEq)]
3843#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3844pub struct LogParams {
3845    floor_db: f64,
3846}
3847
3848impl LogParams {
3849    /// Create new logarithmic scaling parameters.
3850    ///
3851    /// # Arguments
3852    ///
3853    /// * `floor_db` - Minimum dB value (floor) for logarithmic scaling
3854    ///
3855    /// # Errors
3856    ///
3857    /// Returns an error if `floor_db` is not finite.
3858    ///
3859    /// # Returns
3860    ///
3861    /// A `LogParams` instance.
3862    #[inline]
3863    pub fn new(floor_db: f64) -> SpectrogramResult<Self> {
3864        if !floor_db.is_finite() {
3865            return Err(SpectrogramError::invalid_input("floor_db must be finite"));
3866        }
3867
3868        Ok(Self { floor_db })
3869    }
3870
3871    /// Get the floor dB value.
3872    #[inline]
3873    #[must_use]
3874    pub const fn floor_db(&self) -> f64 {
3875        self.floor_db
3876    }
3877}
3878
3879/// Spectrogram computation parameters.
3880///
3881/// * `stft`: STFT parameters
3882/// * `sample_rate_hz`: Sample rate in Hz
3883#[derive(Debug, Copy, Clone)]
3884#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3885pub struct SpectrogramParams {
3886    stft: StftParams,
3887    sample_rate_hz: f64,
3888}
3889
3890impl SpectrogramParams {
3891    /// Create new spectrogram parameters.
3892    ///
3893    /// # Arguments
3894    ///
3895    /// * `stft` - STFT parameters
3896    /// * `sample_rate_hz` - Sample rate in Hz
3897    ///
3898    /// # Errors
3899    ///
3900    /// Returns an error if the sample rate is not positive and finite.
3901    ///
3902    /// # Returns
3903    ///
3904    /// A `SpectrogramParams` instance.
3905    #[inline]
3906    pub fn new(stft: StftParams, sample_rate_hz: f64) -> SpectrogramResult<Self> {
3907        if !(sample_rate_hz > 0.0 && sample_rate_hz.is_finite()) {
3908            return Err(SpectrogramError::invalid_input(
3909                "sample_rate_hz must be finite and > 0",
3910            ));
3911        }
3912
3913        Ok(Self {
3914            stft,
3915            sample_rate_hz,
3916        })
3917    }
3918
3919    /// Create a builder for spectrogram parameters.
3920    ///
3921    /// # Errors
3922    ///
3923    /// Returns an error if required parameters are not set or are invalid.
3924    ///
3925    /// # Returns
3926    ///
3927    /// A builder for [`SpectrogramParams`].
3928    ///
3929    /// # Examples
3930    ///
3931    /// ```
3932    /// use spectrograms::{SpectrogramParams, WindowType, nzu};
3933    ///
3934    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
3935    /// let params = SpectrogramParams::builder()
3936    ///     .sample_rate(16000.0)
3937    ///     .n_fft(nzu!(512))
3938    ///     .hop_size(nzu!(256))
3939    ///     .window(WindowType::Hanning)
3940    ///     .centre(true)
3941    ///     .build()?;
3942    ///
3943    /// assert_eq!(params.sample_rate_hz(), 16000.0);
3944    /// # Ok(())
3945    /// # }
3946    /// ```
3947    #[inline]
3948    #[must_use]
3949    pub fn builder() -> SpectrogramParamsBuilder {
3950        SpectrogramParamsBuilder::default()
3951    }
3952
3953    /// Create default parameters for speech processing.
3954    ///
3955    /// # Arguments
3956    ///
3957    /// * `sample_rate_hz` - Sample rate in Hz
3958    ///
3959    /// # Returns
3960    ///
3961    /// A `SpectrogramParams` instance with default settings for music analysis.
3962    ///
3963    /// # Errors
3964    ///
3965    /// Returns an error if the sample rate is not positive and finite.
3966    ///
3967    /// Uses:
3968    /// - `n_fft`: 512 (32ms at 16kHz)
3969    /// - `hop_size`: 160 (10ms at 16kHz)
3970    /// - window: Hanning
3971    /// - centre: true
3972    #[inline]
3973    pub fn speech_default(sample_rate_hz: f64) -> SpectrogramResult<Self> {
3974        // safety: parameters are known to be valid
3975        let stft =
3976            unsafe { StftParams::new_unchecked(nzu!(512), nzu!(160), WindowType::Hanning, true) };
3977
3978        Self::new(stft, sample_rate_hz)
3979    }
3980
3981    /// Create default parameters for music processing.
3982    ///
3983    /// # Arguments
3984    ///
3985    /// * `sample_rate_hz` - Sample rate in Hz
3986    ///
3987    /// # Returns
3988    ///
3989    /// A `SpectrogramParams` instance with default settings for music analysis.
3990    ///
3991    /// # Errors
3992    ///
3993    /// Returns an error if the sample rate is not positive and finite.
3994    ///
3995    /// Uses:
3996    /// - `n_fft`: 2048 (46ms at 44.1kHz)
3997    /// - `hop_size`: 512 (11.6ms at 44.1kHz)
3998    /// - window: Hanning
3999    /// - centre: true
4000    #[inline]
4001    pub fn music_default(sample_rate_hz: f64) -> SpectrogramResult<Self> {
4002        // safety: parameters are known to be valid
4003        let stft =
4004            unsafe { StftParams::new_unchecked(nzu!(2048), nzu!(512), WindowType::Hanning, true) };
4005        Self::new(stft, sample_rate_hz)
4006    }
4007
4008    /// Get the STFT parameters.
4009    #[inline]
4010    #[must_use]
4011    pub const fn stft(&self) -> &StftParams {
4012        &self.stft
4013    }
4014
4015    /// Get the sample rate in Hz.
4016    #[inline]
4017    #[must_use]
4018    pub const fn sample_rate_hz(&self) -> f64 {
4019        self.sample_rate_hz
4020    }
4021
4022    /// Get the frame period in seconds.
4023    #[inline]
4024    #[must_use]
4025    #[allow(clippy::cast_precision_loss)]
4026    pub fn frame_period_seconds(&self) -> f64 {
4027        self.stft.hop_size().get() as f64 / self.sample_rate_hz
4028    }
4029
4030    /// Get the Nyquist frequency in Hz.
4031    #[inline]
4032    #[must_use]
4033    pub fn nyquist_hz(&self) -> f64 {
4034        self.sample_rate_hz * 0.5
4035    }
4036}
4037
4038/// Builder for [`SpectrogramParams`].
4039#[derive(Debug, Clone)]
4040pub struct SpectrogramParamsBuilder {
4041    sample_rate: Option<f64>,
4042    n_fft: Option<NonZeroUsize>,
4043    hop_size: Option<NonZeroUsize>,
4044    window: WindowType,
4045    centre: bool,
4046}
4047
4048impl Default for SpectrogramParamsBuilder {
4049    #[inline]
4050    fn default() -> Self {
4051        Self {
4052            sample_rate: None,
4053            n_fft: None,
4054            hop_size: None,
4055            window: WindowType::Hanning,
4056            centre: true,
4057        }
4058    }
4059}
4060
4061impl SpectrogramParamsBuilder {
4062    /// Set the sample rate in Hz.
4063    ///
4064    /// # Arguments
4065    ///
4066    /// * `sample_rate` - Sample rate in Hz.
4067    ///
4068    /// # Returns
4069    ///
4070    /// The updated builder instance.
4071    #[inline]
4072    #[must_use]
4073    pub const fn sample_rate(mut self, sample_rate: f64) -> Self {
4074        self.sample_rate = Some(sample_rate);
4075        self
4076    }
4077
4078    /// Set the FFT window size.
4079    ///
4080    /// # Arguments
4081    ///
4082    /// * `n_fft` - FFT size.
4083    ///
4084    /// # Returns
4085    ///
4086    /// The updated builder instance.
4087    #[inline]
4088    #[must_use]
4089    pub const fn n_fft(mut self, n_fft: NonZeroUsize) -> Self {
4090        self.n_fft = Some(n_fft);
4091        self
4092    }
4093
4094    /// Set the hop size (samples between successive frames).
4095    ///
4096    /// # Arguments
4097    ///
4098    /// * `hop_size` - Hop size in samples.
4099    ///
4100    /// # Returns
4101    ///
4102    /// The updated builder instance.
4103    #[inline]
4104    #[must_use]
4105    pub const fn hop_size(mut self, hop_size: NonZeroUsize) -> Self {
4106        self.hop_size = Some(hop_size);
4107        self
4108    }
4109
4110    /// Set the window function.
4111    ///
4112    /// # Arguments
4113    ///
4114    /// * `window` - Window function to apply to each frame.
4115    ///
4116    /// # Returns
4117    ///
4118    /// The updated builder instance.
4119    #[inline]
4120    #[must_use]
4121    pub const fn window(mut self, window: WindowType) -> Self {
4122        self.window = window;
4123        self
4124    }
4125
4126    /// Set whether to center frames (pad input signal).
4127    ///
4128    /// # Arguments
4129    ///
4130    /// * `centre` - If true, frames are centered by padding the input signal.
4131    ///
4132    /// # Returns
4133    ///
4134    /// The updated builder instance.
4135    #[inline]
4136    #[must_use]
4137    pub const fn centre(mut self, centre: bool) -> Self {
4138        self.centre = centre;
4139        self
4140    }
4141
4142    /// Build the [`SpectrogramParams`].
4143    ///
4144    /// # Errors
4145    ///
4146    /// Returns an error if required parameters are not set or are invalid.
4147    ///
4148    /// # Returns
4149    ///
4150    /// A `SpectrogramParams` instance.
4151    #[inline]
4152    pub fn build(self) -> SpectrogramResult<SpectrogramParams> {
4153        let sample_rate = self
4154            .sample_rate
4155            .ok_or_else(|| SpectrogramError::invalid_input("sample_rate must be set"))?;
4156        let n_fft = self
4157            .n_fft
4158            .ok_or_else(|| SpectrogramError::invalid_input("n_fft must be set"))?;
4159        let hop_size = self
4160            .hop_size
4161            .ok_or_else(|| SpectrogramError::invalid_input("hop_size must be set"))?;
4162
4163        let stft = StftParams::new(n_fft, hop_size, self.window, self.centre)?;
4164        SpectrogramParams::new(stft, sample_rate)
4165    }
4166}
4167
4168//
4169// ========================
4170// Standalone FFT Functions
4171// ========================
4172//
4173
4174/// Compute the real-to-complex FFT of a real-valued signal.
4175///
4176/// This function performs a forward FFT on real-valued input, returning the
4177/// complex frequency domain representation. Only the positive frequencies
4178/// are returned (length = `n_fft/2` + 1) due to conjugate symmetry.
4179///
4180/// # Arguments
4181///
4182/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4183/// * `n_fft` - FFT size
4184///
4185/// # Returns
4186///
4187/// A vector of complex frequency bins with length `n_fft/2` + 1.
4188///
4189/// # Automatic Zero-Padding
4190///
4191/// If the input signal is shorter than `n_fft`, it will be automatically
4192/// zero-padded to the required length. This is standard DSP practice and
4193/// preserves frequency resolution (bin spacing = sample_rate / n_fft).
4194///
4195/// ```
4196/// use spectrograms::{fft, nzu};
4197/// use non_empty_slice::non_empty_vec;
4198///
4199/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4200/// let signal = non_empty_vec![1.0, 2.0, 3.0]; // Only 3 samples
4201/// let spectrum = fft(&signal, nzu!(8))?;   // Automatically padded to 8
4202/// assert_eq!(spectrum.len(), 5);     // Output: 8/2 + 1 = 5 bins
4203/// # Ok(())
4204/// # }
4205/// ```
4206///
4207/// # Errors
4208///
4209/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4210///
4211/// # Examples
4212///
4213/// ```
4214/// use spectrograms::*;
4215/// use non_empty_slice::non_empty_vec;
4216///
4217/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4218/// let signal = non_empty_vec![0.0; nzu!(512)];
4219/// let spectrum = fft(&signal, nzu!(512))?;
4220///
4221/// assert_eq!(spectrum.len(), 257); // 512/2 + 1
4222/// # Ok(())
4223/// # }
4224/// ```
4225#[inline]
4226pub fn fft(
4227    samples: &NonEmptySlice<f64>,
4228    n_fft: NonZeroUsize,
4229) -> SpectrogramResult<Array1<Complex<f64>>> {
4230    if samples.len() > n_fft {
4231        return Err(SpectrogramError::invalid_input(format!(
4232            "Input length ({}) exceeds FFT size ({})",
4233            samples.len(),
4234            n_fft
4235        )));
4236    }
4237
4238    let out_len = r2c_output_size(n_fft.get());
4239
4240    // Get FFT plan from global cache (or create if first use)
4241    #[cfg(feature = "realfft")]
4242    let mut fft = {
4243        use crate::fft_backend::get_or_create_r2c_plan;
4244        let plan = get_or_create_r2c_plan(n_fft.get())?;
4245        // Clone the plan to get our own mutable copy with independent scratch buffer
4246        // This is cheap - only clones the scratch buffer, not the expensive twiddle factors
4247        (*plan).clone()
4248    };
4249
4250    #[cfg(feature = "fftw")]
4251    let mut fft = {
4252        use std::sync::Arc;
4253        let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
4254        crate::FftwPlan::new(Arc::new(plan))
4255    };
4256
4257    let input = if samples.len() < n_fft {
4258        let mut padded = vec![0.0; n_fft.get()];
4259        padded[..samples.len().get()].copy_from_slice(samples);
4260        // safety: samples.len() < n_fft checked above and n_fft > 0
4261
4262        unsafe { NonEmptyVec::new_unchecked(padded) }
4263    } else {
4264        samples.to_non_empty_vec()
4265    };
4266
4267    let mut output = vec![Complex::new(0.0, 0.0); out_len];
4268    fft.process(&input, &mut output)?;
4269    let output = Array1::from_vec(output);
4270    Ok(output)
4271}
4272
4273#[inline]
4274/// Compute the real-valued fft of a signal.
4275///
4276/// # Arguments
4277/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4278/// * `n_fft` - FFT size
4279///
4280/// # Returns
4281///
4282/// An array with length `n_fft/2` + 1.
4283///
4284/// # Errors
4285///
4286/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4287///
4288/// # Examples
4289///
4290/// ```
4291/// use spectrograms::*;
4292/// use non_empty_slice::non_empty_vec;
4293///
4294/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4295/// let signal = non_empty_vec![0.0; nzu!(512)];
4296/// let rfft_result = rfft(&signal, nzu!(512))?;
4297/// // equivalent to
4298/// let fft_result = fft(&signal, nzu!(512))?;
4299/// let rfft_result = fft_result.mapv(num_complex::Complex::norm);
4300/// # Ok(())
4301/// # }
4302///
4303pub fn rfft(samples: &NonEmptySlice<f64>, n_fft: NonZeroUsize) -> SpectrogramResult<Array1<f64>> {
4304    Ok(fft(samples, n_fft)?.mapv(Complex::norm))
4305}
4306
4307/// Compute the power spectrum of a signal (|X|²).
4308///
4309/// This function applies an optional window function and computes the
4310/// power spectrum via FFT. The result contains only positive frequencies.
4311///
4312/// # Arguments
4313///
4314/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4315/// * `n_fft` - FFT size
4316/// * `window` - Optional window function (None for rectangular window)
4317///
4318/// # Returns
4319///
4320/// A vector of power values with length `n_fft/2` + 1.
4321///
4322/// # Automatic Zero-Padding
4323///
4324/// If the input signal is shorter than `n_fft`, it will be automatically
4325/// zero-padded to the required length. This is standard DSP practice and
4326/// preserves frequency resolution (bin spacing = sample_rate / n_fft).
4327///
4328/// ```
4329/// use spectrograms::{power_spectrum, nzu};
4330/// use non_empty_slice::non_empty_vec;
4331///
4332/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4333/// let signal = non_empty_vec![1.0, 2.0, 3.0]; // Only 3 samples
4334/// let power = power_spectrum(&signal, nzu!(8), None)?;
4335/// assert_eq!(power.len(), nzu!(5));     // Output: 8/2 + 1 = 5 bins
4336/// # Ok(())
4337/// # }
4338/// ```
4339///
4340/// # Errors
4341///
4342/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4343///
4344/// # Examples
4345///
4346/// ```
4347/// use spectrograms::*;
4348/// use non_empty_slice::non_empty_vec;
4349///
4350/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4351/// let signal = non_empty_vec![0.0; nzu!(512)];
4352/// let power = power_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4353///
4354/// assert_eq!(power.len(), nzu!(257)); // 512/2 + 1
4355/// # Ok(())
4356/// # }
4357/// ```
4358#[inline]
4359pub fn power_spectrum(
4360    samples: &NonEmptySlice<f64>,
4361    n_fft: NonZeroUsize,
4362    window: Option<WindowType>,
4363) -> SpectrogramResult<NonEmptyVec<f64>> {
4364    if samples.len() > n_fft {
4365        return Err(SpectrogramError::invalid_input(format!(
4366            "Input length ({}) exceeds FFT size ({})",
4367            samples.len(),
4368            n_fft
4369        )));
4370    }
4371
4372    let mut windowed = vec![0.0; n_fft.get()];
4373    windowed[..samples.len().get()].copy_from_slice(samples);
4374
4375    if let Some(win_type) = window {
4376        let window_samples = make_window(win_type, n_fft);
4377        for i in 0..n_fft.get() {
4378            windowed[i] *= window_samples[i];
4379        }
4380    }
4381
4382    // safety: windowed is non-empty since n_fft > 0
4383    let windowed = unsafe { NonEmptySlice::new_unchecked(&windowed) };
4384    let fft_result = fft(windowed, n_fft)?;
4385    let fft_result = fft_result
4386        .iter()
4387        .map(num_complex::Complex::norm_sqr)
4388        .collect();
4389    // safety: fft_result is non-empty since fft returned successfully
4390    Ok(unsafe { NonEmptyVec::new_unchecked(fft_result) })
4391}
4392
4393/// Compute the magnitude spectrum of a signal (|X|).
4394///
4395/// This function applies an optional window function and computes the
4396/// magnitude spectrum via FFT. The result contains only positive frequencies.
4397///
4398/// # Arguments
4399///
4400/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4401/// * `n_fft` - FFT size
4402/// * `window` - Optional window function (None for rectangular window)
4403///
4404/// # Automatic Zero-Padding
4405///
4406/// If the input signal is shorter than `n_fft`, it will be automatically
4407/// zero-padded to the required length. This preserves frequency resolution.
4408///
4409/// # Errors
4410///
4411/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4412///
4413/// # Returns
4414///
4415/// A vector of magnitude values with length `n_fft/2` + 1.
4416///
4417/// # Examples
4418///
4419/// ```
4420/// use spectrograms::*;
4421/// use non_empty_slice::non_empty_vec;
4422///
4423/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4424/// let signal = non_empty_vec![0.0; nzu!(512)];
4425/// let magnitude = magnitude_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4426///
4427/// assert_eq!(magnitude.len(), nzu!(257)); // 512/2 + 1
4428/// # Ok(())
4429/// # }
4430/// ```
4431#[inline]
4432pub fn magnitude_spectrum(
4433    samples: &NonEmptySlice<f64>,
4434    n_fft: NonZeroUsize,
4435    window: Option<WindowType>,
4436) -> SpectrogramResult<NonEmptyVec<f64>> {
4437    let power = power_spectrum(samples, n_fft, window)?;
4438    let power = power.iter().map(|&p| p.sqrt()).collect();
4439    // safety: power is non-empty since power_spectrum returned successfully
4440    Ok(unsafe { NonEmptyVec::new_unchecked(power) })
4441}
4442
4443/// Compute the Short-Time Fourier Transform (STFT) of a signal.
4444///
4445/// This function computes the STFT by applying a sliding window and FFT
4446/// to sequential frames of the input signal.
4447///
4448/// # Arguments
4449///
4450/// * `samples` - Input signal (any type that can be converted to a slice)
4451/// * `n_fft` - FFT size
4452/// * `hop_size` - Number of samples between successive frames
4453/// * `window` - Window function to apply to each frame
4454/// * `center` - If true, pad the signal to center frames
4455///
4456/// # Returns
4457///
4458/// A 2D array with shape (`frequency_bins`, `time_frames`) containing complex STFT values.
4459///
4460/// # Errors
4461///
4462/// Returns an error if:
4463/// - `hop_size` > `n_fft`
4464/// - STFT computation fails
4465///
4466/// # Examples
4467///
4468/// ```
4469/// use spectrograms::*;
4470/// use non_empty_slice::non_empty_vec;
4471///
4472/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4473/// let signal = non_empty_vec![0.0; nzu!(16000)];
4474/// let stft_matrix = stft(&signal, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4475///
4476/// println!("STFT: {} bins x {} frames", stft_matrix.nrows(), stft_matrix.ncols());
4477/// # Ok(())
4478/// # }
4479/// ```
4480#[inline]
4481pub fn stft(
4482    samples: &NonEmptySlice<f64>,
4483    n_fft: NonZeroUsize,
4484    hop_size: NonZeroUsize,
4485    window: WindowType,
4486    center: bool,
4487) -> SpectrogramResult<Array2<Complex<f64>>> {
4488    let stft_params = StftParams::new(n_fft, hop_size, window, center)?;
4489    let params = SpectrogramParams::new(stft_params, 1.0)?; // dummy sample rate
4490
4491    let planner = SpectrogramPlanner::new();
4492    let result = planner.compute_stft(samples, &params)?;
4493
4494    Ok(result.data)
4495}
4496
4497/// Compute the inverse real FFT (complex-to-real IFFT).
4498///
4499/// This function performs an inverse FFT, converting frequency domain data
4500/// back to the time domain. Only the positive frequencies need to be provided
4501/// (length = `n_fft/2` + 1) due to conjugate symmetry.
4502///
4503/// # Arguments
4504///
4505/// * `spectrum` - Complex frequency bins (length should be `n_fft/2` + 1)
4506/// * `n_fft` - FFT size (length of the output signal)
4507///
4508/// # Returns
4509///
4510/// A vector of real-valued time-domain samples with length `n_fft`.
4511///
4512/// # Errors
4513///
4514/// Returns an error if:
4515/// - `spectrum` length doesn't match `n_fft/2` + 1
4516/// - Inverse FFT computation fails
4517///
4518/// # Examples
4519///
4520/// ```
4521/// use spectrograms::*;
4522/// use non_empty_slice::{non_empty_vec, NonEmptySlice};
4523/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4524/// // Forward FFT
4525/// let signal = non_empty_vec![1.0, 0.0, -1.0, 0.0, 1.0, 0.0, -1.0, 0.0];
4526/// let spectrum = fft(&signal, nzu!(8))?;
4527/// let slice = spectrum.as_slice().unwrap();
4528/// let spectrum_slice = NonEmptySlice::new(slice).unwrap();
4529/// // Inverse FFT
4530/// let reconstructed = irfft(spectrum_slice, nzu!(8))?;
4531///
4532/// assert_eq!(reconstructed.len(), nzu!(8));
4533/// # Ok(())
4534/// # }
4535/// ```
4536#[inline]
4537pub fn irfft(
4538    spectrum: &NonEmptySlice<Complex<f64>>,
4539    n_fft: NonZeroUsize,
4540) -> SpectrogramResult<NonEmptyVec<f64>> {
4541    use crate::fft_backend::{C2rPlan, r2c_output_size};
4542
4543    let n_fft = n_fft.get();
4544    let expected_len = r2c_output_size(n_fft);
4545    if spectrum.len().get() != expected_len {
4546        return Err(SpectrogramError::dimension_mismatch(
4547            expected_len,
4548            spectrum.len().get(),
4549        ));
4550    }
4551
4552    // Get inverse FFT plan from global cache (or create if first use)
4553    #[cfg(feature = "realfft")]
4554    let mut ifft = {
4555        use crate::fft_backend::get_or_create_c2r_plan;
4556        let plan = get_or_create_c2r_plan(n_fft)?;
4557        // Clone to get our own mutable copy with independent scratch buffer
4558        (*plan).clone()
4559    };
4560
4561    #[cfg(feature = "fftw")]
4562    let mut ifft = {
4563        use crate::fft_backend::C2rPlanner;
4564        let mut planner = crate::FftwPlanner::new();
4565        planner.plan_c2r(n_fft)?
4566    };
4567
4568    let mut output = vec![0.0; n_fft];
4569    ifft.process(spectrum.as_slice(), &mut output)?;
4570
4571    // Safety: output is non-empty since n_fft > 0
4572    Ok(unsafe { NonEmptyVec::new_unchecked(output) })
4573}
4574
4575/// Reconstruct a time-domain signal from its STFT using overlap-add.
4576///
4577/// This function performs the inverse Short-Time Fourier Transform, converting
4578/// a 2D complex STFT matrix back to a 1D time-domain signal using overlap-add
4579/// synthesis with the specified window function.
4580///
4581/// # Arguments
4582///
4583/// * `stft_matrix` - Complex STFT with shape (`frequency_bins`, `time_frames`)
4584/// * `n_fft` - FFT size
4585/// * `hop_size` - Number of samples between successive frames
4586/// * `window` - Window function to apply (should match forward STFT window)
4587/// * `center` - If true, assume the forward STFT was centered
4588///
4589/// # Returns
4590///
4591/// A vector of reconstructed time-domain samples.
4592///
4593/// # Errors
4594///
4595/// Returns an error if:
4596/// - `stft_matrix` dimensions are inconsistent with `n_fft`
4597/// - `hop_size` > `n_fft`
4598/// - Inverse STFT computation fails
4599///
4600/// # Examples
4601///
4602/// ```
4603/// use spectrograms::*;
4604/// use non_empty_slice::non_empty_vec;
4605///
4606/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4607/// // Generate signal
4608/// let signal = non_empty_vec![1.0; nzu!(16000)];
4609///
4610/// // Forward STFT
4611/// let stft_matrix = stft(&signal, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4612///
4613/// // Inverse STFT
4614/// let reconstructed = istft(&stft_matrix, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4615///
4616/// println!("Original: {} samples", signal.len());
4617/// println!("Reconstructed: {} samples", reconstructed.len());
4618/// # Ok(())
4619/// # }
4620/// ```
4621#[inline]
4622pub fn istft(
4623    stft_matrix: &Array2<Complex<f64>>,
4624    n_fft: NonZeroUsize,
4625    hop_size: NonZeroUsize,
4626    window: WindowType,
4627    center: bool,
4628) -> SpectrogramResult<Vec<f64>> {
4629    use crate::fft_backend::{C2rPlan, C2rPlanner, r2c_output_size};
4630
4631    let n_bins = stft_matrix.nrows();
4632    let n_frames = stft_matrix.ncols();
4633
4634    let expected_bins = r2c_output_size(n_fft.get());
4635    if n_bins != expected_bins {
4636        return Err(SpectrogramError::dimension_mismatch(expected_bins, n_bins));
4637    }
4638    if hop_size.get() > n_fft.get() {
4639        return Err(SpectrogramError::invalid_input("hop_size must be <= n_fft"));
4640    }
4641    // Create inverse FFT plan
4642    #[cfg(feature = "realfft")]
4643    let mut ifft = {
4644        let mut planner = crate::RealFftPlanner::new();
4645        planner.plan_c2r(n_fft.get())?
4646    };
4647
4648    #[cfg(feature = "fftw")]
4649    let mut ifft = {
4650        let mut planner = crate::FftwPlanner::new();
4651        planner.plan_c2r(n_fft.get())?
4652    };
4653
4654    // Generate window
4655    let window_samples = make_window(window, n_fft);
4656    let n_fft = n_fft.get();
4657    let hop_size = hop_size.get();
4658    // Calculate output length
4659    let pad = if center { n_fft / 2 } else { 0 };
4660    let output_len = (n_frames - 1) * hop_size + n_fft;
4661    let unpadded_len = output_len.saturating_sub(2 * pad);
4662
4663    // Allocate output buffer and normalization buffer
4664    let mut output = vec![0.0; output_len];
4665    let mut norm = vec![0.0; output_len];
4666
4667    // Overlap-add synthesis
4668    let mut frame_buffer = vec![Complex::new(0.0, 0.0); n_bins];
4669    let mut time_frame = vec![0.0; n_fft];
4670
4671    for frame_idx in 0..n_frames {
4672        // Extract complex frame from STFT matrix
4673        for bin_idx in 0..n_bins {
4674            frame_buffer[bin_idx] = stft_matrix[[bin_idx, frame_idx]];
4675        }
4676
4677        // Inverse FFT
4678        ifft.process(&frame_buffer, &mut time_frame)?;
4679
4680        // Apply window
4681        for i in 0..n_fft {
4682            time_frame[i] *= window_samples[i];
4683        }
4684
4685        // Overlap-add into output buffer
4686        let start = frame_idx * hop_size;
4687        for i in 0..n_fft {
4688            let pos = start + i;
4689            if pos < output_len {
4690                output[pos] += time_frame[i];
4691                norm[pos] += window_samples[i] * window_samples[i];
4692            }
4693        }
4694    }
4695
4696    // Normalize by window energy
4697    for i in 0..output_len {
4698        if norm[i] > 1e-10 {
4699            output[i] /= norm[i];
4700        }
4701    }
4702
4703    // Remove padding if centered
4704    if center && unpadded_len > 0 {
4705        let start = pad;
4706        let end = start + unpadded_len;
4707        output = output[start..end.min(output_len)].to_vec();
4708    }
4709
4710    Ok(output)
4711}
4712
4713//
4714// ========================
4715// Reusable FFT Plans
4716// ========================
4717//
4718
4719/// A reusable FFT planner for efficient repeated FFT operations.
4720///
4721/// This planner caches FFT plans internally, making repeated FFT operations
4722/// of the same size much more efficient than calling `fft()` repeatedly.
4723///
4724/// # Examples
4725///
4726/// ```
4727/// use spectrograms::*;
4728/// use non_empty_slice::non_empty_vec;
4729///
4730/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4731/// let mut planner = FftPlanner::new();
4732///
4733/// // Process multiple signals of the same size efficiently
4734/// for _ in 0..100 {
4735///     let signal = non_empty_vec![0.0; nzu!(512)];
4736///     let spectrum = planner.fft(&signal, nzu!(512))?;
4737///     // ... process spectrum ...
4738/// }
4739/// # Ok(())
4740/// # }
4741/// ```
4742pub struct FftPlanner {
4743    #[cfg(feature = "realfft")]
4744    inner: crate::RealFftPlanner,
4745    #[cfg(feature = "fftw")]
4746    inner: crate::FftwPlanner,
4747}
4748
4749impl FftPlanner {
4750    /// Create a new FFT planner with empty cache.
4751    #[inline]
4752    #[must_use]
4753    pub fn new() -> Self {
4754        Self {
4755            #[cfg(feature = "realfft")]
4756            inner: crate::RealFftPlanner::new(),
4757            #[cfg(feature = "fftw")]
4758            inner: crate::FftwPlanner::new(),
4759        }
4760    }
4761
4762    /// Compute forward FFT, reusing cached plans.
4763    ///
4764    /// This is more efficient than calling the standalone `fft()` function
4765    /// repeatedly for the same FFT size.
4766    ///
4767    /// # Automatic Zero-Padding
4768    ///
4769    /// If the input signal is shorter than `n_fft`, it will be automatically
4770    /// zero-padded to the required length.
4771    ///
4772    /// # Errors
4773    ///
4774    /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4775    ///
4776    /// # Examples
4777    ///
4778    /// ```
4779    /// use spectrograms::*;
4780    /// use non_empty_slice::non_empty_vec;
4781    ///
4782    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4783    /// let mut planner = FftPlanner::new();
4784    ///
4785    /// let signal = non_empty_vec![1.0; nzu!(512)];
4786    /// let spectrum = planner.fft(&signal, nzu!(512))?;
4787    ///
4788    /// assert_eq!(spectrum.len(), 257); // 512/2 + 1
4789    /// # Ok(())
4790    /// # }
4791    /// ```
4792    #[inline]
4793    pub fn fft(
4794        &mut self,
4795        samples: &NonEmptySlice<f64>,
4796        n_fft: NonZeroUsize,
4797    ) -> SpectrogramResult<Array1<Complex<f64>>> {
4798        use crate::fft_backend::{R2cPlan, R2cPlanner, r2c_output_size};
4799
4800        if samples.len() > n_fft {
4801            return Err(SpectrogramError::invalid_input(format!(
4802                "Input length ({}) exceeds FFT size ({})",
4803                samples.len(),
4804                n_fft
4805            )));
4806        }
4807
4808        let out_len = r2c_output_size(n_fft.get());
4809        let mut plan = self.inner.plan_r2c(n_fft.get())?;
4810
4811        let input = if samples.len() < n_fft {
4812            let mut padded = vec![0.0; n_fft.get()];
4813            padded[..samples.len().get()].copy_from_slice(samples);
4814
4815            // safety: samples.len() < n_fft checked above and n_fft > 0
4816            unsafe { NonEmptyVec::new_unchecked(padded) }
4817        } else {
4818            samples.to_non_empty_vec()
4819        };
4820
4821        let mut output = vec![Complex::new(0.0, 0.0); out_len];
4822        plan.process(&input, &mut output)?;
4823
4824        let output = Array1::from_vec(output);
4825        Ok(output)
4826    }
4827
4828    /// Compute forward real FFT magnitude
4829    ///
4830    /// # Errors
4831    ///
4832    /// Returns an error if:
4833    /// - `n_fft` doesn't match the samples length
4834    ///
4835    ///
4836    #[inline]
4837    pub fn rfft(
4838        &mut self,
4839        samples: &NonEmptySlice<f64>,
4840        n_fft: NonZeroUsize,
4841    ) -> SpectrogramResult<Array1<f64>> {
4842        let fft_with_complex = fft(samples, n_fft)?;
4843        Ok(fft_with_complex.mapv(Complex::norm))
4844    }
4845
4846    /// Compute inverse FFT, reusing cached plans.
4847    ///
4848    /// This is more efficient than calling the standalone `irfft()` function
4849    /// repeatedly for the same FFT size.
4850    ///
4851    /// # Errors
4852    /// Returns an error if:
4853    ///
4854    /// - The calculated expected length of `spectrum` doesn't match its actual length
4855    ///
4856    /// # Examples
4857    ///
4858    /// ```
4859    /// use spectrograms::*;
4860    /// use non_empty_slice::{non_empty_vec, NonEmptySlice};
4861    ///
4862    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4863    /// let mut planner = FftPlanner::new();
4864    ///
4865    /// // Forward FFT
4866    /// let signal = non_empty_vec![1.0; nzu!(512)];
4867    /// let spectrum = planner.fft(&signal, nzu!(512))?;
4868    ///
4869    /// // Inverse FFT
4870    /// let spectrum_slice = NonEmptySlice::new(spectrum.as_slice().unwrap()).unwrap();
4871    /// let reconstructed = planner.irfft(spectrum_slice, nzu!(512))?;
4872    ///
4873    /// assert_eq!(reconstructed.len(), nzu!(512));
4874    /// # Ok(())
4875    /// # }
4876    /// ```
4877    #[inline]
4878    pub fn irfft(
4879        &mut self,
4880        spectrum: &NonEmptySlice<Complex<f64>>,
4881        n_fft: NonZeroUsize,
4882    ) -> SpectrogramResult<NonEmptyVec<f64>> {
4883        use crate::fft_backend::{C2rPlan, C2rPlanner, r2c_output_size};
4884
4885        let expected_len = r2c_output_size(n_fft.get());
4886        if spectrum.len().get() != expected_len {
4887            return Err(SpectrogramError::dimension_mismatch(
4888                expected_len,
4889                spectrum.len().get(),
4890            ));
4891        }
4892
4893        let mut plan = self.inner.plan_c2r(n_fft.get())?;
4894        let mut output = vec![0.0; n_fft.get()];
4895        plan.process(spectrum, &mut output)?;
4896        // Safety: output is non-empty since n_fft > 0
4897        let output = unsafe { NonEmptyVec::new_unchecked(output) };
4898        Ok(output)
4899    }
4900
4901    /// Compute power spectrum with optional windowing, reusing cached plans.
4902    ///
4903    /// # Automatic Zero-Padding
4904    ///
4905    /// If the input signal is shorter than `n_fft`, it will be automatically
4906    /// zero-padded to the required length.
4907    ///
4908    /// # Errors
4909    /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4910    ///
4911    /// # Examples
4912    ///
4913    /// ```
4914    /// use spectrograms::*;
4915    /// use non_empty_slice::non_empty_vec;
4916    ///
4917    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4918    /// let mut planner = FftPlanner::new();
4919    ///
4920    /// let signal = non_empty_vec![1.0; nzu!(512)];
4921    /// let power = planner.power_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4922    ///
4923    /// assert_eq!(power.len(), nzu!(257));
4924    /// # Ok(())
4925    /// # }
4926    /// ```
4927    #[inline]
4928    pub fn power_spectrum(
4929        &mut self,
4930        samples: &NonEmptySlice<f64>,
4931        n_fft: NonZeroUsize,
4932        window: Option<WindowType>,
4933    ) -> SpectrogramResult<NonEmptyVec<f64>> {
4934        if samples.len() > n_fft {
4935            return Err(SpectrogramError::invalid_input(format!(
4936                "Input length ({}) exceeds FFT size ({})",
4937                samples.len(),
4938                n_fft
4939            )));
4940        }
4941
4942        let mut windowed = vec![0.0; n_fft.get()];
4943        windowed[..samples.len().get()].copy_from_slice(samples);
4944        if let Some(win_type) = window {
4945            let window_samples = make_window(win_type, n_fft);
4946            for i in 0..n_fft.get() {
4947                windowed[i] *= window_samples[i];
4948            }
4949        }
4950
4951        // safety: windowed is non-empty since n_fft > 0
4952        let windowed = unsafe { NonEmptySlice::new_unchecked(&windowed) };
4953        let fft_result = self.fft(windowed, n_fft)?;
4954        let f = fft_result
4955            .iter()
4956            .map(num_complex::Complex::norm_sqr)
4957            .collect();
4958        // safety: fft_result is non-empty since fft returned successfully
4959        Ok(unsafe { NonEmptyVec::new_unchecked(f) })
4960    }
4961
4962    /// Compute magnitude spectrum with optional windowing, reusing cached plans.
4963    ///
4964    /// # Automatic Zero-Padding
4965    ///
4966    /// If the input signal is shorter than `n_fft`, it will be automatically
4967    /// zero-padded to the required length.
4968    ///
4969    /// # Errors
4970    /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4971    ///
4972    /// # Examples
4973    ///
4974    /// ```
4975    /// use spectrograms::*;
4976    /// use non_empty_slice::non_empty_vec;
4977    ///
4978    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4979    /// let mut planner = FftPlanner::new();
4980    ///
4981    /// let signal = non_empty_vec![1.0; nzu!(512)];
4982    /// let magnitude = planner.magnitude_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4983    ///
4984    /// assert_eq!(magnitude.len(), nzu!(257));
4985    /// # Ok(())
4986    /// # }
4987    /// ```
4988    #[inline]
4989    pub fn magnitude_spectrum(
4990        &mut self,
4991        samples: &NonEmptySlice<f64>,
4992        n_fft: NonZeroUsize,
4993        window: Option<WindowType>,
4994    ) -> SpectrogramResult<NonEmptyVec<f64>> {
4995        let power = self.power_spectrum(samples, n_fft, window)?;
4996        let power = power.iter().map(|&p| p.sqrt()).collect::<Vec<f64>>();
4997        // safety: power is non-empty since power_spectrum returned successfully
4998        Ok(unsafe { NonEmptyVec::new_unchecked(power) })
4999    }
5000}
5001
5002impl Default for FftPlanner {
5003    #[inline]
5004    fn default() -> Self {
5005        Self::new()
5006    }
5007}
5008
5009#[cfg(test)]
5010mod tests {
5011    use super::*;
5012
5013    #[test]
5014    fn test_sparse_matrix_basic() {
5015        // Create a simple 3x5 sparse matrix
5016        let mut sparse = SparseMatrix::new(3, 5);
5017
5018        // Row 0: only column 1 has value 2.0
5019        sparse.set(0, 1, 2.0);
5020
5021        // Row 1: columns 2 and 3
5022        sparse.set(1, 2, 0.5);
5023        sparse.set(1, 3, 1.5);
5024
5025        // Row 2: columns 0 and 4
5026        sparse.set(2, 0, 3.0);
5027        sparse.set(2, 4, 1.0);
5028
5029        // Test matrix-vector multiplication
5030        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
5031        let mut output = vec![0.0; 3];
5032
5033        sparse.multiply_vec(&input, &mut output);
5034
5035        // Expected results:
5036        // Row 0: 2.0 * 2.0 = 4.0
5037        // Row 1: 0.5 * 3.0 + 1.5 * 4.0 = 1.5 + 6.0 = 7.5
5038        // Row 2: 3.0 * 1.0 + 1.0 * 5.0 = 3.0 + 5.0 = 8.0
5039        assert_eq!(output[0], 4.0);
5040        assert_eq!(output[1], 7.5);
5041        assert_eq!(output[2], 8.0);
5042    }
5043
5044    #[test]
5045    fn test_sparse_matrix_zeros_ignored() {
5046        // Verify that zero values are not stored
5047        let mut sparse = SparseMatrix::new(2, 3);
5048
5049        sparse.set(0, 0, 1.0);
5050        sparse.set(0, 1, 0.0); // Should be ignored
5051        sparse.set(0, 2, 2.0);
5052
5053        // Only 2 values should be stored in row 0
5054        assert_eq!(sparse.values[0].len(), 2);
5055        assert_eq!(sparse.indices[0].len(), 2);
5056
5057        // The stored indices should be 0 and 2
5058        assert_eq!(sparse.indices[0], vec![0, 2]);
5059        assert_eq!(sparse.values[0], vec![1.0, 2.0]);
5060    }
5061
5062    #[test]
5063    fn test_loghz_matrix_sparsity() {
5064        // Verify that LogHz matrices are very sparse (1-2 non-zeros per row)
5065        let sample_rate = 16000.0;
5066        let n_fft = nzu!(512);
5067        let n_bins = nzu!(128);
5068        let f_min = 20.0;
5069        let f_max = sample_rate / 2.0;
5070
5071        let (matrix, _freqs) =
5072            build_loghz_matrix(sample_rate, n_fft, n_bins, f_min, f_max).unwrap();
5073
5074        // Each row should have at most 2 non-zero values (linear interpolation)
5075        for row_idx in 0..matrix.nrows() {
5076            let nnz = matrix.values[row_idx].len();
5077            assert!(
5078                nnz <= 2,
5079                "Row {} has {} non-zeros, expected at most 2",
5080                row_idx,
5081                nnz
5082            );
5083            assert!(nnz >= 1, "Row {} has no non-zeros", row_idx);
5084        }
5085
5086        // Total non-zeros should be close to n_bins * 2
5087        let total_nnz: usize = matrix.values.iter().map(|v| v.len()).sum();
5088        assert!(total_nnz <= n_bins.get() * 2);
5089        assert!(total_nnz >= n_bins.get()); // At least 1 per row
5090    }
5091
5092    #[test]
5093    fn test_mel_matrix_sparsity() {
5094        // Verify that Mel matrices are sparse (triangular filters)
5095        let sample_rate = 16000.0;
5096        let n_fft = nzu!(512);
5097        let n_mels = nzu!(40);
5098        let f_min = 0.0;
5099        let f_max = sample_rate / 2.0;
5100
5101        let matrix =
5102            build_mel_filterbank_matrix(sample_rate, n_fft, n_mels, f_min, f_max, MelNorm::None)
5103                .unwrap();
5104
5105        let n_fft_bins = r2c_output_size(n_fft.get());
5106
5107        // Calculate sparsity
5108        let total_nnz: usize = matrix.values.iter().map(|v| v.len()).sum();
5109        let total_elements = n_mels.get() * n_fft_bins;
5110        let sparsity = 1.0 - (total_nnz as f64 / total_elements as f64);
5111
5112        // Mel filterbanks should be >80% sparse
5113        assert!(
5114            sparsity > 0.8,
5115            "Mel matrix sparsity is only {:.1}%, expected >80%",
5116            sparsity * 100.0
5117        );
5118
5119        // Each mel filter should have significantly fewer than n_fft_bins non-zeros
5120        for row_idx in 0..matrix.nrows() {
5121            let nnz = matrix.values[row_idx].len();
5122            assert!(
5123                nnz < n_fft_bins / 2,
5124                "Mel filter {} is not sparse enough",
5125                row_idx
5126            );
5127        }
5128    }
5129}