Skip to main content

spectrograms/
spectrogram.rs

1use std::marker::PhantomData;
2use std::num::NonZeroUsize;
3use std::ops::{Deref, DerefMut};
4
5use ndarray::{Array1, Array2};
6use non_empty_slice::{NonEmptySlice, NonEmptyVec, non_empty_vec};
7use num_complex::Complex;
8
9#[cfg(feature = "python")]
10use pyo3::prelude::*;
11
12use crate::cqt::CqtKernel;
13use crate::erb::ErbFilterbank;
14use crate::{
15    CqtParams, ErbParams, R2cPlan, SpectrogramError, SpectrogramResult, WindowType,
16    min_max_single_pass, nzu,
17};
18const EPS: f64 = 1e-12;
19
20//
21// ========================
22// Sparse Matrix for efficient filterbank multiplication
23// ========================
24//
25
26/// Row-wise sparse matrix optimized for matrix-vector multiplication.
27///
28/// This structure stores sparse data as a vector of vectors, where each row maintains
29/// its own list of non-zero values and corresponding column indices. This is more
30/// flexible than traditional CSR format and allows efficient row-by-row construction.
31///
32/// This structure is designed for matrices with very few non-zero values per row,
33/// such as mel filterbanks (triangular filters) and logarithmic frequency mappings
34/// (linear interpolation between 1-2 adjacent bins).
35///
36/// For typical spectrograms:
37/// - `LogHz` interpolation: Only 1-2 non-zeros per row (~99% sparse)
38/// - Mel filterbank: ~10-50 non-zeros per row depending on FFT size (~90-98% sparse)
39///
40/// By storing only non-zero values, we avoid wasting CPU cycles multiplying by zero,
41/// which can provide 10-100x speedup compared to dense matrix multiplication.
42#[derive(Debug, Clone)]
43struct SparseMatrix {
44    /// Number of rows
45    nrows: usize,
46    /// Number of columns
47    ncols: usize,
48    /// Non-zero values for each row (row-major order)
49    values: Vec<Vec<f64>>,
50    /// Column indices for each non-zero value
51    indices: Vec<Vec<usize>>,
52}
53
54impl SparseMatrix {
55    /// Create a new sparse matrix with the given dimensions.
56    fn new(nrows: usize, ncols: usize) -> Self {
57        Self {
58            nrows,
59            ncols,
60            values: vec![Vec::new(); nrows],
61            indices: vec![Vec::new(); nrows],
62        }
63    }
64
65    /// Set a value in the matrix. Only stores if value is non-zero.
66    ///
67    /// # Panics (debug mode only)
68    /// Panics in debug builds if row or col are out of bounds.
69    fn set(&mut self, row: usize, col: usize, value: f64) {
70        debug_assert!(
71            row < self.nrows && col < self.ncols,
72            "SparseMatrix index out of bounds: ({}, {}) for {}x{} matrix",
73            row,
74            col,
75            self.nrows,
76            self.ncols
77        );
78        if row >= self.nrows || col >= self.ncols {
79            return;
80        }
81
82        // Only store non-zero values (with small epsilon for numerical stability)
83        if value.abs() > 1e-10 {
84            self.values[row].push(value);
85            self.indices[row].push(col);
86        }
87    }
88
89    /// Get the number of rows.
90    const fn nrows(&self) -> usize {
91        self.nrows
92    }
93
94    /// Get the number of columns.
95    const fn ncols(&self) -> usize {
96        self.ncols
97    }
98
99    /// Perform sparse matrix-vector multiplication: out = self * input
100    /// This is much faster than dense multiplication when the matrix is sparse.
101    #[inline]
102    fn multiply_vec(&self, input: &[f64], out: &mut [f64]) {
103        debug_assert_eq!(input.len(), self.ncols);
104        debug_assert_eq!(out.len(), self.nrows);
105
106        for (row_idx, (row_values, row_indices)) in
107            self.values.iter().zip(&self.indices).enumerate()
108        {
109            let mut acc = 0.0;
110            for (&value, &col_idx) in row_values.iter().zip(row_indices) {
111                acc += value * input[col_idx];
112            }
113            out[row_idx] = acc;
114        }
115    }
116}
117
118// Linear frequency
119pub type LinearPowerSpectrogram = Spectrogram<LinearHz, Power>;
120pub type LinearMagnitudeSpectrogram = Spectrogram<LinearHz, Magnitude>;
121pub type LinearDbSpectrogram = Spectrogram<LinearHz, Decibels>;
122pub type LinearSpectrogram<AmpScale> = Spectrogram<LinearHz, AmpScale>;
123
124// Log-frequency (e.g. CQT-style)
125pub type LogHzPowerSpectrogram = Spectrogram<LogHz, Power>;
126pub type LogHzMagnitudeSpectrogram = Spectrogram<LogHz, Magnitude>;
127pub type LogHzDbSpectrogram = Spectrogram<LogHz, Decibels>;
128pub type LogHzSpectrogram<AmpScale> = Spectrogram<LogHz, AmpScale>;
129
130// ERB / gammatone
131pub type ErbPowerSpectrogram = Spectrogram<Erb, Power>;
132pub type ErbMagnitudeSpectrogram = Spectrogram<Erb, Magnitude>;
133pub type ErbDbSpectrogram = Spectrogram<Erb, Decibels>;
134pub type GammatonePowerSpectrogram = ErbPowerSpectrogram;
135pub type GammatoneMagnitudeSpectrogram = ErbMagnitudeSpectrogram;
136pub type GammatoneDbSpectrogram = ErbDbSpectrogram;
137pub type ErbSpectrogram<AmpScale> = Spectrogram<Erb, AmpScale>;
138pub type GammatoneSpectrogram<AmpScale> = ErbSpectrogram<AmpScale>;
139
140// Mel
141pub type MelMagnitudeSpectrogram = Spectrogram<Mel, Magnitude>;
142pub type MelPowerSpectrogram = Spectrogram<Mel, Power>;
143pub type MelDbSpectrogram = Spectrogram<Mel, Decibels>;
144pub type LogMelSpectrogram = MelDbSpectrogram;
145pub type MelSpectrogram<AmpScale> = Spectrogram<Mel, AmpScale>;
146
147// CQT
148pub type CqtPowerSpectrogram = Spectrogram<Cqt, Power>;
149pub type CqtMagnitudeSpectrogram = Spectrogram<Cqt, Magnitude>;
150pub type CqtDbSpectrogram = Spectrogram<Cqt, Decibels>;
151pub type CqtSpectrogram<AmpScale> = Spectrogram<Cqt, AmpScale>;
152
153use crate::fft_backend::r2c_output_size;
154
155/// A spectrogram plan is the compiled, reusable execution object.
156///
157/// It owns:
158/// - FFT plan (reusable)
159/// - window samples
160/// - mapping (identity / mel filterbank / etc.)
161/// - amplitude scaling config
162/// - workspace buffers to avoid allocations in hot loops
163///
164/// It computes one specific spectrogram type: `Spectrogram<FreqScale, AmpScale>`.
165///
166/// # Type Parameters
167///
168/// - `FreqScale`: Frequency scale type (e.g. `LinearHz`, `LogHz`, `Mel`, etc.)
169/// - `AmpScale`: Amplitude scaling type (e.g. `Power`, `Magnitude`, `Decibels`, etc.)
170pub struct SpectrogramPlan<FreqScale, AmpScale>
171where
172    AmpScale: AmpScaleSpec + 'static,
173    FreqScale: Copy + Clone + 'static,
174{
175    params: SpectrogramParams,
176
177    stft: StftPlan,
178    mapping: FrequencyMapping<FreqScale>,
179    scaling: AmplitudeScaling<AmpScale>,
180
181    freq_axis: FrequencyAxis<FreqScale>,
182    workspace: Workspace,
183
184    _amp: PhantomData<AmpScale>,
185}
186
187impl<FreqScale, AmpScale> SpectrogramPlan<FreqScale, AmpScale>
188where
189    AmpScale: AmpScaleSpec + 'static,
190    FreqScale: Copy + Clone + 'static,
191{
192    /// Get the spectrogram parameters used to create this plan.
193    ///
194    /// # Returns
195    ///
196    /// A reference to the `SpectrogramParams` used in this plan.
197    #[inline]
198    #[must_use]
199    pub const fn params(&self) -> &SpectrogramParams {
200        &self.params
201    }
202
203    /// Get the frequency axis for this spectrogram plan.
204    ///
205    /// # Returns
206    ///
207    /// A reference to the `FrequencyAxis<FreqScale>` used in this plan.
208    #[inline]
209    #[must_use]
210    pub const fn freq_axis(&self) -> &FrequencyAxis<FreqScale> {
211        &self.freq_axis
212    }
213
214    /// Compute a spectrogram for a mono signal.
215    ///
216    /// This function performs:
217    /// - framing + windowing
218    /// - FFT per frame
219    /// - magnitude/power
220    /// - frequency mapping (identity/mel/etc.)
221    /// - amplitude scaling (linear or dB)
222    ///
223    /// It allocates the output `Array2` once, but does not allocate per-frame.
224    ///
225    /// # Arguments
226    ///
227    /// * `samples` - Audio samples
228    ///
229    /// # Returns
230    ///
231    /// A `Spectrogram<FreqScale, AmpScale>` containing the computed spectrogram.
232    ///
233    /// # Errors
234    ///
235    /// Returns an error if STFT computation or mapping fails.
236    #[inline]
237    pub fn compute(
238        &mut self,
239        samples: &NonEmptySlice<f64>,
240    ) -> SpectrogramResult<Spectrogram<FreqScale, AmpScale>> {
241        let n_frames = self.stft.frame_count(samples.len())?;
242        let n_bins = self.mapping.output_bins();
243
244        // Create output matrix: (n_bins, n_frames)
245        let mut data = Array2::<f64>::zeros((n_bins.get(), n_frames.get()));
246
247        // Ensure workspace is correctly sized
248        self.workspace
249            .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
250
251        // Main loop: fill each frame (column)
252        for frame_idx in 0..n_frames.get() {
253            // CQT needs unwindowed frames because its kernels already contain windowing.
254            // Other mappings use the FFT spectrum, so they need the windowed frame.
255            if self.mapping.kind.needs_unwindowed_frame() {
256                // Fill frame without windowing (for CQT)
257                self.stft
258                    .fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
259            } else {
260                // Compute windowed frame spectrum (for FFT-based mappings)
261                self.stft
262                    .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
263            }
264
265            // mapping: spectrum(out_len) -> mapped(n_bins)
266            // For CQT, this uses workspace.frame; for others, workspace.spectrum
267            // For ERB, we need the complex FFT output (fft_out)
268            // We need to borrow workspace fields separately to avoid borrow conflicts
269            let Workspace {
270                spectrum,
271                mapped,
272                frame,
273                ..
274            } = &mut self.workspace;
275
276            self.mapping.apply(spectrum, frame, mapped)?;
277
278            // amplitude scaling in-place on mapped vector
279            self.scaling.apply_in_place(mapped)?;
280
281            // write column into output
282            for (row, &val) in mapped.iter().enumerate() {
283                data[[row, frame_idx]] = val;
284            }
285        }
286
287        let times = build_time_axis_seconds(&self.params, n_frames);
288        let axes = Axes::new(self.freq_axis.clone(), times);
289
290        Ok(Spectrogram::new(data, axes, self.params.clone()))
291    }
292
293    /// Compute a single frame of the spectrogram.
294    ///
295    /// This is useful for streaming/online processing where you want to
296    /// process audio frame-by-frame without computing the entire spectrogram.
297    ///
298    /// # Arguments
299    ///
300    /// * `samples` - Audio samples (must contain at least enough samples for the requested frame)
301    /// * `frame_idx` - Frame index to compute
302    ///
303    /// # Returns
304    ///
305    /// A vector of frequency bin values for the requested frame.
306    ///
307    /// # Errors
308    ///
309    /// Returns an error if the frame index is out of bounds or if STFT computation fails.
310    ///
311    /// # Examples
312    ///
313    /// ```
314    /// use spectrograms::*;
315    /// use non_empty_slice::non_empty_vec;
316    ///
317    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
318    /// let samples = non_empty_vec![0.0; nzu!(16000)];
319    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
320    /// let params = SpectrogramParams::new(stft, 16000.0)?;
321    ///
322    /// let planner = SpectrogramPlanner::new();
323    /// let mut plan = planner.linear_plan::<Power>(&params, None)?;
324    ///
325    /// // Compute just the first frame
326    /// let frame = plan.compute_frame(&samples, 0)?;
327    /// assert_eq!(frame.len(), nzu!(257)); // n_fft/2 + 1
328    /// # Ok(())
329    /// # }
330    /// ```
331    #[inline]
332    pub fn compute_frame(
333        &mut self,
334        samples: &NonEmptySlice<f64>,
335        frame_idx: usize,
336    ) -> SpectrogramResult<NonEmptyVec<f64>> {
337        let n_bins = self.mapping.output_bins();
338
339        // Ensure workspace is correctly sized
340        self.workspace
341            .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
342
343        // CQT needs unwindowed frames because its kernels already contain windowing.
344        // Other mappings use the FFT spectrum, so they need the windowed frame.
345        if self.mapping.kind.needs_unwindowed_frame() {
346            // Fill frame without windowing (for CQT)
347            self.stft
348                .fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
349        } else {
350            // Compute windowed frame spectrum (for FFT-based mappings)
351            self.stft
352                .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
353        }
354
355        // Apply mapping (using split borrows to avoid borrow conflicts)
356        let Workspace {
357            spectrum,
358            mapped,
359            frame,
360            ..
361        } = &mut self.workspace;
362
363        self.mapping.apply(spectrum, frame, mapped)?;
364
365        // Apply amplitude scaling
366        self.scaling.apply_in_place(mapped)?;
367
368        Ok(mapped.clone())
369    }
370
371    /// Compute spectrogram into a pre-allocated buffer.
372    ///
373    /// This avoids allocating the output matrix, which is useful when
374    /// you want to reuse buffers or have strict memory requirements.
375    ///
376    /// # Arguments
377    ///
378    /// * `samples` - Audio samples
379    /// * `output` - Pre-allocated output matrix (must be correct size: `n_bins` x `n_frames`)
380    ///
381    /// # Returns
382    ///
383    /// An empty result on success.
384    ///
385    /// # Errors
386    ///
387    /// Returns an error if the output buffer dimensions don't match the expected size.
388    ///
389    /// # Examples
390    ///
391    /// ```
392    /// use spectrograms::*;
393    /// use ndarray::Array2;
394    /// use non_empty_slice::non_empty_vec;
395    ///
396    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
397    /// let samples = non_empty_vec![0.0; nzu!(16000)];
398    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
399    /// let params = SpectrogramParams::new(stft, 16000.0)?;
400    ///
401    /// let planner = SpectrogramPlanner::new();
402    /// let mut plan = planner.linear_plan::<Power>(&params, None)?;
403    ///
404    /// // Pre-allocate output buffer
405    /// let mut output = Array2::<f64>::zeros((257, 63));
406    /// plan.compute_into(&samples, &mut output)?;
407    /// # Ok(())
408    /// # }
409    /// ```
410    #[inline]
411    pub fn compute_into(
412        &mut self,
413        samples: &NonEmptySlice<f64>,
414        output: &mut Array2<f64>,
415    ) -> SpectrogramResult<()> {
416        let n_frames = self.stft.frame_count(samples.len())?;
417        let n_bins = self.mapping.output_bins();
418
419        // Validate output dimensions
420        if output.nrows() != n_bins.get() {
421            return Err(SpectrogramError::dimension_mismatch(
422                n_bins.get(),
423                output.nrows(),
424            ));
425        }
426        if output.ncols() != n_frames.get() {
427            return Err(SpectrogramError::dimension_mismatch(
428                n_frames.get(),
429                output.ncols(),
430            ));
431        }
432
433        // Ensure workspace is correctly sized
434        self.workspace
435            .ensure_sizes(self.stft.n_fft, self.stft.out_len, n_bins);
436
437        // Main loop: fill each frame (column)
438        for frame_idx in 0..n_frames.get() {
439            // CQT needs unwindowed frames because its kernels already contain windowing.
440            // Other mappings use the FFT spectrum, so they need the windowed frame.
441            if self.mapping.kind.needs_unwindowed_frame() {
442                // Fill frame without windowing (for CQT)
443                self.stft
444                    .fill_frame_unwindowed(samples, frame_idx, &mut self.workspace)?;
445            } else {
446                // Compute windowed frame spectrum (for FFT-based mappings)
447                self.stft
448                    .compute_frame_spectrum(samples, frame_idx, &mut self.workspace)?;
449            }
450
451            // mapping: spectrum(out_len) -> mapped(n_bins)
452            // For CQT, this uses workspace.frame; for others, workspace.spectrum
453            // For ERB, we need the complex FFT output (fft_out)
454            // We need to borrow workspace fields separately to avoid borrow conflicts
455            let Workspace {
456                spectrum,
457                mapped,
458                frame,
459                ..
460            } = &mut self.workspace;
461
462            self.mapping.apply(spectrum, frame, mapped)?;
463
464            // amplitude scaling in-place on mapped vector
465            self.scaling.apply_in_place(mapped)?;
466
467            // write column into output
468            for (row, &val) in mapped.iter().enumerate() {
469                output[[row, frame_idx]] = val;
470            }
471        }
472
473        Ok(())
474    }
475
476    /// Get the expected output dimensions for a given signal length.
477    ///
478    /// # Arguments
479    ///
480    /// * `signal_length` - Length of the input signal in samples.
481    ///
482    /// # Returns
483    ///
484    /// A tuple `(n_bins, n_frames)` representing the output spectrogram shape.
485    ///
486    /// # Errors
487    ///
488    /// Returns an error if the signal length is too short to produce any frames.
489    ///
490    /// # Examples
491    ///
492    /// ```
493    /// use spectrograms::*;
494    ///
495    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
496    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
497    /// let params = SpectrogramParams::new(stft, 16000.0)?;
498    ///
499    /// let planner = SpectrogramPlanner::new();
500    /// let plan = planner.linear_plan::<Power>(&params, None)?;
501    ///
502    /// let (n_bins, n_frames) = plan.output_shape(nzu!(16000))?;
503    /// assert_eq!(n_bins, nzu!(257));
504    /// assert_eq!(n_frames, nzu!(63));
505    /// # Ok(())
506    /// # }
507    /// ```
508    #[inline]
509    pub fn output_shape(
510        &self,
511        signal_length: NonZeroUsize,
512    ) -> SpectrogramResult<(NonZeroUsize, NonZeroUsize)> {
513        let n_frames = self.stft.frame_count(signal_length)?;
514        let n_bins = self.mapping.output_bins();
515        Ok((n_bins, n_frames))
516    }
517}
518
519/// STFT (Short-Time Fourier Transform) result containing complex frequency bins.
520///
521/// This is the raw STFT output before any frequency mapping or amplitude scaling.
522///
523/// # Fields
524///
525/// - `data`: Complex STFT matrix with shape (`frequency_bins`, `time_frames`)
526/// - `frequencies`: Frequency axis in Hz
527/// - `sample_rate`: Sample rate in Hz
528/// - `params`: STFT computation parameters
529#[derive(Debug, Clone)]
530#[non_exhaustive]
531pub struct StftResult {
532    /// Complex STFT matrix with shape (`frequency_bins`, `time_frames`)
533    pub data: Array2<Complex<f64>>,
534    /// Frequency axis in Hz
535    pub frequencies: NonEmptyVec<f64>,
536    /// Sample rate in Hz
537    pub sample_rate: f64,
538    pub params: StftParams,
539}
540
541impl StftResult {
542    /// Get the number of frequency bins.
543    ///
544    /// # Returns
545    ///
546    /// Number of frequency bins in the STFT result.
547    #[inline]
548    #[must_use]
549    pub fn n_bins(&self) -> NonZeroUsize {
550        // safety: nrows() > 0 for NonEmptyVec
551        unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
552    }
553
554    /// Get the number of time frames.
555    ///
556    /// # Returns
557    ///
558    /// Number of time frames in the STFT result.
559    #[inline]
560    #[must_use]
561    pub fn n_frames(&self) -> NonZeroUsize {
562        // safety: ncols() > 0 for NonEmptyVec
563        unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
564    }
565
566    /// Get the frequency resolution in Hz
567    ///
568    /// # Returns
569    ///
570    /// Frequency bin width in Hz.
571    #[inline]
572    #[must_use]
573    pub fn frequency_resolution(&self) -> f64 {
574        self.sample_rate / self.params.n_fft().get() as f64
575    }
576
577    /// Get the time resolution in seconds.
578    ///
579    /// # Returns
580    ///
581    /// Time between successive frames in seconds.
582    #[inline]
583    #[must_use]
584    pub fn time_resolution(&self) -> f64 {
585        self.params.hop_size().get() as f64 / self.sample_rate
586    }
587
588    /// Normalizes self.data to remove the complex aspect of it.
589    ///
590    /// # Returns
591    ///
592    /// An Array2\<f64\> containing the norms of each complex number in self.data.
593    #[inline]
594    pub fn norm(&self) -> Array2<f64> {
595        self.as_ref().mapv(Complex::norm)
596    }
597}
598
599impl AsRef<Array2<Complex<f64>>> for StftResult {
600    #[inline]
601    fn as_ref(&self) -> &Array2<Complex<f64>> {
602        &self.data
603    }
604}
605
606impl AsMut<Array2<Complex<f64>>> for StftResult {
607    #[inline]
608    fn as_mut(&mut self) -> &mut Array2<Complex<f64>> {
609        &mut self.data
610    }
611}
612
613impl Deref for StftResult {
614    type Target = Array2<Complex<f64>>;
615
616    #[inline]
617    fn deref(&self) -> &Self::Target {
618        &self.data
619    }
620}
621
622impl DerefMut for StftResult {
623    #[inline]
624    fn deref_mut(&mut self) -> &mut Self::Target {
625        &mut self.data
626    }
627}
628
629/// A planner is an object that can build spectrogram plans.
630///
631/// In your design, this is where:
632/// - FFT plans are created
633/// - mapping matrices are compiled
634/// - axes are computed
635///
636/// This allows you to keep plan building separate from the output types.
637#[derive(Debug, Default)]
638#[non_exhaustive]
639pub struct SpectrogramPlanner;
640
641impl SpectrogramPlanner {
642    /// Create a new spectrogram planner.
643    ///
644    /// # Returns
645    ///
646    /// A new `SpectrogramPlanner` instance.
647    #[inline]
648    #[must_use]
649    pub const fn new() -> Self {
650        Self
651    }
652
653    /// Compute the Short-Time Fourier Transform (STFT) of a signal.
654    ///
655    /// This returns the raw complex STFT matrix before any frequency mapping
656    /// or amplitude scaling. Useful for applications that need the full complex
657    /// spectrum or custom processing.
658    ///
659    /// # Arguments
660    ///
661    /// * `samples` - Audio samples (any type that can be converted to a slice)
662    /// * `params` - STFT computation parameters
663    ///
664    /// # Returns
665    ///
666    /// An `StftResult` containing the complex STFT matrix and metadata.
667    ///
668    /// # Errors
669    ///
670    /// Returns an error if STFT computation fails.
671    ///
672    /// # Examples
673    ///
674    /// ```
675    /// use spectrograms::*;
676    /// use non_empty_slice::non_empty_vec;
677    ///
678    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
679    /// let samples = non_empty_vec![0.0; nzu!(16000)];
680    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
681    /// let params = SpectrogramParams::new(stft, 16000.0)?;
682    ///
683    /// let planner = SpectrogramPlanner::new();
684    /// let stft_result = planner.compute_stft(&samples, &params)?;
685    ///
686    /// println!("STFT: {} bins x {} frames", stft_result.n_bins(), stft_result.n_frames());
687    /// # Ok(())
688    /// # }
689    /// ```
690    ///
691    /// # Performance Note
692    ///
693    /// This method creates a new FFT plan each time. For processing multiple
694    /// signals, create a reusable plan with `StftPlan::new()` instead.
695    ///
696    /// # Examples
697    ///
698    /// ```rust
699    /// use spectrograms::*;
700    /// use non_empty_slice::non_empty_vec;
701    ///
702    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
703    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
704    /// let params = SpectrogramParams::new(stft, 16000.0)?;
705    ///
706    /// // One-shot (convenient)
707    /// let planner = SpectrogramPlanner::new();
708    /// let stft_result = planner.compute_stft(&non_empty_vec![0.0; nzu!(16000)], &params)?;
709    ///
710    /// // Reusable plan (efficient for batch)
711    /// let mut plan = StftPlan::new(&params)?;
712    /// for signal in &[non_empty_vec![0.0; nzu!(16000)], non_empty_vec![1.0; nzu!(16000)]] {
713    ///     let stft = plan.compute(&signal, &params)?;
714    /// }
715    /// # Ok(())
716    /// # }
717    /// ```
718    #[inline]
719    pub fn compute_stft(
720        &self,
721        samples: &NonEmptySlice<f64>,
722        params: &SpectrogramParams,
723    ) -> SpectrogramResult<StftResult> {
724        let mut plan = StftPlan::new(params)?;
725        plan.compute(samples, params)
726    }
727
728    /// Compute the power spectrum of a single audio frame.
729    ///
730    /// This is useful for real-time processing or analyzing individual frames.
731    ///
732    /// # Arguments
733    ///
734    /// * `samples` - Audio frame (length ≤ n_fft, will be zero-padded if shorter)
735    /// * `n_fft` - FFT size
736    /// * `window` - Window type to apply
737    ///
738    /// # Returns
739    ///
740    /// A vector of power values (|X|²) with length `n_fft/2` + 1.
741    ///
742    /// # Automatic Zero-Padding
743    ///
744    /// If the input signal is shorter than `n_fft`, it will be automatically
745    /// zero-padded to the required length.
746    ///
747    /// # Errors
748    ///
749    /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
750    ///
751    /// # Examples
752    ///
753    /// ```
754    /// use spectrograms::*;
755    /// use non_empty_slice::non_empty_vec;
756    ///
757    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
758    /// let frame = non_empty_vec![0.0; nzu!(512)];
759    ///
760    /// let planner = SpectrogramPlanner::new();
761    /// let power = planner.compute_power_spectrum(frame.as_ref(), nzu!(512), WindowType::Hanning)?;
762    ///
763    /// assert_eq!(power.len(), nzu!(257)); // 512/2 + 1
764    /// # Ok(())
765    /// # }
766    /// ```
767    #[inline]
768    pub fn compute_power_spectrum(
769        &self,
770        samples: &NonEmptySlice<f64>,
771        n_fft: NonZeroUsize,
772        window: WindowType,
773    ) -> SpectrogramResult<NonEmptyVec<f64>> {
774        if samples.len() > n_fft {
775            return Err(SpectrogramError::invalid_input(format!(
776                "Input length ({}) exceeds FFT size ({})",
777                samples.len(),
778                n_fft
779            )));
780        }
781
782        let window_samples = make_window(window, n_fft);
783        let out_len = r2c_output_size(n_fft.get());
784
785        // Create FFT plan
786        #[cfg(feature = "realfft")]
787        let mut fft = {
788            let mut planner = crate::RealFftPlanner::new();
789            let plan = planner.get_or_create(n_fft.get());
790            crate::RealFftPlan::new(n_fft.get(), plan)
791        };
792
793        #[cfg(feature = "fftw")]
794        let mut fft = {
795            use std::sync::Arc;
796            let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
797            crate::FftwPlan::new(Arc::new(plan))
798        };
799
800        // Apply window and compute FFT
801        let mut windowed = vec![0.0; n_fft.get()];
802        for i in 0..samples.len().get() {
803            windowed[i] = samples[i] * window_samples[i];
804        }
805        // The rest is already zero-padded
806        let mut fft_out = vec![Complex::new(0.0, 0.0); out_len];
807        fft.process(&windowed, &mut fft_out)?;
808
809        // Convert to power
810        let power: Vec<f64> = fft_out.iter().map(num_complex::Complex::norm_sqr).collect();
811        // safety: power is non-empty since n_fft > 0
812        Ok(unsafe { NonEmptyVec::new_unchecked(power) })
813    }
814
815    /// Compute the magnitude spectrum of a single audio frame.
816    ///
817    /// This is useful for real-time processing or analyzing individual frames.
818    ///
819    /// # Arguments
820    ///
821    /// * `samples` - Audio frame (length ≤ n_fft, will be zero-padded if shorter)
822    /// * `n_fft` - FFT size
823    /// * `window` - Window type to apply
824    ///
825    /// # Returns
826    ///
827    /// A vector of magnitude values (|X|) with length `n_fft/2` + 1.
828    ///
829    /// # Automatic Zero-Padding
830    ///
831    /// If the input signal is shorter than `n_fft`, it will be automatically
832    /// zero-padded to the required length.
833    ///
834    /// # Errors
835    ///
836    /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
837    ///
838    /// # Examples
839    ///
840    /// ```
841    /// use spectrograms::*;
842    /// use non_empty_slice::non_empty_vec;
843    ///
844    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
845    /// let frame = non_empty_vec![0.0; nzu!(512)];
846    ///
847    /// let planner = SpectrogramPlanner::new();
848    /// let magnitude = planner.compute_magnitude_spectrum(frame.as_ref(), nzu!(512), WindowType::Hanning)?;
849    ///
850    /// assert_eq!(magnitude.len(), nzu!(257)); // 512/2 + 1
851    /// # Ok(())
852    /// # }
853    /// ```
854    #[inline]
855    pub fn compute_magnitude_spectrum(
856        &self,
857        samples: &NonEmptySlice<f64>,
858        n_fft: NonZeroUsize,
859        window: WindowType,
860    ) -> SpectrogramResult<NonEmptyVec<f64>> {
861        let power = self.compute_power_spectrum(samples, n_fft, window)?;
862        let power = power.iter().map(|&p| p.sqrt()).collect::<Vec<f64>>();
863        // safety: power is non-empty since power_spectrum returned successfully
864        Ok(unsafe { NonEmptyVec::new_unchecked(power) })
865    }
866
867    /// Build a linear-frequency spectrogram plan.
868    ///
869    /// # Type Parameters
870    ///
871    /// `AmpScale` determines whether output is:
872    /// - Magnitude
873    /// - Power
874    /// - Decibels
875    ///
876    /// # Arguments
877    ///
878    /// * `params` - Spectrogram parameters
879    /// * `db` - Logarithmic scaling parameters (only used if `AmpScale
880    /// is `Decibels`)
881    ///
882    /// # Returns
883    ///
884    /// A `SpectrogramPlan` configured for linear-frequency spectrogram computation.
885    ///
886    /// # Errors
887    ///
888    /// Returns an error if the plan cannot be created due to invalid parameters.
889    #[inline]
890    pub fn linear_plan<AmpScale>(
891        &self,
892        params: &SpectrogramParams,
893        db: Option<&LogParams>, // only used when AmpScale = Decibels
894    ) -> SpectrogramResult<SpectrogramPlan<LinearHz, AmpScale>>
895    where
896        AmpScale: AmpScaleSpec + 'static,
897    {
898        let stft = StftPlan::new(params)?;
899        let mapping = FrequencyMapping::<LinearHz>::new(params)?;
900        let scaling = AmplitudeScaling::<AmpScale>::new(db);
901        let freq_axis = build_frequency_axis::<LinearHz>(params, &mapping);
902
903        let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
904
905        Ok(SpectrogramPlan {
906            params: params.clone(),
907            stft,
908            mapping,
909            scaling,
910            freq_axis,
911            workspace,
912            _amp: PhantomData,
913        })
914    }
915
916    /// Build a mel-frequency spectrogram plan.
917    ///
918    /// This compiles a mel filterbank matrix and caches it inside the plan.
919    ///
920    /// # Type Parameters
921    ///
922    /// `AmpScale`: determines whether output is:
923    /// - Magnitude
924    /// - Power
925    /// - Decibels
926    ///
927    /// # Arguments
928    ///
929    /// * `params` - Spectrogram parameters
930    /// * `mel` - Mel-specific parameters
931    /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
932    ///
933    /// # Returns
934    ///
935    /// A `SpectrogramPlan` configured for mel spectrogram computation.
936    ///
937    /// # Errors
938    ///
939    /// Returns an error if the plan cannot be created due to invalid parameters.
940    #[inline]
941    pub fn mel_plan<AmpScale>(
942        &self,
943        params: &SpectrogramParams,
944        mel: &MelParams,
945        db: Option<&LogParams>, // only used when AmpScale = Decibels
946    ) -> SpectrogramResult<SpectrogramPlan<Mel, AmpScale>>
947    where
948        AmpScale: AmpScaleSpec + 'static,
949    {
950        // cross-validation: mel range must be compatible with sample rate
951        let nyquist = params.nyquist_hz();
952        if mel.f_max() > nyquist {
953            return Err(SpectrogramError::invalid_input(
954                "mel f_max must be <= Nyquist",
955            ));
956        }
957
958        let stft = StftPlan::new(params)?;
959        let mapping = FrequencyMapping::<Mel>::new_mel(params, mel)?;
960        let scaling = AmplitudeScaling::<AmpScale>::new(db);
961        let freq_axis = build_frequency_axis::<Mel>(params, &mapping);
962
963        let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
964
965        Ok(SpectrogramPlan {
966            params: params.clone(),
967            stft,
968            mapping,
969            scaling,
970            freq_axis,
971            workspace,
972            _amp: PhantomData,
973        })
974    }
975
976    /// Build an ERB-scale spectrogram plan.
977    ///
978    /// This creates a spectrogram with ERB-spaced frequency bands using gammatone
979    /// filterbank approximation in the frequency domain.
980    ///
981    /// # Type Parameters
982    ///
983    /// `AmpScale`: determines whether output is:
984    /// - Magnitude
985    /// - Power
986    /// - Decibels
987    ///
988    /// # Arguments
989    ///
990    /// * `params` - Spectrogram parameters
991    /// * `erb` - ERB-specific parameters
992    /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
993    ///
994    /// # Returns
995    ///
996    /// A `SpectrogramPlan` configured for ERB spectrogram computation.
997    ///
998    /// # Errors
999    ///
1000    /// Returns an error if the plan cannot be created due to invalid parameters.
1001    #[inline]
1002    pub fn erb_plan<AmpScale>(
1003        &self,
1004        params: &SpectrogramParams,
1005        erb: &ErbParams,
1006        db: Option<&LogParams>,
1007    ) -> SpectrogramResult<SpectrogramPlan<Erb, AmpScale>>
1008    where
1009        AmpScale: AmpScaleSpec + 'static,
1010    {
1011        // cross-validation: erb range must be compatible with sample rate
1012        let nyquist = params.nyquist_hz();
1013        if erb.f_max() > nyquist {
1014            return Err(SpectrogramError::invalid_input(format!(
1015                "f_max={} exceeds Nyquist={}",
1016                erb.f_max(),
1017                nyquist
1018            )));
1019        }
1020
1021        let stft = StftPlan::new(params)?;
1022        let mapping = FrequencyMapping::<Erb>::new_erb(params, erb)?;
1023        let scaling = AmplitudeScaling::<AmpScale>::new(db);
1024        let freq_axis = build_frequency_axis::<Erb>(params, &mapping);
1025
1026        let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1027
1028        Ok(SpectrogramPlan {
1029            params: params.clone(),
1030            stft,
1031            mapping,
1032            scaling,
1033            freq_axis,
1034            workspace,
1035            _amp: PhantomData,
1036        })
1037    }
1038
1039    /// Build a log-frequency plan.
1040    ///
1041    /// This creates a spectrogram with logarithmically-spaced frequency bins.
1042    ///
1043    /// # Type Parameters
1044    ///
1045    /// `AmpScale`: determines whether output is:
1046    /// - Magnitude
1047    /// - Power
1048    /// - Decibels
1049    ///
1050    /// # Arguments
1051    ///
1052    /// * `params` - Spectrogram parameters
1053    /// * `loghz` - LogHz-specific parameters
1054    /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
1055    ///
1056    /// # Returns
1057    ///
1058    /// A `SpectrogramPlan` configured for log-frequency spectrogram computation.
1059    ///
1060    /// # Errors
1061    ///
1062    /// Returns an error if the plan cannot be created due to invalid parameters.
1063    #[inline]
1064    pub fn log_hz_plan<AmpScale>(
1065        &self,
1066        params: &SpectrogramParams,
1067        loghz: &LogHzParams,
1068        db: Option<&LogParams>,
1069    ) -> SpectrogramResult<SpectrogramPlan<LogHz, AmpScale>>
1070    where
1071        AmpScale: AmpScaleSpec + 'static,
1072    {
1073        // cross-validation: loghz range must be compatible with sample rate
1074        let nyquist = params.nyquist_hz();
1075        if loghz.f_max() > nyquist {
1076            return Err(SpectrogramError::invalid_input(format!(
1077                "f_max={} exceeds Nyquist={}",
1078                loghz.f_max(),
1079                nyquist
1080            )));
1081        }
1082
1083        let stft = StftPlan::new(params)?;
1084        let mapping = FrequencyMapping::<LogHz>::new_loghz(params, loghz)?;
1085        let scaling = AmplitudeScaling::<AmpScale>::new(db);
1086        let freq_axis = build_frequency_axis::<LogHz>(params, &mapping);
1087
1088        let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1089
1090        Ok(SpectrogramPlan {
1091            params: params.clone(),
1092            stft,
1093            mapping,
1094            scaling,
1095            freq_axis,
1096            workspace,
1097            _amp: PhantomData,
1098        })
1099    }
1100
1101    /// Build a cqt spectrogram plan.
1102    ///
1103    /// # Type Parameters
1104    ///
1105    /// `AmpScale`: determines whether output is:
1106    /// - Magnitude
1107    /// - Power
1108    /// - Decibels
1109    ///
1110    /// # Arguments
1111    ///
1112    /// * `params` - Spectrogram parameters
1113    /// * `cqt` - CQT-specific parameters
1114    /// * `db` - Logarithmic scaling parameters (only used if `AmpScale` is `Decibels`)
1115    ///
1116    /// # Returns
1117    ///
1118    /// A `SpectrogramPlan` configured for CQT spectrogram computation.
1119    ///
1120    /// # Errors
1121    ///
1122    /// Returns an error if the plan cannot be created due to invalid parameters.
1123    #[inline]
1124    pub fn cqt_plan<AmpScale>(
1125        &self,
1126        params: &SpectrogramParams,
1127        cqt: &CqtParams,
1128        db: Option<&LogParams>, // only used when AmpScale = Decibels
1129    ) -> SpectrogramResult<SpectrogramPlan<Cqt, AmpScale>>
1130    where
1131        AmpScale: AmpScaleSpec + 'static,
1132    {
1133        let stft = StftPlan::new(params)?;
1134        let mapping = FrequencyMapping::<Cqt>::new(params, cqt)?;
1135        let scaling = AmplitudeScaling::<AmpScale>::new(db);
1136        let freq_axis = build_frequency_axis::<Cqt>(params, &mapping);
1137
1138        let workspace = Workspace::new(stft.n_fft, stft.out_len, mapping.output_bins());
1139
1140        Ok(SpectrogramPlan {
1141            params: params.clone(),
1142            stft,
1143            mapping,
1144            scaling,
1145            freq_axis,
1146            workspace,
1147            _amp: PhantomData,
1148        })
1149    }
1150}
1151
1152/// STFT plan containing reusable FFT plan and buffers.
1153///
1154/// This struct is responsible for performing the Short-Time Fourier Transform (STFT)
1155/// on audio signals based on the provided parameters.
1156///
1157/// It encapsulates the FFT plan, windowing function, and internal buffers to efficiently
1158/// compute the STFT for multiple frames of audio data.
1159///
1160/// # Fields
1161///
1162/// - `n_fft`: Size of the FFT.
1163/// - `hop_size`: Hop size between consecutive frames.
1164/// - `window`: Windowing function samples.
1165/// - `centre`: Whether to centre the frames with padding.
1166/// - `out_len`: Length of the FFT output.
1167/// - `fft`: Boxed FFT plan for real-to-complex transformation.
1168/// - `fft_out`: Internal buffer for FFT output.
1169/// - `frame`: Internal buffer for windowed audio frames.
1170pub struct StftPlan {
1171    n_fft: NonZeroUsize,
1172    hop_size: NonZeroUsize,
1173    window: NonEmptyVec<f64>,
1174    centre: bool,
1175
1176    out_len: NonZeroUsize,
1177
1178    // FFT plan (reused for all frames)
1179    fft: Box<dyn R2cPlan>,
1180
1181    // internal scratch
1182    fft_out: NonEmptyVec<Complex<f64>>,
1183    frame: NonEmptyVec<f64>,
1184}
1185
1186impl StftPlan {
1187    /// Create a new STFT plan from parameters.
1188    ///
1189    /// # Arguments
1190    ///
1191    /// * `params` - Spectrogram parameters containing STFT config
1192    ///
1193    /// # Returns
1194    ///
1195    /// A new `StftPlan` instance.
1196    ///
1197    /// # Errors
1198    ///
1199    /// Returns an error if the FFT plan cannot be created.
1200    #[inline]
1201    pub fn new(params: &SpectrogramParams) -> SpectrogramResult<Self> {
1202        let stft = params.stft();
1203        let n_fft = stft.n_fft();
1204        let hop_size = stft.hop_size();
1205        let centre = stft.centre();
1206
1207        let window = make_window(stft.window(), n_fft);
1208
1209        let out_len = r2c_output_size(n_fft.get());
1210        let out_len = NonZeroUsize::new(out_len)
1211            .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1212
1213        #[cfg(feature = "realfft")]
1214        let fft = {
1215            let mut planner = crate::RealFftPlanner::new();
1216            let plan = planner.get_or_create(n_fft.get());
1217            let plan = crate::RealFftPlan::new(n_fft.get(), plan);
1218            Box::new(plan)
1219        };
1220
1221        #[cfg(feature = "fftw")]
1222        let fft = {
1223            use std::sync::Arc;
1224            let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
1225            Box::new(crate::FftwPlan::new(Arc::new(plan)))
1226        };
1227
1228        Ok(Self {
1229            n_fft,
1230            hop_size,
1231            window,
1232            centre,
1233            out_len,
1234            fft,
1235            fft_out: non_empty_vec![Complex::new(0.0, 0.0); out_len],
1236            frame: non_empty_vec![0.0; n_fft],
1237        })
1238    }
1239
1240    fn frame_count(&self, n_samples: NonZeroUsize) -> SpectrogramResult<NonZeroUsize> {
1241        // Framing policy:
1242        // - centre = true: implicit padding of n_fft/2 on both sides
1243        // - centre = false: no padding
1244        //
1245        // Define the number of frames such that each frame has a valid centre sample position.
1246        let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1247        let padded_len = n_samples.get() + 2 * pad;
1248
1249        if padded_len < self.n_fft.get() {
1250            // still produce 1 frame (all padding / partial)
1251            return Ok(nzu!(1));
1252        }
1253
1254        let remaining = padded_len - self.n_fft.get();
1255        let n_frames = remaining / self.hop_size().get() + 1;
1256        let n_frames = NonZeroUsize::new(n_frames).ok_or_else(|| {
1257            SpectrogramError::invalid_input("computed number of frames must be non-zero")
1258        })?;
1259        Ok(n_frames)
1260    }
1261
1262    /// Compute one frame FFT using internal buffers only.
1263    fn compute_frame_fft_simple(
1264        &mut self,
1265        samples: &NonEmptySlice<f64>,
1266        frame_idx: usize,
1267    ) -> SpectrogramResult<()> {
1268        let out = self.frame.as_mut_slice();
1269        debug_assert_eq!(out.len(), self.n_fft.get());
1270
1271        let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1272        let start = frame_idx
1273            .checked_mul(self.hop_size.get())
1274            .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1275
1276        // Fill windowed frame
1277        for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1278            let v_idx = start + i;
1279            let s_idx = v_idx as isize - pad as isize;
1280
1281            let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1282                0.0
1283            } else {
1284                samples[s_idx as usize]
1285            };
1286            *sample = sample_val * self.window[i];
1287        }
1288
1289        // Compute FFT
1290        let fft_out = self.fft_out.as_mut_slice();
1291        self.fft.process(out, fft_out)?;
1292
1293        Ok(())
1294    }
1295
1296    /// Compute one frame spectrum into workspace:
1297    /// - fills windowed frame
1298    /// - runs FFT
1299    /// - converts to magnitude/power based on `AmpScale` later
1300    fn compute_frame_spectrum(
1301        &mut self,
1302        samples: &NonEmptySlice<f64>,
1303        frame_idx: usize,
1304        workspace: &mut Workspace,
1305    ) -> SpectrogramResult<()> {
1306        let out = workspace.frame.as_mut_slice();
1307
1308        // self.fill_frame(samples, frame_idx, frame)?;
1309        debug_assert_eq!(out.len(), self.n_fft.get());
1310
1311        let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1312        let start = frame_idx
1313            .checked_mul(self.hop_size().get())
1314            .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1315
1316        // The "virtual" signal is samples with pad zeros on both sides.
1317        // Virtual index 0..padded_len
1318        // Map virtual index to original samples by subtracting pad.
1319        for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1320            let v_idx = start + i;
1321            let s_idx = v_idx as isize - pad as isize;
1322
1323            let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1324                0.0
1325            } else {
1326                samples[s_idx as usize]
1327            };
1328
1329            *sample = sample_val * self.window[i];
1330        }
1331        let fft_out = workspace.fft_out.as_mut_slice();
1332        // FFT
1333        self.fft.process(out, fft_out)?;
1334
1335        // Convert complex spectrum to linear magnitude OR power here? No:
1336        // Keep "spectrum" as power by default? That would entangle semantics.
1337        //
1338        // Instead, we store magnitude^2 (power) as the canonical intermediate,
1339        // and let AmpScale decide later whether output is magnitude or power.
1340        //
1341        // This is consistent and avoids recomputing norms multiple times.
1342        for (i, c) in workspace.fft_out.iter().enumerate() {
1343            workspace.spectrum[i] = c.norm_sqr();
1344        }
1345
1346        Ok(())
1347    }
1348
1349    /// Fill a time-domain frame WITHOUT applying the window.
1350    ///
1351    /// This is used for CQT mapping, where the CQT kernels already contain
1352    /// windowing applied during kernel generation. Applying the STFT window
1353    /// would result in double-windowing.
1354    ///
1355    /// # Arguments
1356    ///
1357    /// * `samples` - Input audio samples
1358    /// * `frame_idx` - Frame index
1359    /// * `workspace` - Workspace containing the frame buffer to fill
1360    ///
1361    /// # Errors
1362    ///
1363    /// Returns an error if frame index is out of bounds.
1364    fn fill_frame_unwindowed(
1365        &self,
1366        samples: &NonEmptySlice<f64>,
1367        frame_idx: usize,
1368        workspace: &mut Workspace,
1369    ) -> SpectrogramResult<()> {
1370        let out = workspace.frame.as_mut_slice();
1371        debug_assert_eq!(out.len(), self.n_fft.get());
1372
1373        let pad = if self.centre { self.n_fft.get() / 2 } else { 0 };
1374        let start = frame_idx
1375            .checked_mul(self.hop_size().get())
1376            .ok_or_else(|| SpectrogramError::invalid_input("frame index overflow"))?;
1377
1378        // Fill frame WITHOUT windowing
1379        for (i, sample) in out.iter_mut().enumerate().take(self.n_fft.get()) {
1380            let v_idx = start + i;
1381            let s_idx = v_idx as isize - pad as isize;
1382
1383            let sample_val = if s_idx < 0 || (s_idx as usize) >= samples.len().get() {
1384                0.0
1385            } else {
1386                samples[s_idx as usize]
1387            };
1388
1389            // No window multiplication - just copy the sample
1390            *sample = sample_val;
1391        }
1392
1393        Ok(())
1394    }
1395
1396    /// Compute the full STFT for a signal, returning an `StftResult`.
1397    ///
1398    /// This is a convenience method that handles frame iteration and
1399    /// builds the complete STFT matrix.
1400    ///
1401    ///
1402    /// # Arguments
1403    ///
1404    /// * `samples` - Input audio samples
1405    /// * `params` - STFT computation parameters
1406    ///
1407    /// # Returns
1408    ///
1409    /// An `StftResult` containing the complex STFT matrix and metadata.
1410    ///
1411    /// # Errors
1412    ///
1413    /// Returns an error if computation fails.
1414    ///
1415    /// # Examples
1416    ///
1417    /// ```rust
1418    /// use spectrograms::*;
1419    /// use non_empty_slice::non_empty_vec;
1420    ///
1421    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1422    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1423    /// let params = SpectrogramParams::new(stft, 16000.0)?;
1424    /// let mut plan = StftPlan::new(&params)?;
1425    ///
1426    /// let samples = non_empty_vec![0.0; nzu!(16000)];
1427    /// let stft_result = plan.compute(&samples, &params)?;
1428    ///
1429    /// println!("STFT: {} bins x {} frames", stft_result.n_bins(), stft_result.n_frames());
1430    /// # Ok(())
1431    /// # }
1432    /// ```
1433    #[inline]
1434    pub fn compute(
1435        &mut self,
1436        samples: &NonEmptySlice<f64>,
1437        params: &SpectrogramParams,
1438    ) -> SpectrogramResult<StftResult> {
1439        let n_frames = self.frame_count(samples.len())?;
1440        let n_bins = self.out_len;
1441
1442        // Allocate output matrix (frequency_bins x time_frames)
1443        let mut data = Array2::<Complex<f64>>::zeros((n_bins.get(), n_frames.get()));
1444
1445        // Compute each frame
1446        for frame_idx in 0..n_frames.get() {
1447            self.compute_frame_fft_simple(samples, frame_idx)?;
1448
1449            // Copy from internal buffer to output
1450            for (bin_idx, &value) in self.fft_out.iter().enumerate() {
1451                data[[bin_idx, frame_idx]] = value;
1452            }
1453        }
1454
1455        // Build frequency axis
1456        let frequencies: Vec<f64> = (0..n_bins.get())
1457            .map(|k| k as f64 * params.sample_rate_hz() / params.stft().n_fft().get() as f64)
1458            .collect();
1459        // SAFETY: n_bins > 0
1460        let frequencies = unsafe { NonEmptyVec::new_unchecked(frequencies) };
1461
1462        Ok(StftResult {
1463            data,
1464            frequencies,
1465            sample_rate: params.sample_rate_hz(),
1466            params: params.stft().clone(),
1467        })
1468    }
1469
1470    /// Compute a single frame of STFT, returning the complex spectrum.
1471    ///
1472    /// This is useful for streaming/online processing where you want to
1473    /// process audio frame-by-frame.
1474    ///
1475    /// # Arguments
1476    ///
1477    /// * `samples` - Input audio samples
1478    /// * `frame_idx` - Index of the frame to compute
1479    ///
1480    /// # Returns
1481    ///
1482    /// A `NonEmptyVec` containing the complex spectrum for the specified frame.
1483    ///
1484    /// # Errors
1485    ///
1486    /// Returns an error if computation fails.
1487    ///
1488    /// # Examples
1489    ///
1490    /// ```rust
1491    /// use spectrograms::*;
1492    /// use non_empty_slice::non_empty_vec;
1493    ///
1494    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1495    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1496    /// let params = SpectrogramParams::new(stft, 16000.0)?;
1497    /// let mut plan = StftPlan::new(&params)?;
1498    ///
1499    /// let samples = non_empty_vec![0.0; nzu!(16000)];
1500    /// let (_, n_frames) = plan.output_shape(samples.len())?;
1501    ///
1502    /// for frame_idx in 0..n_frames.get() {
1503    ///     let spectrum = plan.compute_frame_simple(&samples, frame_idx)?;
1504    ///     // Process spectrum...
1505    /// }
1506    /// # Ok(())
1507    /// # }
1508    /// ```
1509    #[inline]
1510    pub fn compute_frame_simple(
1511        &mut self,
1512        samples: &NonEmptySlice<f64>,
1513        frame_idx: usize,
1514    ) -> SpectrogramResult<NonEmptyVec<Complex<f64>>> {
1515        self.compute_frame_fft_simple(samples, frame_idx)?;
1516        Ok(self.fft_out.clone())
1517    }
1518
1519    /// Compute STFT into a pre-allocated buffer.
1520    ///
1521    /// This avoids allocating the output matrix, useful for reusing buffers.
1522    ///
1523    /// # Arguments
1524    ///
1525    /// * `samples` - Input audio samples
1526    /// * `output` - Pre-allocated output buffer (shape: `n_bins` x `n_frames`)
1527    ///  
1528    /// # Returns
1529    ///
1530    /// `Ok(())` on success, or an error if dimensions mismatch.
1531    ///
1532    /// # Errors
1533    ///
1534    /// Returns `DimensionMismatch` error if the output buffer has incorrect shape.
1535    ///
1536    /// # Examples
1537    ///
1538    /// ```rust
1539    /// use spectrograms::*;
1540    /// use ndarray::Array2;
1541    /// use num_complex::Complex;
1542    /// use non_empty_slice::non_empty_vec;
1543    ///
1544    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1545    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
1546    /// let params = SpectrogramParams::new(stft, 16000.0)?;
1547    /// let mut plan = StftPlan::new(&params)?;
1548    ///
1549    /// let samples = non_empty_vec![0.0; nzu!(16000)];
1550    /// let (n_bins, n_frames) = plan.output_shape(samples.len())?;
1551    /// let mut output = Array2::<Complex<f64>>::zeros((n_bins.get(), n_frames.get()));
1552    ///
1553    /// plan.compute_into(&samples, &mut output)?;
1554    /// # Ok(())
1555    /// # }
1556    /// ```
1557    #[inline]
1558    pub fn compute_into(
1559        &mut self,
1560        samples: &NonEmptySlice<f64>,
1561        output: &mut Array2<Complex<f64>>,
1562    ) -> SpectrogramResult<()> {
1563        let n_frames = self.frame_count(samples.len())?;
1564        let n_bins = self.out_len;
1565
1566        // Validate output dimensions
1567        if output.nrows() != n_bins.get() {
1568            return Err(SpectrogramError::dimension_mismatch(
1569                n_bins.get(),
1570                output.nrows(),
1571            ));
1572        }
1573        if output.ncols() != n_frames.get() {
1574            return Err(SpectrogramError::dimension_mismatch(
1575                n_frames.get(),
1576                output.ncols(),
1577            ));
1578        }
1579
1580        // Compute into pre-allocated buffer
1581        for frame_idx in 0..n_frames.get() {
1582            self.compute_frame_fft_simple(samples, frame_idx)?;
1583
1584            for (bin_idx, &value) in self.fft_out.iter().enumerate() {
1585                output[[bin_idx, frame_idx]] = value;
1586            }
1587        }
1588
1589        Ok(())
1590    }
1591
1592    /// Get the expected output dimensions for a given signal length.
1593    ///
1594    /// # Arguments
1595    ///
1596    /// * `signal_length` - Length of the input signal in samples
1597    ///
1598    /// # Returns
1599    ///
1600    /// A tuple `(n_frequency_bins, n_time_frames)`.
1601    ///
1602    /// # Errors
1603    ///
1604    /// Returns an error if the computed number of frames is invalid.
1605    #[inline]
1606    pub fn output_shape(
1607        &self,
1608        signal_length: NonZeroUsize,
1609    ) -> SpectrogramResult<(NonZeroUsize, NonZeroUsize)> {
1610        let n_frames = self.frame_count(signal_length)?;
1611        Ok((self.out_len, n_frames))
1612    }
1613
1614    /// Get the number of frequency bins in the output.
1615    ///
1616    /// # Returns
1617    ///
1618    /// The number of frequency bins.
1619    #[inline]
1620    #[must_use]
1621    pub const fn n_bins(&self) -> NonZeroUsize {
1622        self.out_len
1623    }
1624
1625    /// Get the FFT size.
1626    ///
1627    /// # Returns
1628    ///
1629    /// The FFT size.
1630    #[inline]
1631    #[must_use]
1632    pub const fn n_fft(&self) -> NonZeroUsize {
1633        self.n_fft
1634    }
1635
1636    /// Get the hop size.
1637    ///
1638    /// # Returns
1639    ///
1640    /// The hop size.
1641    #[inline]
1642    #[must_use]
1643    pub const fn hop_size(&self) -> NonZeroUsize {
1644        self.hop_size
1645    }
1646}
1647
1648#[derive(Debug, Clone)]
1649enum MappingKind {
1650    Identity {
1651        out_len: NonZeroUsize,
1652    },
1653    Mel {
1654        matrix: SparseMatrix,
1655    }, // shape: (n_mels, out_len)
1656    LogHz {
1657        matrix: SparseMatrix,
1658        frequencies: NonEmptyVec<f64>,
1659    }, // shape: (n_bins, out_len)
1660    Erb {
1661        filterbank: ErbFilterbank,
1662    },
1663    Cqt {
1664        kernel: CqtKernel,
1665    },
1666}
1667
1668impl MappingKind {
1669    /// Check if this mapping requires unwindowed time-domain frames.
1670    ///
1671    /// CQT kernels already contain windowing applied during kernel generation,
1672    /// so they should receive unwindowed frames to avoid double-windowing.
1673    /// All other mappings work on FFT spectra and don't use the time-domain frame.
1674    const fn needs_unwindowed_frame(&self) -> bool {
1675        matches!(self, Self::Cqt { .. })
1676    }
1677}
1678
1679/// Typed mapping wrapper.
1680#[derive(Debug, Clone)]
1681struct FrequencyMapping<FreqScale> {
1682    kind: MappingKind,
1683    _marker: PhantomData<FreqScale>,
1684}
1685
1686impl FrequencyMapping<LinearHz> {
1687    fn new(params: &SpectrogramParams) -> SpectrogramResult<Self> {
1688        let out_len = r2c_output_size(params.stft().n_fft().get());
1689        let out_len = NonZeroUsize::new(out_len)
1690            .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1691        Ok(Self {
1692            kind: MappingKind::Identity { out_len },
1693            _marker: PhantomData,
1694        })
1695    }
1696}
1697
1698impl FrequencyMapping<Mel> {
1699    fn new_mel(params: &SpectrogramParams, mel: &MelParams) -> SpectrogramResult<Self> {
1700        let n_fft = params.stft().n_fft();
1701        let out_len = r2c_output_size(n_fft.get());
1702        let out_len = NonZeroUsize::new(out_len)
1703            .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1704
1705        // Validate: mel bins must be <= something sensible
1706        if mel.n_mels() > nzu!(10_000) {
1707            return Err(SpectrogramError::invalid_input(
1708                "n_mels is unreasonably large",
1709            ));
1710        }
1711
1712        let matrix = build_mel_filterbank_matrix(
1713            params.sample_rate_hz(),
1714            n_fft,
1715            mel.n_mels(),
1716            mel.f_min(),
1717            mel.f_max(),
1718            mel.norm(),
1719        )?;
1720
1721        // matrix must be (n_mels, out_len)
1722        if matrix.nrows() != mel.n_mels().get() || matrix.ncols() != out_len.get() {
1723            return Err(SpectrogramError::invalid_input(
1724                "mel filterbank matrix shape mismatch",
1725            ));
1726        }
1727
1728        Ok(Self {
1729            kind: MappingKind::Mel { matrix },
1730            _marker: PhantomData,
1731        })
1732    }
1733}
1734
1735impl FrequencyMapping<LogHz> {
1736    fn new_loghz(params: &SpectrogramParams, loghz: &LogHzParams) -> SpectrogramResult<Self> {
1737        let n_fft = params.stft().n_fft();
1738        let out_len = r2c_output_size(n_fft.get());
1739        let out_len = NonZeroUsize::new(out_len)
1740            .ok_or_else(|| SpectrogramError::invalid_input("FFT output length must be non-zero"))?;
1741        // Validate: n_bins must be <= something sensible
1742        if loghz.n_bins() > nzu!(10_000) {
1743            return Err(SpectrogramError::invalid_input(
1744                "n_bins is unreasonably large",
1745            ));
1746        }
1747
1748        let (matrix, frequencies) = build_loghz_matrix(
1749            params.sample_rate_hz(),
1750            n_fft,
1751            loghz.n_bins(),
1752            loghz.f_min(),
1753            loghz.f_max(),
1754        )?;
1755
1756        // matrix must be (n_bins, out_len)
1757        if matrix.nrows() != loghz.n_bins().get() || matrix.ncols() != out_len.get() {
1758            return Err(SpectrogramError::invalid_input(
1759                "loghz matrix shape mismatch",
1760            ));
1761        }
1762
1763        Ok(Self {
1764            kind: MappingKind::LogHz {
1765                matrix,
1766                frequencies,
1767            },
1768            _marker: PhantomData,
1769        })
1770    }
1771}
1772
1773impl FrequencyMapping<Erb> {
1774    fn new_erb(params: &SpectrogramParams, erb: &crate::erb::ErbParams) -> SpectrogramResult<Self> {
1775        let n_fft = params.stft().n_fft();
1776        let sample_rate = params.sample_rate_hz();
1777
1778        // Validate: n_filters must be <= something sensible
1779        if erb.n_filters() > nzu!(10_000) {
1780            return Err(SpectrogramError::invalid_input(
1781                "n_filters is unreasonably large",
1782            ));
1783        }
1784
1785        // Generate ERB filterbank with pre-computed frequency responses
1786        let filterbank = crate::erb::ErbFilterbank::generate(erb, sample_rate, n_fft)?;
1787
1788        Ok(Self {
1789            kind: MappingKind::Erb { filterbank },
1790            _marker: PhantomData,
1791        })
1792    }
1793}
1794
1795impl FrequencyMapping<Cqt> {
1796    fn new(params: &SpectrogramParams, cqt: &CqtParams) -> SpectrogramResult<Self> {
1797        let sample_rate = params.sample_rate_hz();
1798        let n_fft = params.stft().n_fft();
1799
1800        // Validate that frequency range is reasonable
1801        let f_max = cqt.bin_frequency(cqt.num_bins().get().saturating_sub(1));
1802        if f_max >= sample_rate / 2.0 {
1803            return Err(SpectrogramError::invalid_input(
1804                "CQT maximum frequency must be below Nyquist frequency",
1805            ));
1806        }
1807
1808        // Generate CQT kernel using n_fft as the signal length for kernel generation
1809        let kernel = CqtKernel::generate(cqt, sample_rate, n_fft);
1810
1811        Ok(Self {
1812            kind: MappingKind::Cqt { kernel },
1813            _marker: PhantomData,
1814        })
1815    }
1816}
1817
1818impl<FreqScale> FrequencyMapping<FreqScale> {
1819    const fn output_bins(&self) -> NonZeroUsize {
1820        // safety: all variants ensure output bins > 0 OR rely on a matrix that is guaranteed to have rows > 0
1821        match &self.kind {
1822            MappingKind::Identity { out_len } => *out_len,
1823            // safety: matrix.nrows() > 0
1824            MappingKind::LogHz { matrix, .. } | MappingKind::Mel { matrix } => unsafe {
1825                NonZeroUsize::new_unchecked(matrix.nrows())
1826            },
1827            MappingKind::Erb { filterbank, .. } => filterbank.num_filters(),
1828            MappingKind::Cqt { kernel, .. } => kernel.num_bins(),
1829        }
1830    }
1831
1832    fn apply(
1833        &self,
1834        spectrum: &NonEmptySlice<f64>,
1835        frame: &NonEmptySlice<f64>,
1836        out: &mut NonEmptySlice<f64>,
1837    ) -> SpectrogramResult<()> {
1838        match &self.kind {
1839            MappingKind::Identity { out_len } => {
1840                if spectrum.len() != *out_len {
1841                    return Err(SpectrogramError::dimension_mismatch(
1842                        (*out_len).get(),
1843                        spectrum.len().get(),
1844                    ));
1845                }
1846                if out.len() != *out_len {
1847                    return Err(SpectrogramError::dimension_mismatch(
1848                        (*out_len).get(),
1849                        out.len().get(),
1850                    ));
1851                }
1852                out.copy_from_slice(spectrum);
1853                Ok(())
1854            }
1855            MappingKind::LogHz { matrix, .. } | MappingKind::Mel { matrix } => {
1856                let out_bins = matrix.nrows();
1857                let in_bins = matrix.ncols();
1858
1859                if spectrum.len().get() != in_bins {
1860                    return Err(SpectrogramError::dimension_mismatch(
1861                        in_bins,
1862                        spectrum.len().get(),
1863                    ));
1864                }
1865                if out.len().get() != out_bins {
1866                    return Err(SpectrogramError::dimension_mismatch(
1867                        out_bins,
1868                        out.len().get(),
1869                    ));
1870                }
1871
1872                // Sparse matrix-vector multiplication: out = matrix * spectrum
1873                matrix.multiply_vec(spectrum, out);
1874                Ok(())
1875            }
1876            MappingKind::Erb { filterbank } => {
1877                // Apply ERB filterbank using pre-computed frequency responses
1878                // The filterbank already has |H(f)|^2 pre-computed, so we just
1879                // apply it to the power spectrum
1880                let erb_out = filterbank.apply_to_power_spectrum(spectrum)?;
1881
1882                if out.len().get() != erb_out.len().get() {
1883                    return Err(SpectrogramError::dimension_mismatch(
1884                        erb_out.len().get(),
1885                        out.len().get(),
1886                    ));
1887                }
1888
1889                out.copy_from_slice(&erb_out);
1890                Ok(())
1891            }
1892            MappingKind::Cqt { kernel } => {
1893                // CQT works on time-domain unwindowed frame (not FFT spectrum).
1894                // The CQT kernels contain windowing applied during kernel generation,
1895                // so the input frame should be unwindowed to avoid double-windowing.
1896                let cqt_complex = kernel.apply(frame)?;
1897
1898                if out.len().get() != cqt_complex.len().get() {
1899                    return Err(SpectrogramError::dimension_mismatch(
1900                        cqt_complex.len().get(),
1901                        out.len().get(),
1902                    ));
1903                }
1904
1905                // Convert complex coefficients to power (|z|^2)
1906                // This matches the convention where intermediate values are in power domain
1907                for (i, c) in cqt_complex.iter().enumerate() {
1908                    out[i] = c.norm_sqr();
1909                }
1910
1911                Ok(())
1912            }
1913        }
1914    }
1915
1916    fn frequencies_hz(&self, params: &SpectrogramParams) -> NonEmptyVec<f64> {
1917        match &self.kind {
1918            MappingKind::Identity { out_len } => {
1919                // Standard R2C bins: k * sr / n_fft
1920                let n_fft = params.stft().n_fft().get() as f64;
1921                let sr = params.sample_rate_hz();
1922                let df = sr / n_fft;
1923
1924                let mut f = Vec::with_capacity((*out_len).get());
1925                for k in 0..(*out_len).get() {
1926                    f.push(k as f64 * df);
1927                }
1928                // safety: out_len > 0
1929                unsafe { NonEmptyVec::new_unchecked(f) }
1930            }
1931            MappingKind::Mel { matrix } => {
1932                // For mel, the axis is defined by the mel band centre frequencies.
1933                // We compute and store them consistently with how we built the filterbank.
1934                let n_mels = matrix.nrows();
1935                // safety: n_mels > 0
1936                let n_mels = unsafe { NonZeroUsize::new_unchecked(n_mels) };
1937                mel_band_centres_hz(n_mels, params.sample_rate_hz(), params.nyquist_hz())
1938            }
1939            MappingKind::LogHz { frequencies, .. } => {
1940                // Frequencies are stored when the mapping is created
1941                frequencies.clone()
1942            }
1943            MappingKind::Erb { filterbank, .. } => {
1944                // ERB center frequencies
1945                filterbank.center_frequencies().to_non_empty_vec()
1946            }
1947            MappingKind::Cqt { kernel, .. } => {
1948                // CQT center frequencies from the kernel
1949                kernel.frequencies().to_non_empty_vec()
1950            }
1951        }
1952    }
1953}
1954
1955//
1956// ========================
1957// Amplitude scaling
1958// ========================
1959//
1960
1961/// Marker trait so we can specialise behaviour by `AmpScale`.
1962pub trait AmpScaleSpec: Sized + Send + Sync {
1963    /// Apply conversion from power-domain value to the desired amplitude scale.
1964    ///
1965    /// # Arguments
1966    ///
1967    /// - `power`: input power-domain value.
1968    ///
1969    /// # Returns
1970    ///
1971    /// Converted amplitude value.
1972    fn apply_from_power(power: f64) -> f64;
1973
1974    /// Apply dB conversion in-place on a power-domain vector.
1975    ///
1976    /// This is a no-op for Power and Magnitude scales.
1977    ///
1978    /// # Arguments
1979    ///
1980    /// - `x`: power-domain values to convert to dB in-place.
1981    /// - `floor_db`: dB floor value to apply.
1982    ///
1983    /// # Returns
1984    ///
1985    /// `SpectrogramResult<()>`: Ok on success, error on invalid input.
1986    ///
1987    /// # Errors
1988    ///
1989    /// - If `floor_db` is not finite.
1990    fn apply_db_in_place(x: &mut [f64], floor_db: f64) -> SpectrogramResult<()>;
1991}
1992
1993impl AmpScaleSpec for Power {
1994    #[inline]
1995    fn apply_from_power(power: f64) -> f64 {
1996        power
1997    }
1998
1999    #[inline]
2000    fn apply_db_in_place(_x: &mut [f64], _floor_db: f64) -> SpectrogramResult<()> {
2001        Ok(())
2002    }
2003}
2004
2005impl AmpScaleSpec for Magnitude {
2006    #[inline]
2007    fn apply_from_power(power: f64) -> f64 {
2008        power.sqrt()
2009    }
2010
2011    #[inline]
2012    fn apply_db_in_place(_x: &mut [f64], _floor_db: f64) -> SpectrogramResult<()> {
2013        Ok(())
2014    }
2015}
2016
2017impl AmpScaleSpec for Decibels {
2018    #[inline]
2019    fn apply_from_power(power: f64) -> f64 {
2020        // dB conversion is applied in batch, not here.
2021        power
2022    }
2023
2024    #[inline]
2025    fn apply_db_in_place(x: &mut [f64], floor_db: f64) -> SpectrogramResult<()> {
2026        // Convert power -> dB: 10*log10(max(power, eps))
2027        // where eps is derived from floor_db to ensure consistency
2028        if !floor_db.is_finite() {
2029            return Err(SpectrogramError::invalid_input("floor_db must be finite"));
2030        }
2031
2032        // Convert floor_db to linear scale to get epsilon
2033        // e.g., floor_db = -80 dB -> eps = 10^(-80/10) = 1e-8
2034        let eps = 10.0_f64.powf(floor_db / 10.0);
2035
2036        for v in x.iter_mut() {
2037            // Clamp power to epsilon before log to avoid log(0) and ensure floor
2038            *v = 10.0 * v.max(eps).log10();
2039        }
2040        Ok(())
2041    }
2042}
2043
2044/// Amplitude scaling configuration.
2045///
2046/// This handles conversion from power-domain intermediate to the desired amplitude scale (Power, Magnitude, Decibels).
2047#[derive(Debug, Clone)]
2048struct AmplitudeScaling<AmpScale> {
2049    db_floor: Option<f64>,
2050    _marker: PhantomData<AmpScale>,
2051}
2052
2053impl<AmpScale> AmplitudeScaling<AmpScale>
2054where
2055    AmpScale: AmpScaleSpec + 'static,
2056{
2057    fn new(db: Option<&LogParams>) -> Self {
2058        let db_floor = db.map(LogParams::floor_db);
2059        Self {
2060            db_floor,
2061            _marker: PhantomData,
2062        }
2063    }
2064
2065    /// Apply amplitude scaling in-place on a mapped spectrum vector.
2066    ///
2067    /// The input vector is assumed to be in the *power* domain (|X|^2),
2068    /// because the STFT stage produces power as the canonical intermediate.
2069    ///
2070    /// - Power: leaves values unchanged.
2071    /// - Magnitude: sqrt(power).
2072    /// - Decibels: converts power -> dB and floors at `db_floor`.
2073    pub fn apply_in_place(&self, x: &mut [f64]) -> SpectrogramResult<()> {
2074        // Convert from canonical power-domain intermediate into the requested linear domain.
2075        for v in x.iter_mut() {
2076            *v = AmpScale::apply_from_power(*v);
2077        }
2078
2079        // Apply dB conversion if configured (no-op for Power/Magnitude via trait impls).
2080        if let Some(floor_db) = self.db_floor {
2081            AmpScale::apply_db_in_place(x, floor_db)?;
2082        }
2083
2084        Ok(())
2085    }
2086}
2087
2088#[derive(Debug, Clone)]
2089struct Workspace {
2090    spectrum: NonEmptyVec<f64>,         // out_len (power spectrum)
2091    mapped: NonEmptyVec<f64>,           // n_bins (after mapping)
2092    frame: NonEmptyVec<f64>,            // n_fft (windowed frame for FFT)
2093    fft_out: NonEmptyVec<Complex<f64>>, // out_len (FFT output)
2094}
2095
2096impl Workspace {
2097    fn new(n_fft: NonZeroUsize, out_len: NonZeroUsize, n_bins: NonZeroUsize) -> Self {
2098        Self {
2099            spectrum: non_empty_vec![0.0; out_len],
2100            mapped: non_empty_vec![0.0; n_bins],
2101            frame: non_empty_vec![0.0; n_fft],
2102            fft_out: non_empty_vec![Complex::new(0.0, 0.0); out_len],
2103        }
2104    }
2105
2106    fn ensure_sizes(&mut self, n_fft: NonZeroUsize, out_len: NonZeroUsize, n_bins: NonZeroUsize) {
2107        if self.spectrum.len() != out_len {
2108            self.spectrum.resize(out_len, 0.0);
2109        }
2110        if self.mapped.len() != n_bins {
2111            self.mapped.resize(n_bins, 0.0);
2112        }
2113        if self.frame.len() != n_fft {
2114            self.frame.resize(n_fft, 0.0);
2115        }
2116        if self.fft_out.len() != out_len {
2117            self.fft_out.resize(out_len, Complex::new(0.0, 0.0));
2118        }
2119    }
2120}
2121
2122fn build_frequency_axis<FreqScale>(
2123    params: &SpectrogramParams,
2124    mapping: &FrequencyMapping<FreqScale>,
2125) -> FrequencyAxis<FreqScale>
2126where
2127    FreqScale: Copy + Clone + 'static,
2128{
2129    let frequencies = mapping.frequencies_hz(params);
2130    FrequencyAxis::new(frequencies)
2131}
2132
2133fn build_time_axis_seconds(params: &SpectrogramParams, n_frames: NonZeroUsize) -> NonEmptyVec<f64> {
2134    let dt = params.frame_period_seconds();
2135    let mut times = Vec::with_capacity(n_frames.get());
2136
2137    for i in 0..n_frames.get() {
2138        times.push(i as f64 * dt);
2139    }
2140
2141    // safety: times is guaranteed non-empty since n_frames > 0
2142
2143    unsafe { NonEmptyVec::new_unchecked(times) }
2144}
2145
2146/// Generate window function samples.
2147///
2148/// Supports various window types including Rectangular, Hanning, Hamming, Blackman, Kaiser, and Gaussian.
2149///
2150/// # Arguments
2151///
2152/// * `window` - The type of window function to generate.
2153/// * `n_fft` - The size of the FFT, which determines the length of the window.
2154///
2155/// # Returns
2156///
2157/// A `NonEmptyVec<f64>` containing the window function samples.
2158///
2159/// # Panics
2160///
2161/// Panics if a custom window is provided with a size that does not match `n_fft`.
2162#[inline]
2163#[must_use]
2164pub fn make_window(window: WindowType, n_fft: NonZeroUsize) -> NonEmptyVec<f64> {
2165    let n_fft = n_fft.get();
2166    let mut w = vec![0.0; n_fft];
2167
2168    match window {
2169        WindowType::Rectangular => {
2170            w.fill(1.0);
2171        }
2172        WindowType::Hanning => {
2173            // Hann: 0.5 - 0.5*cos(2Ï€n/(N-1))
2174            let n1 = (n_fft - 1) as f64;
2175            for (n, v) in w.iter_mut().enumerate() {
2176                *v = 0.5f64.mul_add(-(2.0 * std::f64::consts::PI * (n as f64) / n1).cos(), 0.5);
2177            }
2178        }
2179        WindowType::Hamming => {
2180            // Hamming: 0.54 - 0.46*cos(2Ï€n/(N-1))
2181            let n1 = (n_fft - 1) as f64;
2182            for (n, v) in w.iter_mut().enumerate() {
2183                *v = 0.46f64.mul_add(-(2.0 * std::f64::consts::PI * (n as f64) / n1).cos(), 0.54);
2184            }
2185        }
2186        WindowType::Blackman => {
2187            // Blackman: 0.42 - 0.5*cos(2Ï€n/(N-1)) + 0.08*cos(4Ï€n/(N-1))
2188            let n1 = (n_fft - 1) as f64;
2189            for (n, v) in w.iter_mut().enumerate() {
2190                let a = 2.0 * std::f64::consts::PI * (n as f64) / n1;
2191                *v = 0.08f64.mul_add((2.0 * a).cos(), 0.5f64.mul_add(-a.cos(), 0.42));
2192            }
2193        }
2194        WindowType::Kaiser { beta } => {
2195            if n_fft == 1 {
2196                w[0] = 1.0;
2197            } else {
2198                let denom = modified_bessel_i0(beta);
2199                let n_max = (n_fft - 1) as f64 / 2.0;
2200
2201                for (i, value) in w.iter_mut().enumerate() {
2202                    let n = i as f64 - n_max;
2203                    let ratio = if n_max == 0.0 {
2204                        0.0
2205                    } else {
2206                        let normalized = n / n_max;
2207                        (1.0 - normalized * normalized).max(0.0)
2208                    };
2209                    let arg = beta * ratio.sqrt();
2210                    *value = if denom == 0.0 {
2211                        0.0
2212                    } else {
2213                        modified_bessel_i0(arg) / denom
2214                    };
2215                }
2216            }
2217        }
2218        WindowType::Gaussian { std } => (0..n_fft).for_each(|i| {
2219            let n = i as f64;
2220            let center: f64 = (n_fft - 1) as f64 / 2.0;
2221            let exponent: f64 = -0.5 * ((n - center) / std).powi(2);
2222            w[i] = exponent.exp();
2223        }),
2224        WindowType::Custom { coefficients, size } => {
2225            assert!(
2226                size.get() == n_fft,
2227                "Custom window size mismatch: expected {}, got {}. \
2228                 Custom windows must be pre-computed with the exact FFT size.",
2229                n_fft,
2230                size.get()
2231            );
2232            w.copy_from_slice(&coefficients);
2233        }
2234    }
2235
2236    // safety: window is guaranteed non-empty since n_fft > 0
2237    unsafe { NonEmptyVec::new_unchecked(w) }
2238}
2239
2240fn modified_bessel_i0(x: f64) -> f64 {
2241    let ax = x.abs();
2242    if ax <= 3.75 {
2243        let t = x / 3.75;
2244        let t2 = t * t;
2245        1.0 + t2
2246            * (3.515_622_9
2247                + t2 * (3.089_942_4
2248                    + t2 * (1.206_749_2
2249                        + t2 * (0.265_973_2 + t2 * (0.036_076_8 + t2 * 0.004_581_3)))))
2250    } else {
2251        let t = 3.75 / ax;
2252        let poly = 0.398_942_28
2253            + t * (0.013_285_92
2254                + t * (0.002_253_19
2255                    + t * (-0.001_575_65
2256                        + t * (0.009_162_81
2257                            + t * (-0.020_577_06
2258                                + t * (0.026_355_37 + t * (-0.016_476_33 + t * 0.003_923_77)))))));
2259
2260        (ax.exp() / (ax.sqrt() * (2.0 * std::f64::consts::PI).sqrt())) * poly
2261    }
2262}
2263
2264/// Convert Hz to mel scale using Slaney formula (librosa default, htk=False).
2265///
2266/// Uses a hybrid scale:
2267/// - Linear below 1000 Hz: mel = hz / (200/3)
2268/// - Logarithmic above 1000 Hz: mel = 15 + log(hz/1000) / log_step
2269///
2270/// This matches librosa's default behavior.
2271fn hz_to_mel(hz: f64) -> f64 {
2272    const F_MIN: f64 = 0.0;
2273    const F_SP: f64 = 200.0 / 3.0; // ~66.667
2274    const MIN_LOG_HZ: f64 = 1000.0;
2275    const MIN_LOG_MEL: f64 = (MIN_LOG_HZ - F_MIN) / F_SP; // = 15.0
2276    const LOGSTEP: f64 = 0.068_751_777_420_949_23; // ln(6.4) / 27
2277    if hz >= MIN_LOG_HZ {
2278        // Logarithmic region
2279        MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / LOGSTEP
2280    } else {
2281        // Linear region
2282        (hz - F_MIN) / F_SP
2283    }
2284}
2285
2286/// Convert mel to Hz using Slaney formula (librosa default, htk=False).
2287///
2288/// Inverse of hz_to_mel.
2289fn mel_to_hz(mel: f64) -> f64 {
2290    const F_MIN: f64 = 0.0;
2291    const F_SP: f64 = 200.0 / 3.0; // ~66.667
2292    const MIN_LOG_HZ: f64 = 1000.0;
2293    const MIN_LOG_MEL: f64 = (MIN_LOG_HZ - F_MIN) / F_SP; // = 15.0
2294    const LOGSTEP: f64 = 0.068_751_777_420_949_23; // ln(6.4) / 27
2295
2296    if mel >= MIN_LOG_MEL {
2297        // Logarithmic region
2298        MIN_LOG_HZ * (LOGSTEP * (mel - MIN_LOG_MEL)).exp()
2299    } else {
2300        // Linear region
2301        F_SP.mul_add(mel, F_MIN)
2302    }
2303}
2304
2305fn build_mel_filterbank_matrix(
2306    sample_rate_hz: f64,
2307    n_fft: NonZeroUsize,
2308    n_mels: NonZeroUsize,
2309    f_min: f64,
2310    f_max: f64,
2311    norm: MelNorm,
2312) -> SpectrogramResult<SparseMatrix> {
2313    if sample_rate_hz <= 0.0 || !sample_rate_hz.is_finite() {
2314        return Err(SpectrogramError::invalid_input(
2315            "sample_rate_hz must be finite and > 0",
2316        ));
2317    }
2318    if f_min < 0.0 || f_min.is_infinite() {
2319        return Err(SpectrogramError::invalid_input("f_min must be >= 0"));
2320    }
2321    if f_max <= f_min {
2322        return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
2323    }
2324    if f_max > sample_rate_hz * 0.5 {
2325        return Err(SpectrogramError::invalid_input("f_max must be <= Nyquist"));
2326    }
2327    let n_mels = n_mels.get();
2328    let n_fft = n_fft.get();
2329    let out_len = r2c_output_size(n_fft);
2330
2331    // FFT bin frequencies
2332    let df = sample_rate_hz / n_fft as f64;
2333
2334    // Mel points: n_mels + 2 (for triangular edges)
2335    let mel_min = hz_to_mel(f_min);
2336    let mel_max = hz_to_mel(f_max);
2337
2338    let n_points = n_mels + 2;
2339    let step = (mel_max - mel_min) / (n_points - 1) as f64;
2340
2341    let mut mel_points = Vec::with_capacity(n_points);
2342    for i in 0..n_points {
2343        mel_points.push((i as f64).mul_add(step, mel_min));
2344    }
2345
2346    let mut hz_points = Vec::with_capacity(n_points);
2347    for m in &mel_points {
2348        hz_points.push(mel_to_hz(*m));
2349    }
2350
2351    // Build filterbank as sparse matrix (librosa-style, in frequency space)
2352    // This builds triangular filters based on actual frequencies, not bin indices
2353    let mut fb = SparseMatrix::new(n_mels, out_len);
2354
2355    for m in 0..n_mels {
2356        let freq_left = hz_points[m];
2357        let freq_center = hz_points[m + 1];
2358        let freq_right = hz_points[m + 2];
2359
2360        let fdiff_left = freq_center - freq_left;
2361        let fdiff_right = freq_right - freq_center;
2362
2363        if fdiff_left == 0.0 || fdiff_right == 0.0 {
2364            // Degenerate triangle, skip
2365            continue;
2366        }
2367
2368        // For each FFT bin, compute the triangular weight based on its frequency
2369        for k in 0..out_len {
2370            let bin_freq = k as f64 * df;
2371
2372            // Lower slope: rises from freq_left to freq_center
2373            let lower = (bin_freq - freq_left) / fdiff_left;
2374
2375            // Upper slope: falls from freq_center to freq_right
2376            let upper = (freq_right - bin_freq) / fdiff_right;
2377
2378            // Triangle is the minimum of the two slopes, clipped to [0, 1]
2379            let weight = lower.min(upper).clamp(0.0, 1.0);
2380
2381            if weight > 0.0 {
2382                fb.set(m, k, weight);
2383            }
2384        }
2385    }
2386
2387    // Apply normalization
2388    match norm {
2389        MelNorm::None => {
2390            // No normalization needed
2391        }
2392        MelNorm::Slaney => {
2393            // Slaney-style area normalization: 2 / (hz_max - hz_min) for each triangle
2394            // NOTE: Uses Hz bandwidth, not mel bandwidth (to match librosa's implementation)
2395            for m in 0..n_mels {
2396                let mel_left = mel_points[m];
2397                let mel_right = mel_points[m + 2];
2398                let hz_left = mel_to_hz(mel_left);
2399                let hz_right = mel_to_hz(mel_right);
2400                let enorm = 2.0 / (hz_right - hz_left);
2401
2402                // Normalize all values in this row
2403                for val in &mut fb.values[m] {
2404                    *val *= enorm;
2405                }
2406            }
2407        }
2408        MelNorm::L1 => {
2409            // L1 normalization: sum of weights = 1.0
2410            for m in 0..n_mels {
2411                let sum: f64 = fb.values[m].iter().sum();
2412                if sum > 0.0 {
2413                    let normalizer = 1.0 / sum;
2414                    for val in &mut fb.values[m] {
2415                        *val *= normalizer;
2416                    }
2417                }
2418            }
2419        }
2420        MelNorm::L2 => {
2421            // L2 normalization: L2 norm = 1.0
2422            for m in 0..n_mels {
2423                let norm_val: f64 = fb.values[m].iter().map(|&v| v * v).sum::<f64>().sqrt();
2424                if norm_val > 0.0 {
2425                    let normalizer = 1.0 / norm_val;
2426                    for val in &mut fb.values[m] {
2427                        *val *= normalizer;
2428                    }
2429                }
2430            }
2431        }
2432    }
2433
2434    Ok(fb)
2435}
2436
2437/// Build a logarithmic frequency interpolation matrix.
2438///
2439/// Maps linearly-spaced FFT bins to logarithmically-spaced frequency bins
2440/// using linear interpolation.
2441fn build_loghz_matrix(
2442    sample_rate_hz: f64,
2443    n_fft: NonZeroUsize,
2444    n_bins: NonZeroUsize,
2445    f_min: f64,
2446    f_max: f64,
2447) -> SpectrogramResult<(SparseMatrix, NonEmptyVec<f64>)> {
2448    if sample_rate_hz <= 0.0 || !sample_rate_hz.is_finite() {
2449        return Err(SpectrogramError::invalid_input(
2450            "sample_rate_hz must be finite and > 0",
2451        ));
2452    }
2453    if f_min <= 0.0 || f_min.is_infinite() {
2454        return Err(SpectrogramError::invalid_input(
2455            "f_min must be finite and > 0",
2456        ));
2457    }
2458    if f_max <= f_min {
2459        return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
2460    }
2461    if f_max > sample_rate_hz * 0.5 {
2462        return Err(SpectrogramError::invalid_input("f_max must be <= Nyquist"));
2463    }
2464
2465    let n_bins = n_bins.get();
2466    let n_fft = n_fft.get();
2467
2468    let out_len = r2c_output_size(n_fft);
2469    let df = sample_rate_hz / n_fft as f64;
2470
2471    // Generate logarithmically-spaced frequencies
2472    let log_f_min = f_min.ln();
2473    let log_f_max = f_max.ln();
2474    let log_step = (log_f_max - log_f_min) / (n_bins - 1) as f64;
2475
2476    let mut log_frequencies = Vec::with_capacity(n_bins);
2477    for i in 0..n_bins {
2478        let log_f = (i as f64).mul_add(log_step, log_f_min);
2479        log_frequencies.push(log_f.exp());
2480    }
2481    // safety: n_bins > 0
2482    let log_frequencies = unsafe { NonEmptyVec::new_unchecked(log_frequencies) };
2483
2484    // Build interpolation matrix as sparse matrix
2485    let mut matrix = SparseMatrix::new(n_bins, out_len);
2486
2487    for (bin_idx, &target_freq) in log_frequencies.iter().enumerate() {
2488        // Find the two FFT bins that bracket this frequency
2489        let exact_bin = target_freq / df;
2490        let lower_bin = exact_bin.floor() as usize;
2491        let upper_bin = (exact_bin.ceil() as usize).min(out_len - 1);
2492
2493        if lower_bin >= out_len {
2494            continue;
2495        }
2496
2497        if lower_bin == upper_bin {
2498            // Exact match
2499            matrix.set(bin_idx, lower_bin, 1.0);
2500        } else {
2501            // Linear interpolation
2502            let frac = exact_bin - lower_bin as f64;
2503            matrix.set(bin_idx, lower_bin, 1.0 - frac);
2504            if upper_bin < out_len {
2505                matrix.set(bin_idx, upper_bin, frac);
2506            }
2507        }
2508    }
2509
2510    Ok((matrix, log_frequencies))
2511}
2512
2513fn mel_band_centres_hz(
2514    n_mels: NonZeroUsize,
2515    sample_rate_hz: f64,
2516    nyquist_hz: f64,
2517) -> NonEmptyVec<f64> {
2518    let f_min = 0.0;
2519    let f_max = nyquist_hz.min(sample_rate_hz * 0.5);
2520
2521    let mel_min = hz_to_mel(f_min);
2522    let mel_max = hz_to_mel(f_max);
2523    let n_mels = n_mels.get();
2524    let step = (mel_max - mel_min) / (n_mels + 1) as f64;
2525
2526    let mut centres = Vec::with_capacity(n_mels);
2527    for i in 0..n_mels {
2528        let mel = (i as f64 + 1.0).mul_add(step, mel_min);
2529        centres.push(mel_to_hz(mel));
2530    }
2531    // safety: centres is guaranteed non-empty since n_mels > 0
2532    unsafe { NonEmptyVec::new_unchecked(centres) }
2533}
2534
2535/// Spectrogram structure holding the computed spectrogram data and metadata.
2536///
2537/// # Type Parameters
2538///
2539/// * `FreqScale`: The frequency scale type (e.g., `LinearHz`, `Mel`, `LogHz`, etc.).
2540/// * `AmpScale`: The amplitude scale type (e.g., `Power`, `Magnitude`, `Decibels`).
2541///
2542/// # Fields
2543///
2544/// * `data`: A 2D array containing the spectrogram data.
2545/// * `axes`: The axes of the spectrogram (frequency and time).
2546/// * `params`: The parameters used to compute the spectrogram.
2547/// * `_amp`: A phantom data marker for the amplitude scale type.
2548#[derive(Debug, Clone)]
2549#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
2550pub struct Spectrogram<FreqScale, AmpScale>
2551where
2552    AmpScale: AmpScaleSpec + 'static,
2553    FreqScale: Copy + Clone + 'static,
2554{
2555    data: Array2<f64>,
2556    axes: Axes<FreqScale>,
2557    params: SpectrogramParams,
2558    #[cfg_attr(feature = "serde", serde(skip))]
2559    _amp: PhantomData<AmpScale>,
2560}
2561
2562impl<FreqScale, AmpScale> Spectrogram<FreqScale, AmpScale>
2563where
2564    AmpScale: AmpScaleSpec + 'static,
2565    FreqScale: Copy + Clone + 'static,
2566{
2567    /// Get the X-axis label.
2568    ///
2569    /// # Returns
2570    ///
2571    /// A static string slice representing the X-axis label.
2572    #[inline]
2573    #[must_use]
2574    pub const fn x_axis_label() -> &'static str {
2575        "Time (s)"
2576    }
2577
2578    /// Get the Y-axis label based on the frequency scale type.
2579    ///
2580    /// # Returns
2581    ///
2582    /// A static string slice representing the Y-axis label.
2583    #[inline]
2584    #[must_use]
2585    pub fn y_axis_label() -> &'static str {
2586        match std::any::TypeId::of::<FreqScale>() {
2587            id if id == std::any::TypeId::of::<LinearHz>() => "Frequency (Hz)",
2588            id if id == std::any::TypeId::of::<Mel>() => "Frequency (Mel)",
2589            id if id == std::any::TypeId::of::<LogHz>() => "Frequency (Log Hz)",
2590            id if id == std::any::TypeId::of::<Erb>() => "Frequency (ERB)",
2591            id if id == std::any::TypeId::of::<Cqt>() => "Frequency (CQT Bins)",
2592            _ => "Frequency",
2593        }
2594    }
2595
2596    /// Internal constructor. Only callable inside the crate.
2597    ///
2598    /// All inputs must already be validated and consistent.
2599    pub(crate) fn new(data: Array2<f64>, axes: Axes<FreqScale>, params: SpectrogramParams) -> Self {
2600        debug_assert_eq!(data.nrows(), axes.frequencies().len().get());
2601        debug_assert_eq!(data.ncols(), axes.times().len().get());
2602
2603        Self {
2604            data,
2605            axes,
2606            params,
2607            _amp: PhantomData,
2608        }
2609    }
2610
2611    /// Set spectrogram data matrix
2612    ///
2613    /// # Arguments
2614    ///
2615    /// * `data` - The new spectrogram data matrix.
2616    #[inline]
2617    pub fn set_data(&mut self, data: Array2<f64>) {
2618        self.data = data;
2619    }
2620
2621    /// Spectrogram data matrix
2622    ///
2623    /// # Returns
2624    ///
2625    /// A reference to the spectrogram data matrix.
2626    #[inline]
2627    #[must_use]
2628    pub const fn data(&self) -> &Array2<f64> {
2629        &self.data
2630    }
2631
2632    /// Consume the spectrogram and extract the data matrix.
2633    ///
2634    /// This method moves the data out of the spectrogram, consuming it.
2635    /// Useful for transferring ownership to Python without copying.
2636    ///
2637    /// # Returns
2638    ///
2639    /// The owned spectrogram data matrix.
2640    #[inline]
2641    #[must_use]
2642    pub fn into_data(self) -> Array2<f64> {
2643        self.data
2644    }
2645
2646    /// Axes of the spectrogram
2647    ///
2648    /// # Returns
2649    ///
2650    /// A reference to the axes of the spectrogram.
2651    #[inline]
2652    #[must_use]
2653    pub const fn axes(&self) -> &Axes<FreqScale> {
2654        &self.axes
2655    }
2656
2657    /// Frequency axis in Hz
2658    ///
2659    /// # Returns
2660    ///
2661    /// A reference to the frequency axis in Hz.
2662    #[inline]
2663    #[must_use]
2664    pub fn frequencies(&self) -> &NonEmptySlice<f64> {
2665        self.axes.frequencies()
2666    }
2667
2668    /// Frequency range in Hz (min, max)
2669    ///
2670    /// # Returns
2671    ///
2672    /// A tuple containing the minimum and maximum frequencies in Hz.
2673    #[inline]
2674    #[must_use]
2675    pub const fn frequency_range(&self) -> (f64, f64) {
2676        self.axes.frequency_range()
2677    }
2678
2679    /// Time axis in seconds
2680    ///
2681    /// # Returns
2682    ///
2683    /// A reference to the time axis in seconds.
2684    #[inline]
2685    #[must_use]
2686    pub fn times(&self) -> &NonEmptySlice<f64> {
2687        self.axes.times()
2688    }
2689
2690    /// Spectrogram computation parameters
2691    ///
2692    /// # Returns
2693    ///
2694    /// A reference to the spectrogram computation parameters.
2695    #[inline]
2696    #[must_use]
2697    pub const fn params(&self) -> &SpectrogramParams {
2698        &self.params
2699    }
2700
2701    /// Duration of the spectrogram in seconds
2702    ///
2703    /// # Returns
2704    ///
2705    /// The duration of the spectrogram in seconds.
2706    #[inline]
2707    #[must_use]
2708    pub fn duration(&self) -> f64 {
2709        self.axes.duration()
2710    }
2711
2712    /// If this is a dB spectrogram, return the (min, max) dB values. otherwise do the maths to compute dB range.
2713    ///
2714    /// # Returns
2715    ///
2716    /// The (min, max) dB values of the spectrogram, or `None` if the amplitude scale is unknown.
2717    #[inline]
2718    #[must_use]
2719    pub fn db_range(&self) -> Option<(f64, f64)> {
2720        let type_self = std::any::TypeId::of::<AmpScale>();
2721
2722        if type_self == std::any::TypeId::of::<Decibels>() {
2723            let (min, max) = min_max_single_pass(self.data.as_slice()?);
2724            Some((min, max))
2725        } else if type_self == std::any::TypeId::of::<Power>() {
2726            // Not a dB spectrogram; compute dB range from power values
2727            let mut min_db = f64::INFINITY;
2728            let mut max_db = f64::NEG_INFINITY;
2729            for &v in &self.data {
2730                let db = 10.0 * (v + EPS).log10();
2731                if db < min_db {
2732                    min_db = db;
2733                }
2734                if db > max_db {
2735                    max_db = db;
2736                }
2737            }
2738            Some((min_db, max_db))
2739        } else if type_self == std::any::TypeId::of::<Magnitude>() {
2740            // Not a dB spectrogram; compute dB range from magnitude values
2741            let mut min_db = f64::INFINITY;
2742            let mut max_db = f64::NEG_INFINITY;
2743
2744            for &v in &self.data {
2745                let power = v * v;
2746                let db = 10.0 * (power + EPS).log10();
2747                if db < min_db {
2748                    min_db = db;
2749                }
2750                if db > max_db {
2751                    max_db = db;
2752                }
2753            }
2754
2755            Some((min_db, max_db))
2756        } else {
2757            // Unknown AmpScale type; return dummy values
2758            None
2759        }
2760    }
2761
2762    /// Number of frequency bins
2763    ///
2764    /// # Returns
2765    ///
2766    /// The number of frequency bins in the spectrogram.
2767    #[inline]
2768    #[must_use]
2769    pub fn n_bins(&self) -> NonZeroUsize {
2770        // safety: data.nrows() > 0 is guaranteed by construction
2771        unsafe { NonZeroUsize::new_unchecked(self.data.nrows()) }
2772    }
2773
2774    /// Number of time frames in the spectrogram
2775    ///
2776    /// # Returns
2777    ///
2778    /// The number of time frames (columns) in the spectrogram.
2779    #[inline]
2780    #[must_use]
2781    pub fn n_frames(&self) -> NonZeroUsize {
2782        // safety: data.ncols() > 0 is guaranteed by construction
2783        unsafe { NonZeroUsize::new_unchecked(self.data.ncols()) }
2784    }
2785}
2786
2787impl<FreqScale, AmpScale> AsRef<Array2<f64>> for Spectrogram<FreqScale, AmpScale>
2788where
2789    FreqScale: Copy + Clone + 'static,
2790    AmpScale: AmpScaleSpec + 'static,
2791{
2792    #[inline]
2793    fn as_ref(&self) -> &Array2<f64> {
2794        &self.data
2795    }
2796}
2797
2798impl<FreqScale, AmpScale> Deref for Spectrogram<FreqScale, AmpScale>
2799where
2800    FreqScale: Copy + Clone + 'static,
2801    AmpScale: AmpScaleSpec + 'static,
2802{
2803    type Target = Array2<f64>;
2804
2805    #[inline]
2806    fn deref(&self) -> &Self::Target {
2807        &self.data
2808    }
2809}
2810
2811impl<AmpScale> Spectrogram<LinearHz, AmpScale>
2812where
2813    AmpScale: AmpScaleSpec + 'static,
2814{
2815    /// Compute a linear-frequency spectrogram from audio samples.
2816    ///
2817    /// This is a convenience method that creates a planner internally and computes
2818    /// the spectrogram in one call. For processing multiple signals with the same
2819    /// parameters, use [`SpectrogramPlanner::linear_plan`] to create a reusable plan.
2820    ///
2821    /// # Arguments
2822    ///
2823    /// * `samples` - Audio samples (any type that can be converted to a slice)
2824    /// * `params` - Spectrogram computation parameters
2825    /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2826    ///
2827    /// # Returns
2828    ///
2829    /// A linear-frequency spectrogram with the specified amplitude scale.
2830    ///
2831    /// # Errors
2832    ///
2833    /// Returns an error if:
2834    /// - The samples slice is empty
2835    /// - Parameters are invalid
2836    /// - FFT computation fails
2837    ///
2838    /// # Examples
2839    ///
2840    /// ```
2841    /// use spectrograms::*;
2842    /// use non_empty_slice::non_empty_vec;
2843    ///
2844    /// # fn example() -> SpectrogramResult<()> {
2845    /// // Create a simple test signal
2846    /// let sample_rate = 16000.0;
2847    /// let samples_vec: Vec<f64> = (0..16000).map(|i| {
2848    ///     (2.0 * std::f64::consts::PI * 440.0 * i as f64 / sample_rate).sin()
2849    /// }).collect();
2850    /// let samples = non_empty_slice::NonEmptyVec::new(samples_vec).unwrap();
2851    ///
2852    /// // Set up parameters
2853    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2854    /// let params = SpectrogramParams::new(stft, sample_rate)?;
2855    ///
2856    /// // Compute power spectrogram
2857    /// let spec = LinearPowerSpectrogram::compute(&samples, &params, None)?;
2858    ///
2859    /// println!("Computed spectrogram: {} bins x {} frames", spec.n_bins(), spec.n_frames());
2860    /// # Ok(())
2861    /// # }
2862    /// ```
2863    #[inline]
2864    pub fn compute(
2865        samples: &NonEmptySlice<f64>,
2866        params: &SpectrogramParams,
2867        db: Option<&LogParams>,
2868    ) -> SpectrogramResult<Self> {
2869        let planner = SpectrogramPlanner::new();
2870        let mut plan = planner.linear_plan(params, db)?;
2871        plan.compute(samples)
2872    }
2873}
2874
2875impl<AmpScale> Spectrogram<Mel, AmpScale>
2876where
2877    AmpScale: AmpScaleSpec + 'static,
2878{
2879    /// Compute a mel-frequency spectrogram from audio samples.
2880    ///
2881    /// This is a convenience method that creates a planner internally and computes
2882    /// the spectrogram in one call. For processing multiple signals with the same
2883    /// parameters, use [`SpectrogramPlanner::mel_plan`] to create a reusable plan.
2884    ///
2885    /// # Arguments
2886    ///
2887    /// * `samples` - Audio samples (any type that can be converted to a slice)
2888    /// * `params` - Spectrogram computation parameters
2889    /// * `mel` - Mel filterbank parameters
2890    /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2891    ///
2892    /// # Returns
2893    ///
2894    /// A mel-frequency spectrogram with the specified amplitude scale.
2895    ///
2896    /// # Errors
2897    ///
2898    /// Returns an error if:
2899    /// - The samples slice is empty
2900    /// - Parameters are invalid
2901    /// - Mel `f_max` exceeds Nyquist frequency
2902    /// - FFT computation fails
2903    ///
2904    /// # Examples
2905    ///
2906    /// ```
2907    /// use spectrograms::*;
2908    /// use non_empty_slice::non_empty_vec;
2909    ///
2910    /// # fn example() -> SpectrogramResult<()> {
2911    /// // Create a simple test signal
2912    /// let sample_rate = 16000.0;
2913    /// let samples_vec: Vec<f64> = (0..16000).map(|i| {
2914    ///     (2.0 * std::f64::consts::PI * 440.0 * i as f64 / sample_rate).sin()
2915    /// }).collect();
2916    /// let samples = non_empty_slice::NonEmptyVec::new(samples_vec).unwrap();
2917    ///
2918    /// // Set up parameters
2919    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2920    /// let params = SpectrogramParams::new(stft, sample_rate)?;
2921    /// let mel = MelParams::new(nzu!(80), 0.0, 8000.0)?;
2922    ///
2923    /// // Compute mel spectrogram in dB scale
2924    /// let db = LogParams::new(-80.0)?;
2925    /// let spec = MelDbSpectrogram::compute(&samples, &params, &mel, Some(&db))?;
2926    ///
2927    /// println!("Computed mel spectrogram: {} mels x {} frames", spec.n_bins(), spec.n_frames());
2928    /// # Ok(())
2929    /// # }
2930    /// ```
2931    #[inline]
2932    pub fn compute(
2933        samples: &NonEmptySlice<f64>,
2934        params: &SpectrogramParams,
2935        mel: &MelParams,
2936        db: Option<&LogParams>,
2937    ) -> SpectrogramResult<Self> {
2938        let planner = SpectrogramPlanner::new();
2939        let mut plan = planner.mel_plan(params, mel, db)?;
2940        plan.compute(samples)
2941    }
2942}
2943
2944impl<AmpScale> Spectrogram<Erb, AmpScale>
2945where
2946    AmpScale: AmpScaleSpec + 'static,
2947{
2948    /// Compute an ERB-frequency spectrogram from audio samples.
2949    ///
2950    /// This is a convenience method that creates a planner internally and computes
2951    /// the spectrogram in one call. For processing multiple signals with the same
2952    /// parameters, use [`SpectrogramPlanner::erb_plan`] to create a reusable plan.
2953    ///
2954    /// # Arguments
2955    ///
2956    /// * `samples` - Audio samples (any type that can be converted to a slice)
2957    /// * `params` - Spectrogram computation parameters
2958    /// * `erb` - ERB frequency scale parameters
2959    /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
2960    ///
2961    /// # Returns
2962    ///
2963    /// An ERB-scale spectrogram with the specified amplitude scale.
2964    ///
2965    /// # Errors
2966    ///
2967    /// Returns an error if:
2968    /// - The samples slice is empty
2969    /// - Parameters are invalid
2970    /// - FFT computation fails
2971    ///
2972    /// # Examples
2973    ///
2974    /// ```
2975    /// use spectrograms::*;
2976    /// use non_empty_slice::non_empty_vec;
2977    ///
2978    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2979    /// let samples = non_empty_vec![0.0; nzu!(16000)];
2980    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
2981    /// let params = SpectrogramParams::new(stft, 16000.0)?;
2982    /// let erb = ErbParams::speech_standard();
2983    ///
2984    /// let spec = ErbPowerSpectrogram::compute(&samples, &params, &erb, None)?;
2985    /// assert_eq!(spec.n_bins(), nzu!(40));
2986    /// # Ok(())
2987    /// # }
2988    /// ```
2989    #[inline]
2990    pub fn compute(
2991        samples: &NonEmptySlice<f64>,
2992        params: &SpectrogramParams,
2993        erb: &ErbParams,
2994        db: Option<&LogParams>,
2995    ) -> SpectrogramResult<Self> {
2996        let planner = SpectrogramPlanner::new();
2997        let mut plan = planner.erb_plan(params, erb, db)?;
2998        plan.compute(samples)
2999    }
3000}
3001
3002impl<AmpScale> Spectrogram<LogHz, AmpScale>
3003where
3004    AmpScale: AmpScaleSpec + 'static,
3005{
3006    /// Compute a logarithmic-frequency spectrogram from audio samples.
3007    ///
3008    /// This is a convenience method that creates a planner internally and computes
3009    /// the spectrogram in one call. For processing multiple signals with the same
3010    /// parameters, use [`SpectrogramPlanner::log_hz_plan`] to create a reusable plan.
3011    ///
3012    /// # Arguments
3013    ///
3014    /// * `samples` - Audio samples (any type that can be converted to a slice)
3015    /// * `params` - Spectrogram computation parameters
3016    /// * `loghz` - Logarithmic frequency scale parameters
3017    /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
3018    ///
3019    /// # Returns
3020    ///
3021    /// A logarithmic-frequency spectrogram with the specified amplitude scale.
3022    ///
3023    /// # Errors
3024    ///
3025    /// Returns an error if:
3026    /// - The samples slice is empty
3027    /// - Parameters are invalid
3028    /// - FFT computation fails
3029    ///
3030    /// # Examples
3031    ///
3032    /// ```
3033    /// use spectrograms::*;
3034    /// use non_empty_slice::non_empty_vec;
3035    ///
3036    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
3037    /// let samples = non_empty_vec![0.0; nzu!(16000)];
3038    /// let stft = StftParams::new(nzu!(512), nzu!(256), WindowType::Hanning, true)?;
3039    /// let params = SpectrogramParams::new(stft, 16000.0)?;
3040    /// let loghz = LogHzParams::new(nzu!(128), 20.0, 8000.0)?;
3041    ///
3042    /// let spec = LogHzPowerSpectrogram::compute(&samples, &params, &loghz, None)?;
3043    /// assert_eq!(spec.n_bins(), nzu!(128));
3044    /// # Ok(())
3045    /// # }
3046    /// ```
3047    #[inline]
3048    pub fn compute(
3049        samples: &NonEmptySlice<f64>,
3050        params: &SpectrogramParams,
3051        loghz: &LogHzParams,
3052        db: Option<&LogParams>,
3053    ) -> SpectrogramResult<Self> {
3054        let planner = SpectrogramPlanner::new();
3055        let mut plan = planner.log_hz_plan(params, loghz, db)?;
3056        plan.compute(samples)
3057    }
3058}
3059
3060impl<AmpScale> Spectrogram<Cqt, AmpScale>
3061where
3062    AmpScale: AmpScaleSpec + 'static,
3063{
3064    /// Compute a constant-Q transform (CQT) spectrogram from audio samples.
3065    ///
3066    /// # Arguments
3067    ///
3068    /// * `samples` - Audio samples (any type that can be converted to a slice)
3069    /// * `params` - Spectrogram computation parameters
3070    /// * `cqt` - CQT parameters
3071    /// * `db` - Optional logarithmic scaling parameters (only used when `AmpScale = Decibels`)
3072    ///
3073    /// # Returns
3074    ///
3075    /// A CQT spectrogram with the specified amplitude scale.
3076    ///
3077    /// # Errors
3078    ///
3079    /// Returns an error if:
3080    ///
3081    #[inline]
3082    pub fn compute(
3083        samples: &NonEmptySlice<f64>,
3084        params: &SpectrogramParams,
3085        cqt: &CqtParams,
3086        db: Option<&LogParams>,
3087    ) -> SpectrogramResult<Self> {
3088        let planner = SpectrogramPlanner::new();
3089        let mut plan = planner.cqt_plan(params, cqt, db)?;
3090        plan.compute(samples)
3091    }
3092}
3093
3094// ========================
3095// Display implementations
3096// ========================
3097
3098/// Helper function to get amplitude scale name
3099fn amp_scale_name<AmpScale>() -> &'static str
3100where
3101    AmpScale: AmpScaleSpec + 'static,
3102{
3103    match std::any::TypeId::of::<AmpScale>() {
3104        id if id == std::any::TypeId::of::<Power>() => "Power",
3105        id if id == std::any::TypeId::of::<Magnitude>() => "Magnitude",
3106        id if id == std::any::TypeId::of::<Decibels>() => "Decibels",
3107        _ => "Unknown",
3108    }
3109}
3110
3111/// Helper function to get frequency scale name
3112fn freq_scale_name<FreqScale>() -> &'static str
3113where
3114    FreqScale: Copy + Clone + 'static,
3115{
3116    match std::any::TypeId::of::<FreqScale>() {
3117        id if id == std::any::TypeId::of::<LinearHz>() => "Linear Hz",
3118        id if id == std::any::TypeId::of::<LogHz>() => "Log Hz",
3119        id if id == std::any::TypeId::of::<Mel>() => "Mel",
3120        id if id == std::any::TypeId::of::<Erb>() => "ERB",
3121        id if id == std::any::TypeId::of::<Cqt>() => "CQT",
3122        _ => "Unknown",
3123    }
3124}
3125
3126impl<FreqScale, AmpScale> core::fmt::Display for Spectrogram<FreqScale, AmpScale>
3127where
3128    AmpScale: AmpScaleSpec + 'static,
3129    FreqScale: Copy + Clone + 'static,
3130{
3131    #[inline]
3132    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
3133        let (freq_min, freq_max) = self.frequency_range();
3134        let duration = self.duration();
3135        let (rows, cols) = self.data.dim();
3136
3137        // Alternative formatting (#) provides more detailed output with data
3138        if f.alternate() {
3139            writeln!(f, "Spectrogram {{")?;
3140            writeln!(f, "  Frequency Scale: {}", freq_scale_name::<FreqScale>())?;
3141            writeln!(f, "  Amplitude Scale: {}", amp_scale_name::<AmpScale>())?;
3142            writeln!(f, "  Shape: {rows} frequency bins × {cols} time frames")?;
3143            writeln!(f, "  Frequency Range: {freq_min:.2} Hz - {freq_max:.2} Hz")?;
3144            writeln!(f, "  Duration: {duration:.3} s")?;
3145            writeln!(f)?;
3146
3147            // Display parameters
3148            writeln!(f, "  Parameters:")?;
3149            writeln!(f, "    Sample Rate: {} Hz", self.params.sample_rate_hz())?;
3150            writeln!(f, "    FFT Size: {}", self.params.stft().n_fft())?;
3151            writeln!(f, "    Hop Size: {}", self.params.stft().hop_size())?;
3152            writeln!(f, "    Window: {:?}", self.params.stft().window())?;
3153            writeln!(f, "    Centered: {}", self.params.stft().centre())?;
3154            writeln!(f)?;
3155
3156            // Display data statistics
3157            let data_slice = self.data.as_slice().unwrap_or(&[]);
3158            if !data_slice.is_empty() {
3159                let (min_val, max_val) = min_max_single_pass(data_slice);
3160                let mean = data_slice.iter().sum::<f64>() / data_slice.len() as f64;
3161                writeln!(f, "  Data Statistics:")?;
3162                writeln!(f, "    Min: {min_val:.6}")?;
3163                writeln!(f, "    Max: {max_val:.6}")?;
3164                writeln!(f, "    Mean: {mean:.6}")?;
3165                writeln!(f)?;
3166            }
3167
3168            // Display actual data (truncated if too large)
3169            writeln!(f, "  Data Matrix:")?;
3170            let max_rows_to_display = 5;
3171            let max_cols_to_display = 5;
3172
3173            for i in 0..rows.min(max_rows_to_display) {
3174                write!(f, "    [")?;
3175                for j in 0..cols.min(max_cols_to_display) {
3176                    if j > 0 {
3177                        write!(f, ", ")?;
3178                    }
3179                    write!(f, "{:9.4}", self.data[[i, j]])?;
3180                }
3181                if cols > max_cols_to_display {
3182                    write!(f, ", ... ({} more)", cols - max_cols_to_display)?;
3183                }
3184                writeln!(f, "]")?;
3185            }
3186
3187            if rows > max_rows_to_display {
3188                writeln!(f, "    ... ({} more rows)", rows - max_rows_to_display)?;
3189            }
3190
3191            write!(f, "}}")?;
3192        } else {
3193            // Default formatting: compact summary
3194            write!(
3195                f,
3196                "Spectrogram<{}, {}>[{}x{}] ({:.2}-{:.2} Hz, {:.3}s)",
3197                freq_scale_name::<FreqScale>(),
3198                amp_scale_name::<AmpScale>(),
3199                rows,
3200                cols,
3201                freq_min,
3202                freq_max,
3203                duration
3204            )?;
3205        }
3206
3207        Ok(())
3208    }
3209}
3210
3211#[derive(Debug, Clone)]
3212#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3213pub struct FrequencyAxis<FreqScale>
3214where
3215    FreqScale: Copy + Clone + 'static,
3216{
3217    frequencies: NonEmptyVec<f64>,
3218    #[cfg_attr(feature = "serde", serde(skip))]
3219    _marker: PhantomData<FreqScale>,
3220}
3221
3222impl<FreqScale> FrequencyAxis<FreqScale>
3223where
3224    FreqScale: Copy + Clone + 'static,
3225{
3226    pub(crate) const fn new(frequencies: NonEmptyVec<f64>) -> Self {
3227        Self {
3228            frequencies,
3229            _marker: PhantomData,
3230        }
3231    }
3232
3233    /// Get the frequency values in Hz.
3234    ///
3235    /// # Returns
3236    ///
3237    /// Returns a non-empty slice of frequencies.
3238    #[inline]
3239    #[must_use]
3240    pub fn frequencies(&self) -> &NonEmptySlice<f64> {
3241        &self.frequencies
3242    }
3243
3244    /// Get the frequency range (min, max) in Hz.
3245    ///
3246    /// # Returns
3247    ///
3248    /// Returns a tuple containing the minimum and maximum frequency.
3249    #[inline]
3250    #[must_use]
3251    pub const fn frequency_range(&self) -> (f64, f64) {
3252        let data = self.frequencies.as_slice();
3253        let min = data[0];
3254        let max_idx = data.len().saturating_sub(1); // safe for non-empty
3255        let max = data[max_idx];
3256        (min, max)
3257    }
3258
3259    /// Get the number of frequency bins.
3260    ///
3261    /// # Returns
3262    ///
3263    /// Returns the number of frequency bins as a NonZeroUsize.
3264    #[inline]
3265    #[must_use]
3266    pub const fn len(&self) -> NonZeroUsize {
3267        self.frequencies.len()
3268    }
3269}
3270
3271/// Spectrogram axes container.
3272///
3273/// Holds frequency and time axes.
3274#[derive(Debug, Clone)]
3275#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3276pub struct Axes<FreqScale>
3277where
3278    FreqScale: Copy + Clone + 'static,
3279{
3280    freq: FrequencyAxis<FreqScale>,
3281    times: NonEmptyVec<f64>,
3282}
3283
3284impl<FreqScale> Axes<FreqScale>
3285where
3286    FreqScale: Copy + Clone + 'static,
3287{
3288    pub(crate) const fn new(freq: FrequencyAxis<FreqScale>, times: NonEmptyVec<f64>) -> Self {
3289        Self { freq, times }
3290    }
3291
3292    /// Get the frequency values in Hz.
3293    ///
3294    /// # Returns
3295    ///
3296    /// Returns a non-empty slice of frequencies.
3297    #[inline]
3298    #[must_use]
3299    pub fn frequencies(&self) -> &NonEmptySlice<f64> {
3300        self.freq.frequencies()
3301    }
3302
3303    /// Get the time values in seconds.
3304    ///
3305    /// # Returns
3306    ///
3307    /// Returns a non-empty slice of time values.
3308    #[inline]
3309    #[must_use]
3310    pub fn times(&self) -> &NonEmptySlice<f64> {
3311        &self.times
3312    }
3313
3314    /// Get the frequency range (min, max) in Hz.
3315    ///
3316    /// # Returns
3317    ///
3318    /// Returns a tuple containing the minimum and maximum frequency.
3319    #[inline]
3320    #[must_use]
3321    pub const fn frequency_range(&self) -> (f64, f64) {
3322        self.freq.frequency_range()
3323    }
3324
3325    /// Get the duration of the spectrogram in seconds.
3326    ///
3327    /// # Returns
3328    ///
3329    /// Returns the duration in seconds.
3330    #[inline]
3331    #[must_use]
3332    pub fn duration(&self) -> f64 {
3333        *self.times.last()
3334    }
3335}
3336
3337// Enum types for frequency and amplitude scales
3338
3339/// Linear frequency scale
3340#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3341#[cfg_attr(feature = "python", pyclass(from_py_object))]
3342#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3343#[non_exhaustive]
3344pub enum LinearHz {
3345    _Phantom,
3346}
3347
3348/// Logarithmic frequency scale
3349#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3350#[cfg_attr(feature = "python", pyclass(from_py_object))]
3351#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3352#[non_exhaustive]
3353pub enum LogHz {
3354    _Phantom,
3355}
3356
3357/// Mel frequency scale
3358#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3359#[cfg_attr(feature = "python", pyclass(from_py_object))]
3360#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3361#[non_exhaustive]
3362pub enum Mel {
3363    _Phantom,
3364}
3365
3366/// ERB/gammatone frequency scale
3367#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3368#[cfg_attr(feature = "python", pyclass(from_py_object))]
3369#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3370#[non_exhaustive]
3371pub enum Erb {
3372    _Phantom,
3373}
3374pub type Gammatone = Erb;
3375
3376/// Constant-Q Transform frequency scale
3377#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3378#[cfg_attr(feature = "python", pyclass(from_py_object))]
3379#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3380#[non_exhaustive]
3381pub enum Cqt {
3382    _Phantom,
3383}
3384
3385// Amplitude scales
3386
3387/// Power amplitude scale
3388#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3389#[cfg_attr(feature = "python", pyclass(from_py_object))]
3390#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3391#[non_exhaustive]
3392pub enum Power {
3393    _Phantom,
3394}
3395
3396/// Decibel amplitude scale
3397#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3398#[cfg_attr(feature = "python", pyclass(from_py_object))]
3399#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3400#[non_exhaustive]
3401pub enum Decibels {
3402    _Phantom,
3403}
3404
3405/// Magnitude amplitude scale
3406#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3407#[cfg_attr(feature = "python", pyclass(from_py_object))]
3408#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3409#[non_exhaustive]
3410pub enum Magnitude {
3411    _Phantom,
3412}
3413
3414/// STFT parameters for spectrogram computation.
3415///
3416/// * `n_fft`: Size of the FFT window.
3417/// * `hop_size`: Number of samples between successive frames.
3418/// * window: Window function to apply to each frame.
3419/// * centre: Whether to pad the input signal so that frames are centered.
3420#[derive(Debug, Clone, PartialEq)]
3421#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3422pub struct StftParams {
3423    n_fft: NonZeroUsize,
3424    hop_size: NonZeroUsize,
3425    window: WindowType,
3426    centre: bool,
3427}
3428
3429impl StftParams {
3430    /// Create new STFT parameters.
3431    ///
3432    /// # Arguments
3433    ///
3434    /// * `n_fft` - Size of the FFT window
3435    /// * `hop_size` - Number of samples between successive frames
3436    /// * `window` - Window function to apply to each frame
3437    /// * `centre` - Whether to pad the input signal so that frames are centered
3438    ///
3439    /// # Errors
3440    ///
3441    /// Returns an error if:
3442    /// - `hop_size` > `n_fft`
3443    /// - Custom window size doesn't match `n_fft`
3444    ///
3445    /// # Returns
3446    ///
3447    /// New `StftParams` instance.
3448    #[inline]
3449    pub fn new(
3450        n_fft: NonZeroUsize,
3451        hop_size: NonZeroUsize,
3452        window: WindowType,
3453        centre: bool,
3454    ) -> SpectrogramResult<Self> {
3455        if hop_size.get() > n_fft.get() {
3456            return Err(SpectrogramError::invalid_input("hop_size must be <= n_fft"));
3457        }
3458
3459        // Validate custom window size matches n_fft
3460        if let WindowType::Custom { size, .. } = &window {
3461            if size.get() != n_fft.get() {
3462                return Err(SpectrogramError::invalid_input(format!(
3463                    "Custom window size ({}) must match n_fft ({})",
3464                    size.get(),
3465                    n_fft.get()
3466                )));
3467            }
3468        }
3469
3470        Ok(Self {
3471            n_fft,
3472            hop_size,
3473            window,
3474            centre,
3475        })
3476    }
3477
3478    const unsafe fn new_unchecked(
3479        n_fft: NonZeroUsize,
3480        hop_size: NonZeroUsize,
3481        window: WindowType,
3482        centre: bool,
3483    ) -> Self {
3484        Self {
3485            n_fft,
3486            hop_size,
3487            window,
3488            centre,
3489        }
3490    }
3491
3492    /// Get the FFT window size.
3493    ///
3494    /// # Returns
3495    ///
3496    /// The FFT window size.
3497    #[inline]
3498    #[must_use]
3499    pub const fn n_fft(&self) -> NonZeroUsize {
3500        self.n_fft
3501    }
3502
3503    /// Get the hop size (samples between successive frames).
3504    ///
3505    /// # Returns
3506    ///
3507    /// The hop size.
3508    #[inline]
3509    #[must_use]
3510    pub const fn hop_size(&self) -> NonZeroUsize {
3511        self.hop_size
3512    }
3513
3514    /// Get the window function.
3515    ///
3516    /// # Returns
3517    ///
3518    /// The window function.
3519    #[inline]
3520    #[must_use]
3521    pub fn window(&self) -> WindowType {
3522        self.window.clone()
3523    }
3524
3525    /// Get whether frames are centered (input signal is padded).
3526    ///
3527    /// # Returns
3528    ///
3529    /// `true` if frames are centered, `false` otherwise.
3530    #[inline]
3531    #[must_use]
3532    pub const fn centre(&self) -> bool {
3533        self.centre
3534    }
3535
3536    /// Create a builder for STFT parameters.
3537    ///
3538    /// # Returns
3539    ///
3540    /// A `StftParamsBuilder` instance.
3541    ///
3542    /// # Examples
3543    ///
3544    /// ```
3545    /// use spectrograms::{StftParams, WindowType, nzu};
3546    ///
3547    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
3548    /// let stft = StftParams::builder()
3549    ///     .n_fft(nzu!(2048))
3550    ///     .hop_size(nzu!(512))
3551    ///     .window(WindowType::Hanning)
3552    ///     .centre(true)
3553    ///     .build()?;
3554    ///
3555    /// assert_eq!(stft.n_fft(), nzu!(2048));
3556    /// assert_eq!(stft.hop_size(), nzu!(512));
3557    /// # Ok(())
3558    /// # }
3559    /// ```
3560    #[inline]
3561    #[must_use]
3562    pub fn builder() -> StftParamsBuilder {
3563        StftParamsBuilder::default()
3564    }
3565}
3566
3567/// Builder for [`StftParams`].
3568#[derive(Debug, Clone)]
3569pub struct StftParamsBuilder {
3570    n_fft: Option<NonZeroUsize>,
3571    hop_size: Option<NonZeroUsize>,
3572    window: WindowType,
3573    centre: bool,
3574}
3575
3576impl Default for StftParamsBuilder {
3577    #[inline]
3578    fn default() -> Self {
3579        Self {
3580            n_fft: None,
3581            hop_size: None,
3582            window: WindowType::Hanning,
3583            centre: true,
3584        }
3585    }
3586}
3587
3588impl StftParamsBuilder {
3589    /// Set the FFT window size.
3590    ///
3591    /// # Arguments
3592    ///
3593    /// * `n_fft` - Size of the FFT window
3594    ///
3595    /// # Returns
3596    ///
3597    /// The builder with the updated FFT window size.
3598    #[inline]
3599    #[must_use]
3600    pub const fn n_fft(mut self, n_fft: NonZeroUsize) -> Self {
3601        self.n_fft = Some(n_fft);
3602        self
3603    }
3604
3605    /// Set the hop size (samples between successive frames).
3606    ///
3607    /// # Arguments
3608    ///
3609    /// * `hop_size` - Number of samples between successive frames
3610    ///
3611    /// # Returns
3612    ///
3613    /// The builder with the updated hop size.
3614    #[inline]
3615    #[must_use]
3616    pub const fn hop_size(mut self, hop_size: NonZeroUsize) -> Self {
3617        self.hop_size = Some(hop_size);
3618        self
3619    }
3620
3621    /// Set the window function.
3622    ///
3623    /// # Arguments
3624    ///
3625    /// * `window` - Window function to apply to each frame
3626    ///
3627    /// # Returns
3628    ///
3629    /// The builder with the updated window function.
3630    #[inline]
3631    #[must_use]
3632    pub fn window(mut self, window: WindowType) -> Self {
3633        self.window = window;
3634        self
3635    }
3636
3637    /// Set whether to center frames (pad input signal).
3638    #[inline]
3639    #[must_use]
3640    pub const fn centre(mut self, centre: bool) -> Self {
3641        self.centre = centre;
3642        self
3643    }
3644
3645    /// Build the [`StftParams`].
3646    ///
3647    /// # Errors
3648    ///
3649    /// Returns an error if:
3650    /// - `n_fft` or `hop_size` are not set or are zero
3651    /// - `hop_size` > `n_fft`
3652    #[inline]
3653    pub fn build(self) -> SpectrogramResult<StftParams> {
3654        let n_fft = self
3655            .n_fft
3656            .ok_or_else(|| SpectrogramError::invalid_input("n_fft must be set"))?;
3657        let hop_size = self
3658            .hop_size
3659            .ok_or_else(|| SpectrogramError::invalid_input("hop_size must be set"))?;
3660
3661        StftParams::new(n_fft, hop_size, self.window, self.centre)
3662    }
3663}
3664
3665//
3666// ========================
3667// Mel parameters
3668// ========================
3669//
3670
3671/// Mel filterbank normalization strategy.
3672///
3673/// Determines how the triangular mel filters are normalized after construction.
3674#[derive(Debug, Clone, Copy, PartialEq, Eq)]
3675#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3676#[non_exhaustive]
3677#[derive(Default)]
3678pub enum MelNorm {
3679    /// No normalization (triangular filters with peak = 1.0).
3680    ///
3681    /// This is the default and fastest option.
3682    #[default]
3683    None,
3684
3685    /// Slaney-style area normalization (librosa default).
3686    ///
3687    /// Each mel filter is divided by its bandwidth in Hz: `2.0 / (f_max - f_min)`.
3688    /// This ensures constant energy per mel band regardless of bandwidth.
3689    ///
3690    /// Use this for compatibility with librosa's default behavior.
3691    Slaney,
3692
3693    /// L1 normalization (sum of weights = 1.0).
3694    ///
3695    /// Each mel filter's weights are divided by their sum.
3696    /// Useful when you want each filter to act as a weighted average.
3697    L1,
3698
3699    /// L2 normalization (Euclidean norm = 1.0).
3700    ///
3701    /// Each mel filter's weights are divided by their L2 norm.
3702    /// Provides unit-norm filters in the L2 sense.
3703    L2,
3704}
3705
3706/// Mel filter bank parameters
3707///
3708/// * `n_mels`: Number of mel bands
3709/// * `f_min`: Minimum frequency (Hz)
3710/// * `f_max`: Maximum frequency (Hz)
3711/// * `norm`: Filterbank normalization strategy
3712#[derive(Debug, Clone, Copy, PartialEq)]
3713#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3714pub struct MelParams {
3715    n_mels: NonZeroUsize,
3716    f_min: f64,
3717    f_max: f64,
3718    norm: MelNorm,
3719}
3720
3721impl MelParams {
3722    /// Create new mel filter bank parameters.
3723    ///
3724    /// # Arguments
3725    ///
3726    /// * `n_mels` - Number of mel bands
3727    /// * `f_min` - Minimum frequency (Hz)
3728    /// * `f_max` - Maximum frequency (Hz)
3729    ///
3730    /// # Errors
3731    ///
3732    /// Returns an error if:
3733    /// - `f_min` is not >= 0
3734    /// - `f_max` is not > `f_min`
3735    ///
3736    /// # Returns
3737    ///
3738    /// A `MelParams` instance with no normalization (default).
3739    #[inline]
3740    pub fn new(n_mels: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
3741        Self::with_norm(n_mels, f_min, f_max, MelNorm::None)
3742    }
3743
3744    /// Create new mel filter bank parameters with specified normalization.
3745    ///
3746    /// # Arguments
3747    ///
3748    /// * `n_mels` - Number of mel bands
3749    /// * `f_min` - Minimum frequency (Hz)
3750    /// * `f_max` - Maximum frequency (Hz)
3751    /// * `norm` - Filterbank normalization strategy
3752    ///
3753    /// # Errors
3754    ///
3755    /// Returns an error if:
3756    /// - `f_min` is not >= 0
3757    /// - `f_max` is not > `f_min`
3758    ///
3759    /// # Returns
3760    ///
3761    /// A `MelParams` instance.
3762    #[inline]
3763    pub fn with_norm(
3764        n_mels: NonZeroUsize,
3765        f_min: f64,
3766        f_max: f64,
3767        norm: MelNorm,
3768    ) -> SpectrogramResult<Self> {
3769        if f_min < 0.0 {
3770            return Err(SpectrogramError::invalid_input("f_min must be >= 0"));
3771        }
3772
3773        if f_max <= f_min {
3774            return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
3775        }
3776
3777        Ok(Self {
3778            n_mels,
3779            f_min,
3780            f_max,
3781            norm,
3782        })
3783    }
3784
3785    const unsafe fn new_unchecked(n_mels: NonZeroUsize, f_min: f64, f_max: f64) -> Self {
3786        Self {
3787            n_mels,
3788            f_min,
3789            f_max,
3790            norm: MelNorm::None,
3791        }
3792    }
3793
3794    /// Get the number of mel bands.
3795    ///
3796    /// # Returns
3797    ///
3798    /// The number of mel bands.
3799    #[inline]
3800    #[must_use]
3801    pub const fn n_mels(&self) -> NonZeroUsize {
3802        self.n_mels
3803    }
3804
3805    /// Get the minimum frequency (Hz).
3806    ///
3807    /// # Returns
3808    ///
3809    /// The minimum frequency in Hz.
3810    #[inline]
3811    #[must_use]
3812    pub const fn f_min(&self) -> f64 {
3813        self.f_min
3814    }
3815
3816    /// Get the maximum frequency (Hz).
3817    ///
3818    /// # Returns
3819    ///
3820    /// The maximum frequency in Hz.
3821    #[inline]
3822    #[must_use]
3823    pub const fn f_max(&self) -> f64 {
3824        self.f_max
3825    }
3826
3827    /// Get the filterbank normalization strategy.
3828    ///
3829    /// # Returns
3830    ///
3831    /// The normalization strategy.
3832    #[inline]
3833    #[must_use]
3834    pub const fn norm(&self) -> MelNorm {
3835        self.norm
3836    }
3837
3838    /// Create standard mel filterbank parameters.
3839    ///
3840    /// Uses 128 mel bands from 0 Hz to the Nyquist frequency.
3841    ///
3842    /// # Arguments
3843    ///
3844    /// * `sample_rate` - Sample rate in Hz (used to determine `f_max`)
3845    ///
3846    /// # Returns
3847    ///
3848    /// A `MelParams` instance with standard settings.
3849    ///
3850    /// # Panics
3851    ///
3852    /// Panics if `sample_rate` is not greater than 0.
3853    #[inline]
3854    #[must_use]
3855    pub const fn standard(sample_rate: f64) -> Self {
3856        assert!(sample_rate > 0.0);
3857        // safety: parameters are known to be valid
3858        unsafe { Self::new_unchecked(nzu!(128), 0.0, sample_rate / 2.0) }
3859    }
3860
3861    /// Create mel filterbank parameters optimized for speech.
3862    ///
3863    /// Uses 40 mel bands from 0 Hz to 8000 Hz (typical speech bandwidth).
3864    ///
3865    /// # Returns
3866    ///
3867    /// A `MelParams` instance with speech-optimized settings.
3868    #[inline]
3869    #[must_use]
3870    pub const fn speech_standard() -> Self {
3871        // safety: parameters are known to be valid
3872        unsafe { Self::new_unchecked(nzu!(40), 0.0, 8000.0) }
3873    }
3874}
3875
3876//
3877// ========================
3878// LogHz parameters
3879// ========================
3880//
3881
3882/// Logarithmic frequency scale parameters
3883///
3884/// * `n_bins`: Number of logarithmically-spaced frequency bins
3885/// * `f_min`: Minimum frequency (Hz)
3886/// * `f_max`: Maximum frequency (Hz)
3887#[derive(Debug, Clone, Copy, PartialEq)]
3888#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
3889pub struct LogHzParams {
3890    n_bins: NonZeroUsize,
3891    f_min: f64,
3892    f_max: f64,
3893}
3894
3895impl LogHzParams {
3896    /// Create new logarithmic frequency scale parameters.
3897    ///
3898    /// # Arguments
3899    ///
3900    /// * `n_bins` - Number of logarithmically-spaced frequency bins
3901    /// * `f_min` - Minimum frequency (Hz)
3902    /// * `f_max` - Maximum frequency (Hz)
3903    ///
3904    /// # Errors
3905    ///
3906    /// Returns an error if:
3907    /// - `f_min` is not finite and > 0
3908    /// - `f_max` is not > `f_min`
3909    ///
3910    /// # Returns
3911    ///
3912    /// A `LogHzParams` instance.
3913    #[inline]
3914    pub fn new(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> SpectrogramResult<Self> {
3915        if !(f_min > 0.0 && f_min.is_finite()) {
3916            return Err(SpectrogramError::invalid_input(
3917                "f_min must be finite and > 0",
3918            ));
3919        }
3920
3921        if f_max <= f_min {
3922            return Err(SpectrogramError::invalid_input("f_max must be > f_min"));
3923        }
3924
3925        Ok(Self {
3926            n_bins,
3927            f_min,
3928            f_max,
3929        })
3930    }
3931
3932    const unsafe fn new_unchecked(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> Self {
3933        Self {
3934            n_bins,
3935            f_min,
3936            f_max,
3937        }
3938    }
3939
3940    /// Get the number of frequency bins.
3941    ///
3942    /// # Returns
3943    ///
3944    /// The number of frequency bins.
3945    #[inline]
3946    #[must_use]
3947    pub const fn n_bins(&self) -> NonZeroUsize {
3948        self.n_bins
3949    }
3950
3951    /// Get the minimum frequency (Hz).
3952    ///
3953    /// # Returns
3954    ///
3955    /// The minimum frequency in Hz.
3956    #[inline]
3957    #[must_use]
3958    pub const fn f_min(&self) -> f64 {
3959        self.f_min
3960    }
3961
3962    /// Get the maximum frequency (Hz).
3963    ///
3964    /// # Returns
3965    ///
3966    /// The maximum frequency in Hz.
3967    #[inline]
3968    #[must_use]
3969    pub const fn f_max(&self) -> f64 {
3970        self.f_max
3971    }
3972
3973    /// Create standard logarithmic frequency parameters.
3974    ///
3975    /// Uses 128 log bins from 20 Hz to the Nyquist frequency.
3976    ///
3977    /// # Arguments
3978    ///
3979    /// * `sample_rate` - Sample rate in Hz (used to determine `f_max`)
3980    #[inline]
3981    #[must_use]
3982    pub fn standard(sample_rate: f64) -> Self {
3983        // safety: parameters are known to be valid
3984        unsafe { Self::new_unchecked(nzu!(128), 20.0, sample_rate / 2.0) }
3985    }
3986
3987    /// Create logarithmic frequency parameters optimized for music.
3988    ///
3989    /// Uses 84 bins (7 octaves * 12 bins/octave) from 27.5 Hz (A0) to 4186 Hz (C8).
3990    #[inline]
3991    #[must_use]
3992    pub const fn music_standard() -> Self {
3993        // safety: parameters are known to be valid
3994        unsafe { Self::new_unchecked(nzu!(84), 27.5, 4186.0) }
3995    }
3996}
3997
3998//
3999// ========================
4000// Log scaling parameters
4001// ========================
4002//
4003
4004#[derive(Debug, Clone, Copy, PartialEq)]
4005#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
4006pub struct LogParams {
4007    floor_db: f64,
4008}
4009
4010impl LogParams {
4011    /// Create new logarithmic scaling parameters.
4012    ///
4013    /// # Arguments
4014    ///
4015    /// * `floor_db` - Minimum dB value (floor) for logarithmic scaling
4016    ///
4017    /// # Errors
4018    ///
4019    /// Returns an error if `floor_db` is not finite.
4020    ///
4021    /// # Returns
4022    ///
4023    /// A `LogParams` instance.
4024    #[inline]
4025    pub fn new(floor_db: f64) -> SpectrogramResult<Self> {
4026        if !floor_db.is_finite() {
4027            return Err(SpectrogramError::invalid_input("floor_db must be finite"));
4028        }
4029
4030        Ok(Self { floor_db })
4031    }
4032
4033    /// Get the floor dB value.
4034    #[inline]
4035    #[must_use]
4036    pub const fn floor_db(&self) -> f64 {
4037        self.floor_db
4038    }
4039}
4040
4041/// Spectrogram computation parameters.
4042///
4043/// * `stft`: STFT parameters
4044/// * `sample_rate_hz`: Sample rate in Hz
4045#[derive(Debug, Clone)]
4046#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
4047pub struct SpectrogramParams {
4048    stft: StftParams,
4049    sample_rate_hz: f64,
4050}
4051
4052impl SpectrogramParams {
4053    /// Create new spectrogram parameters.
4054    ///
4055    /// # Arguments
4056    ///
4057    /// * `stft` - STFT parameters
4058    /// * `sample_rate_hz` - Sample rate in Hz
4059    ///
4060    /// # Errors
4061    ///
4062    /// Returns an error if the sample rate is not positive and finite.
4063    ///
4064    /// # Returns
4065    ///
4066    /// A `SpectrogramParams` instance.
4067    #[inline]
4068    pub fn new(stft: StftParams, sample_rate_hz: f64) -> SpectrogramResult<Self> {
4069        if !(sample_rate_hz > 0.0 && sample_rate_hz.is_finite()) {
4070            return Err(SpectrogramError::invalid_input(
4071                "sample_rate_hz must be finite and > 0",
4072            ));
4073        }
4074
4075        Ok(Self {
4076            stft,
4077            sample_rate_hz,
4078        })
4079    }
4080
4081    /// Create a builder for spectrogram parameters.
4082    ///
4083    /// # Errors
4084    ///
4085    /// Returns an error if required parameters are not set or are invalid.
4086    ///
4087    /// # Returns
4088    ///
4089    /// A builder for [`SpectrogramParams`].
4090    ///
4091    /// # Examples
4092    ///
4093    /// ```
4094    /// use spectrograms::{SpectrogramParams, WindowType, nzu};
4095    ///
4096    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4097    /// let params = SpectrogramParams::builder()
4098    ///     .sample_rate(16000.0)
4099    ///     .n_fft(nzu!(512))
4100    ///     .hop_size(nzu!(256))
4101    ///     .window(WindowType::Hanning)
4102    ///     .centre(true)
4103    ///     .build()?;
4104    ///
4105    /// assert_eq!(params.sample_rate_hz(), 16000.0);
4106    /// # Ok(())
4107    /// # }
4108    /// ```
4109    #[inline]
4110    #[must_use]
4111    pub fn builder() -> SpectrogramParamsBuilder {
4112        SpectrogramParamsBuilder::default()
4113    }
4114
4115    /// Create default parameters for speech processing.
4116    ///
4117    /// # Arguments
4118    ///
4119    /// * `sample_rate_hz` - Sample rate in Hz
4120    ///
4121    /// # Returns
4122    ///
4123    /// A `SpectrogramParams` instance with default settings for music analysis.
4124    ///
4125    /// # Errors
4126    ///
4127    /// Returns an error if the sample rate is not positive and finite.
4128    ///
4129    /// Uses:
4130    /// - `n_fft`: 512 (32ms at 16kHz)
4131    /// - `hop_size`: 160 (10ms at 16kHz)
4132    /// - window: Hanning
4133    /// - centre: true
4134    #[inline]
4135    pub fn speech_default(sample_rate_hz: f64) -> SpectrogramResult<Self> {
4136        // safety: parameters are known to be valid
4137        let stft =
4138            unsafe { StftParams::new_unchecked(nzu!(512), nzu!(160), WindowType::Hanning, true) };
4139
4140        Self::new(stft, sample_rate_hz)
4141    }
4142
4143    /// Create default parameters for music processing.
4144    ///
4145    /// # Arguments
4146    ///
4147    /// * `sample_rate_hz` - Sample rate in Hz
4148    ///
4149    /// # Returns
4150    ///
4151    /// A `SpectrogramParams` instance with default settings for music analysis.
4152    ///
4153    /// # Errors
4154    ///
4155    /// Returns an error if the sample rate is not positive and finite.
4156    ///
4157    /// Uses:
4158    /// - `n_fft`: 2048 (46ms at 44.1kHz)
4159    /// - `hop_size`: 512 (11.6ms at 44.1kHz)
4160    /// - window: Hanning
4161    /// - centre: true
4162    #[inline]
4163    pub fn music_default(sample_rate_hz: f64) -> SpectrogramResult<Self> {
4164        // safety: parameters are known to be valid
4165        let stft =
4166            unsafe { StftParams::new_unchecked(nzu!(2048), nzu!(512), WindowType::Hanning, true) };
4167        Self::new(stft, sample_rate_hz)
4168    }
4169
4170    /// Get the STFT parameters.
4171    #[inline]
4172    #[must_use]
4173    pub const fn stft(&self) -> &StftParams {
4174        &self.stft
4175    }
4176
4177    /// Get the sample rate in Hz.
4178    #[inline]
4179    #[must_use]
4180    pub const fn sample_rate_hz(&self) -> f64 {
4181        self.sample_rate_hz
4182    }
4183
4184    /// Get the frame period in seconds.
4185    #[inline]
4186    #[must_use]
4187    #[allow(clippy::cast_precision_loss)]
4188    pub fn frame_period_seconds(&self) -> f64 {
4189        self.stft.hop_size().get() as f64 / self.sample_rate_hz
4190    }
4191
4192    /// Get the Nyquist frequency in Hz.
4193    #[inline]
4194    #[must_use]
4195    pub fn nyquist_hz(&self) -> f64 {
4196        self.sample_rate_hz * 0.5
4197    }
4198}
4199
4200/// Builder for [`SpectrogramParams`].
4201#[derive(Debug, Clone)]
4202pub struct SpectrogramParamsBuilder {
4203    sample_rate: Option<f64>,
4204    n_fft: Option<NonZeroUsize>,
4205    hop_size: Option<NonZeroUsize>,
4206    window: WindowType,
4207    centre: bool,
4208}
4209
4210impl Default for SpectrogramParamsBuilder {
4211    #[inline]
4212    fn default() -> Self {
4213        Self {
4214            sample_rate: None,
4215            n_fft: None,
4216            hop_size: None,
4217            window: WindowType::Hanning,
4218            centre: true,
4219        }
4220    }
4221}
4222
4223impl SpectrogramParamsBuilder {
4224    /// Set the sample rate in Hz.
4225    ///
4226    /// # Arguments
4227    ///
4228    /// * `sample_rate` - Sample rate in Hz.
4229    ///
4230    /// # Returns
4231    ///
4232    /// The updated builder instance.
4233    #[inline]
4234    #[must_use]
4235    pub const fn sample_rate(mut self, sample_rate: f64) -> Self {
4236        self.sample_rate = Some(sample_rate);
4237        self
4238    }
4239
4240    /// Set the FFT window size.
4241    ///
4242    /// # Arguments
4243    ///
4244    /// * `n_fft` - FFT size.
4245    ///
4246    /// # Returns
4247    ///
4248    /// The updated builder instance.
4249    #[inline]
4250    #[must_use]
4251    pub const fn n_fft(mut self, n_fft: NonZeroUsize) -> Self {
4252        self.n_fft = Some(n_fft);
4253        self
4254    }
4255
4256    /// Set the hop size (samples between successive frames).
4257    ///
4258    /// # Arguments
4259    ///
4260    /// * `hop_size` - Hop size in samples.
4261    ///
4262    /// # Returns
4263    ///
4264    /// The updated builder instance.
4265    #[inline]
4266    #[must_use]
4267    pub const fn hop_size(mut self, hop_size: NonZeroUsize) -> Self {
4268        self.hop_size = Some(hop_size);
4269        self
4270    }
4271
4272    /// Set the window function.
4273    ///
4274    /// # Arguments
4275    ///
4276    /// * `window` - Window function to apply to each frame.
4277    ///
4278    /// # Returns
4279    ///
4280    /// The updated builder instance.
4281    #[inline]
4282    #[must_use]
4283    pub fn window(mut self, window: WindowType) -> Self {
4284        self.window = window;
4285        self
4286    }
4287
4288    /// Set whether to center frames (pad input signal).
4289    ///
4290    /// # Arguments
4291    ///
4292    /// * `centre` - If true, frames are centered by padding the input signal.
4293    ///
4294    /// # Returns
4295    ///
4296    /// The updated builder instance.
4297    #[inline]
4298    #[must_use]
4299    pub const fn centre(mut self, centre: bool) -> Self {
4300        self.centre = centre;
4301        self
4302    }
4303
4304    /// Build the [`SpectrogramParams`].
4305    ///
4306    /// # Errors
4307    ///
4308    /// Returns an error if required parameters are not set or are invalid.
4309    ///
4310    /// # Returns
4311    ///
4312    /// A `SpectrogramParams` instance.
4313    #[inline]
4314    pub fn build(self) -> SpectrogramResult<SpectrogramParams> {
4315        let sample_rate = self
4316            .sample_rate
4317            .ok_or_else(|| SpectrogramError::invalid_input("sample_rate must be set"))?;
4318        let n_fft = self
4319            .n_fft
4320            .ok_or_else(|| SpectrogramError::invalid_input("n_fft must be set"))?;
4321        let hop_size = self
4322            .hop_size
4323            .ok_or_else(|| SpectrogramError::invalid_input("hop_size must be set"))?;
4324
4325        let stft = StftParams::new(n_fft, hop_size, self.window, self.centre)?;
4326        SpectrogramParams::new(stft, sample_rate)
4327    }
4328}
4329
4330//
4331// ========================
4332// Standalone FFT Functions
4333// ========================
4334//
4335
4336/// Compute the real-to-complex FFT of a real-valued signal.
4337///
4338/// This function performs a forward FFT on real-valued input, returning the
4339/// complex frequency domain representation. Only the positive frequencies
4340/// are returned (length = `n_fft/2` + 1) due to conjugate symmetry.
4341///
4342/// # Arguments
4343///
4344/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4345/// * `n_fft` - FFT size
4346///
4347/// # Returns
4348///
4349/// A vector of complex frequency bins with length `n_fft/2` + 1.
4350///
4351/// # Automatic Zero-Padding
4352///
4353/// If the input signal is shorter than `n_fft`, it will be automatically
4354/// zero-padded to the required length. This is standard DSP practice and
4355/// preserves frequency resolution (bin spacing = sample_rate / n_fft).
4356///
4357/// ```
4358/// use spectrograms::{fft, nzu};
4359/// use non_empty_slice::non_empty_vec;
4360///
4361/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4362/// let signal = non_empty_vec![1.0, 2.0, 3.0]; // Only 3 samples
4363/// let spectrum = fft(&signal, nzu!(8))?;   // Automatically padded to 8
4364/// assert_eq!(spectrum.len(), 5);     // Output: 8/2 + 1 = 5 bins
4365/// # Ok(())
4366/// # }
4367/// ```
4368///
4369/// # Errors
4370///
4371/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4372///
4373/// # Examples
4374///
4375/// ```
4376/// use spectrograms::*;
4377/// use non_empty_slice::non_empty_vec;
4378///
4379/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4380/// let signal = non_empty_vec![0.0; nzu!(512)];
4381/// let spectrum = fft(&signal, nzu!(512))?;
4382///
4383/// assert_eq!(spectrum.len(), 257); // 512/2 + 1
4384/// # Ok(())
4385/// # }
4386/// ```
4387#[inline]
4388pub fn fft(
4389    samples: &NonEmptySlice<f64>,
4390    n_fft: NonZeroUsize,
4391) -> SpectrogramResult<Array1<Complex<f64>>> {
4392    if samples.len() > n_fft {
4393        return Err(SpectrogramError::invalid_input(format!(
4394            "Input length ({}) exceeds FFT size ({})",
4395            samples.len(),
4396            n_fft
4397        )));
4398    }
4399
4400    let out_len = r2c_output_size(n_fft.get());
4401
4402    // Get FFT plan from global cache (or create if first use)
4403    #[cfg(feature = "realfft")]
4404    let mut fft = {
4405        use crate::fft_backend::get_or_create_r2c_plan;
4406        let plan = get_or_create_r2c_plan(n_fft.get())?;
4407        // Clone the plan to get our own mutable copy with independent scratch buffer
4408        // This is cheap - only clones the scratch buffer, not the expensive twiddle factors
4409        (*plan).clone()
4410    };
4411
4412    #[cfg(feature = "fftw")]
4413    let mut fft = {
4414        use std::sync::Arc;
4415        let plan = crate::FftwPlanner::build_plan(n_fft.get())?;
4416        crate::FftwPlan::new(Arc::new(plan))
4417    };
4418
4419    let input = if samples.len() < n_fft {
4420        let mut padded = vec![0.0; n_fft.get()];
4421        padded[..samples.len().get()].copy_from_slice(samples);
4422        // safety: samples.len() < n_fft checked above and n_fft > 0
4423
4424        unsafe { NonEmptyVec::new_unchecked(padded) }
4425    } else {
4426        samples.to_non_empty_vec()
4427    };
4428
4429    let mut output = vec![Complex::new(0.0, 0.0); out_len];
4430    fft.process(&input, &mut output)?;
4431    let output = Array1::from_vec(output);
4432    Ok(output)
4433}
4434
4435#[inline]
4436/// Compute the real-valued fft of a signal.
4437///
4438/// # Arguments
4439/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4440/// * `n_fft` - FFT size
4441///
4442/// # Returns
4443///
4444/// An array with length `n_fft/2` + 1.
4445///
4446/// # Errors
4447///
4448/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4449///
4450/// # Examples
4451///
4452/// ```
4453/// use spectrograms::*;
4454/// use non_empty_slice::non_empty_vec;
4455///
4456/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4457/// let signal = non_empty_vec![0.0; nzu!(512)];
4458/// let rfft_result = rfft(&signal, nzu!(512))?;
4459/// // equivalent to
4460/// let fft_result = fft(&signal, nzu!(512))?;
4461/// let rfft_result = fft_result.mapv(num_complex::Complex::norm);
4462/// # Ok(())
4463/// # }
4464///
4465pub fn rfft(samples: &NonEmptySlice<f64>, n_fft: NonZeroUsize) -> SpectrogramResult<Array1<f64>> {
4466    Ok(fft(samples, n_fft)?.mapv(Complex::norm))
4467}
4468
4469/// Compute the power spectrum of a signal (|X|²).
4470///
4471/// This function applies an optional window function and computes the
4472/// power spectrum via FFT. The result contains only positive frequencies.
4473///
4474/// # Arguments
4475///
4476/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4477/// * `n_fft` - FFT size
4478/// * `window` - Optional window function (None for rectangular window)
4479///
4480/// # Returns
4481///
4482/// A vector of power values with length `n_fft/2` + 1.
4483///
4484/// # Automatic Zero-Padding
4485///
4486/// If the input signal is shorter than `n_fft`, it will be automatically
4487/// zero-padded to the required length. This is standard DSP practice and
4488/// preserves frequency resolution (bin spacing = sample_rate / n_fft).
4489///
4490/// ```
4491/// use spectrograms::{power_spectrum, nzu};
4492/// use non_empty_slice::non_empty_vec;
4493///
4494/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4495/// let signal = non_empty_vec![1.0, 2.0, 3.0]; // Only 3 samples
4496/// let power = power_spectrum(&signal, nzu!(8), None)?;
4497/// assert_eq!(power.len(), nzu!(5));     // Output: 8/2 + 1 = 5 bins
4498/// # Ok(())
4499/// # }
4500/// ```
4501///
4502/// # Errors
4503///
4504/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4505///
4506/// # Examples
4507///
4508/// ```
4509/// use spectrograms::*;
4510/// use non_empty_slice::non_empty_vec;
4511///
4512/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4513/// let signal = non_empty_vec![0.0; nzu!(512)];
4514/// let power = power_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4515///
4516/// assert_eq!(power.len(), nzu!(257)); // 512/2 + 1
4517/// # Ok(())
4518/// # }
4519/// ```
4520#[inline]
4521pub fn power_spectrum(
4522    samples: &NonEmptySlice<f64>,
4523    n_fft: NonZeroUsize,
4524    window: Option<WindowType>,
4525) -> SpectrogramResult<NonEmptyVec<f64>> {
4526    if samples.len() > n_fft {
4527        return Err(SpectrogramError::invalid_input(format!(
4528            "Input length ({}) exceeds FFT size ({})",
4529            samples.len(),
4530            n_fft
4531        )));
4532    }
4533
4534    let mut windowed = vec![0.0; n_fft.get()];
4535    windowed[..samples.len().get()].copy_from_slice(samples);
4536
4537    if let Some(win_type) = window {
4538        let window_samples = make_window(win_type, n_fft);
4539        for i in 0..n_fft.get() {
4540            windowed[i] *= window_samples[i];
4541        }
4542    }
4543
4544    // safety: windowed is non-empty since n_fft > 0
4545    let windowed = unsafe { NonEmptySlice::new_unchecked(&windowed) };
4546    let fft_result = fft(windowed, n_fft)?;
4547    let fft_result = fft_result
4548        .iter()
4549        .map(num_complex::Complex::norm_sqr)
4550        .collect();
4551    // safety: fft_result is non-empty since fft returned successfully
4552    Ok(unsafe { NonEmptyVec::new_unchecked(fft_result) })
4553}
4554
4555/// Compute the magnitude spectrum of a signal (|X|).
4556///
4557/// This function applies an optional window function and computes the
4558/// magnitude spectrum via FFT. The result contains only positive frequencies.
4559///
4560/// # Arguments
4561///
4562/// * `samples` - Input signal (length ≤ n_fft, will be zero-padded if shorter)
4563/// * `n_fft` - FFT size
4564/// * `window` - Optional window function (None for rectangular window)
4565///
4566/// # Automatic Zero-Padding
4567///
4568/// If the input signal is shorter than `n_fft`, it will be automatically
4569/// zero-padded to the required length. This preserves frequency resolution.
4570///
4571/// # Errors
4572///
4573/// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4574///
4575/// # Returns
4576///
4577/// A vector of magnitude values with length `n_fft/2` + 1.
4578///
4579/// # Examples
4580///
4581/// ```
4582/// use spectrograms::*;
4583/// use non_empty_slice::non_empty_vec;
4584///
4585/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4586/// let signal = non_empty_vec![0.0; nzu!(512)];
4587/// let magnitude = magnitude_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
4588///
4589/// assert_eq!(magnitude.len(), nzu!(257)); // 512/2 + 1
4590/// # Ok(())
4591/// # }
4592/// ```
4593#[inline]
4594pub fn magnitude_spectrum(
4595    samples: &NonEmptySlice<f64>,
4596    n_fft: NonZeroUsize,
4597    window: Option<WindowType>,
4598) -> SpectrogramResult<NonEmptyVec<f64>> {
4599    let power = power_spectrum(samples, n_fft, window)?;
4600    let power = power.iter().map(|&p| p.sqrt()).collect();
4601    // safety: power is non-empty since power_spectrum returned successfully
4602    Ok(unsafe { NonEmptyVec::new_unchecked(power) })
4603}
4604
4605/// Compute the Short-Time Fourier Transform (STFT) of a signal.
4606///
4607/// This function computes the STFT by applying a sliding window and FFT
4608/// to sequential frames of the input signal.
4609///
4610/// # Arguments
4611///
4612/// * `samples` - Input signal (any type that can be converted to a slice)
4613/// * `n_fft` - FFT size
4614/// * `hop_size` - Number of samples between successive frames
4615/// * `window` - Window function to apply to each frame
4616/// * `center` - If true, pad the signal to center frames
4617///
4618/// # Returns
4619///
4620/// A 2D array with shape (`frequency_bins`, `time_frames`) containing complex STFT values.
4621///
4622/// # Errors
4623///
4624/// Returns an error if:
4625/// - `hop_size` > `n_fft`
4626/// - STFT computation fails
4627///
4628/// # Examples
4629///
4630/// ```
4631/// use spectrograms::*;
4632/// use non_empty_slice::non_empty_vec;
4633///
4634/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4635/// let signal = non_empty_vec![0.0; nzu!(16000)];
4636/// let stft_matrix = stft(&signal, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4637///
4638/// println!("STFT: {} bins x {} frames", stft_matrix.nrows(), stft_matrix.ncols());
4639/// # Ok(())
4640/// # }
4641/// ```
4642#[inline]
4643pub fn stft(
4644    samples: &NonEmptySlice<f64>,
4645    n_fft: NonZeroUsize,
4646    hop_size: NonZeroUsize,
4647    window: WindowType,
4648    center: bool,
4649) -> SpectrogramResult<Array2<Complex<f64>>> {
4650    let stft_params = StftParams::new(n_fft, hop_size, window, center)?;
4651    let params = SpectrogramParams::new(stft_params, 1.0)?; // dummy sample rate
4652
4653    let planner = SpectrogramPlanner::new();
4654    let result = planner.compute_stft(samples, &params)?;
4655
4656    Ok(result.data)
4657}
4658
4659/// Compute the inverse real FFT (complex-to-real IFFT).
4660///
4661/// This function performs an inverse FFT, converting frequency domain data
4662/// back to the time domain. Only the positive frequencies need to be provided
4663/// (length = `n_fft/2` + 1) due to conjugate symmetry.
4664///
4665/// # Arguments
4666///
4667/// * `spectrum` - Complex frequency bins (length should be `n_fft/2` + 1)
4668/// * `n_fft` - FFT size (length of the output signal)
4669///
4670/// # Returns
4671///
4672/// A vector of real-valued time-domain samples with length `n_fft`.
4673///
4674/// # Errors
4675///
4676/// Returns an error if:
4677/// - `spectrum` length doesn't match `n_fft/2` + 1
4678/// - Inverse FFT computation fails
4679///
4680/// # Examples
4681///
4682/// ```
4683/// use spectrograms::*;
4684/// use non_empty_slice::{non_empty_vec, NonEmptySlice};
4685/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4686/// // Forward FFT
4687/// let signal = non_empty_vec![1.0, 0.0, -1.0, 0.0, 1.0, 0.0, -1.0, 0.0];
4688/// let spectrum = fft(&signal, nzu!(8))?;
4689/// let slice = spectrum.as_slice().unwrap();
4690/// let spectrum_slice = NonEmptySlice::new(slice).unwrap();
4691/// // Inverse FFT
4692/// let reconstructed = irfft(spectrum_slice, nzu!(8))?;
4693///
4694/// assert_eq!(reconstructed.len(), nzu!(8));
4695/// # Ok(())
4696/// # }
4697/// ```
4698#[inline]
4699pub fn irfft(
4700    spectrum: &NonEmptySlice<Complex<f64>>,
4701    n_fft: NonZeroUsize,
4702) -> SpectrogramResult<NonEmptyVec<f64>> {
4703    use crate::fft_backend::{C2rPlan, r2c_output_size};
4704
4705    let n_fft = n_fft.get();
4706    let expected_len = r2c_output_size(n_fft);
4707    if spectrum.len().get() != expected_len {
4708        return Err(SpectrogramError::dimension_mismatch(
4709            expected_len,
4710            spectrum.len().get(),
4711        ));
4712    }
4713
4714    // Get inverse FFT plan from global cache (or create if first use)
4715    #[cfg(feature = "realfft")]
4716    let mut ifft = {
4717        use crate::fft_backend::get_or_create_c2r_plan;
4718        let plan = get_or_create_c2r_plan(n_fft)?;
4719        // Clone to get our own mutable copy with independent scratch buffer
4720        (*plan).clone()
4721    };
4722
4723    #[cfg(feature = "fftw")]
4724    let mut ifft = {
4725        use crate::fft_backend::C2rPlanner;
4726        let mut planner = crate::FftwPlanner::new();
4727        planner.plan_c2r(n_fft)?
4728    };
4729
4730    let mut output = vec![0.0; n_fft];
4731    ifft.process(spectrum.as_slice(), &mut output)?;
4732
4733    // Safety: output is non-empty since n_fft > 0
4734    Ok(unsafe { NonEmptyVec::new_unchecked(output) })
4735}
4736
4737/// Reconstruct a time-domain signal from its STFT using overlap-add.
4738///
4739/// This function performs the inverse Short-Time Fourier Transform, converting
4740/// a 2D complex STFT matrix back to a 1D time-domain signal using overlap-add
4741/// synthesis with the specified window function.
4742///
4743/// # Arguments
4744///
4745/// * `stft_matrix` - Complex STFT with shape (`frequency_bins`, `time_frames`)
4746/// * `n_fft` - FFT size
4747/// * `hop_size` - Number of samples between successive frames
4748/// * `window` - Window function to apply (should match forward STFT window)
4749/// * `center` - If true, assume the forward STFT was centered
4750///
4751/// # Returns
4752///
4753/// A vector of reconstructed time-domain samples.
4754///
4755/// # Errors
4756///
4757/// Returns an error if:
4758/// - `stft_matrix` dimensions are inconsistent with `n_fft`
4759/// - `hop_size` > `n_fft`
4760/// - Inverse STFT computation fails
4761///
4762/// # Examples
4763///
4764/// ```
4765/// use spectrograms::*;
4766/// use non_empty_slice::non_empty_vec;
4767///
4768/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4769/// // Generate signal
4770/// let signal = non_empty_vec![1.0; nzu!(16000)];
4771///
4772/// // Forward STFT
4773/// let stft_matrix = stft(&signal, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4774///
4775/// // Inverse STFT
4776/// let reconstructed = istft(&stft_matrix, nzu!(512), nzu!(256), WindowType::Hanning, true)?;
4777///
4778/// println!("Original: {} samples", signal.len());
4779/// println!("Reconstructed: {} samples", reconstructed.len());
4780/// # Ok(())
4781/// # }
4782/// ```
4783#[inline]
4784pub fn istft(
4785    stft_matrix: &Array2<Complex<f64>>,
4786    n_fft: NonZeroUsize,
4787    hop_size: NonZeroUsize,
4788    window: WindowType,
4789    center: bool,
4790) -> SpectrogramResult<NonEmptyVec<f64>> {
4791    use crate::fft_backend::{C2rPlan, C2rPlanner, r2c_output_size};
4792
4793    let n_bins = stft_matrix.nrows();
4794    let n_frames = stft_matrix.ncols();
4795
4796    let expected_bins = r2c_output_size(n_fft.get());
4797    if n_bins != expected_bins {
4798        return Err(SpectrogramError::dimension_mismatch(expected_bins, n_bins));
4799    }
4800    if hop_size.get() > n_fft.get() {
4801        return Err(SpectrogramError::invalid_input("hop_size must be <= n_fft"));
4802    }
4803    // Create inverse FFT plan
4804    #[cfg(feature = "realfft")]
4805    let mut ifft = {
4806        let mut planner = crate::RealFftPlanner::new();
4807        planner.plan_c2r(n_fft.get())?
4808    };
4809
4810    #[cfg(feature = "fftw")]
4811    let mut ifft = {
4812        let mut planner = crate::FftwPlanner::new();
4813        planner.plan_c2r(n_fft.get())?
4814    };
4815
4816    // Generate window
4817    let window_samples = make_window(window, n_fft);
4818    let n_fft = n_fft.get();
4819    let hop_size = hop_size.get();
4820    // Calculate output length
4821    let pad = if center { n_fft / 2 } else { 0 };
4822    let output_len = (n_frames - 1) * hop_size + n_fft;
4823    // safety: output_len > 0 since n_frames > 0 and n_fft, hop_size > 0
4824    let output_len = unsafe { NonZeroUsize::new_unchecked(output_len) };
4825    let unpadded_len = output_len.get().saturating_sub(2 * pad);
4826
4827    // Allocate output buffer and normalization buffer
4828    let mut output = non_empty_vec![0.0; output_len];
4829    let mut norm = non_empty_vec![0.0; output_len];
4830
4831    // Overlap-add synthesis
4832    let mut frame_buffer = vec![Complex::new(0.0, 0.0); n_bins];
4833    let mut time_frame = vec![0.0; n_fft];
4834
4835    for frame_idx in 0..n_frames {
4836        // Extract complex frame from STFT matrix
4837        for bin_idx in 0..n_bins {
4838            frame_buffer[bin_idx] = stft_matrix[[bin_idx, frame_idx]];
4839        }
4840
4841        // Inverse FFT
4842        ifft.process(&frame_buffer, &mut time_frame)?;
4843
4844        // Apply window
4845        for i in 0..n_fft {
4846            time_frame[i] *= window_samples[i];
4847        }
4848
4849        // Overlap-add into output buffer
4850        let start = frame_idx * hop_size;
4851        for i in 0..n_fft {
4852            let pos = start + i;
4853            if pos < output_len.get() {
4854                output[pos] += time_frame[i];
4855                norm[pos] += window_samples[i] * window_samples[i];
4856            }
4857        }
4858    }
4859
4860    // Normalize by window energy
4861    for i in 0..output_len.get() {
4862        if norm[i] > 1e-10 {
4863            output[i] /= norm[i];
4864        }
4865    }
4866
4867    // Remove padding if centered
4868    if center && unpadded_len > 0 {
4869        let start = pad;
4870        let end = start + unpadded_len;
4871        // safety: start < end <= output_len, therefore slice is non-empty
4872        output = unsafe {
4873            NonEmptySlice::new_unchecked(&output[start..end.min(output_len.get())])
4874                .to_non_empty_vec()
4875        };
4876    }
4877
4878    Ok(output)
4879}
4880
4881//
4882// ========================
4883// Reusable FFT Plans
4884// ========================
4885//
4886
4887/// A reusable FFT planner for efficient repeated FFT operations.
4888///
4889/// This planner caches FFT plans internally, making repeated FFT operations
4890/// of the same size much more efficient than calling `fft()` repeatedly.
4891///
4892/// # Examples
4893///
4894/// ```
4895/// use spectrograms::*;
4896/// use non_empty_slice::non_empty_vec;
4897///
4898/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4899/// let mut planner = FftPlanner::new();
4900///
4901/// // Process multiple signals of the same size efficiently
4902/// for _ in 0..100 {
4903///     let signal = non_empty_vec![0.0; nzu!(512)];
4904///     let spectrum = planner.fft(&signal, nzu!(512))?;
4905///     // ... process spectrum ...
4906/// }
4907/// # Ok(())
4908/// # }
4909/// ```
4910pub struct FftPlanner {
4911    #[cfg(feature = "realfft")]
4912    inner: crate::RealFftPlanner,
4913    #[cfg(feature = "fftw")]
4914    inner: crate::FftwPlanner,
4915}
4916
4917impl FftPlanner {
4918    /// Create a new FFT planner with empty cache.
4919    #[inline]
4920    #[must_use]
4921    pub fn new() -> Self {
4922        Self {
4923            #[cfg(feature = "realfft")]
4924            inner: crate::RealFftPlanner::new(),
4925            #[cfg(feature = "fftw")]
4926            inner: crate::FftwPlanner::new(),
4927        }
4928    }
4929
4930    /// Compute forward FFT, reusing cached plans.
4931    ///
4932    /// This is more efficient than calling the standalone `fft()` function
4933    /// repeatedly for the same FFT size.
4934    ///
4935    /// # Automatic Zero-Padding
4936    ///
4937    /// If the input signal is shorter than `n_fft`, it will be automatically
4938    /// zero-padded to the required length.
4939    ///
4940    /// # Errors
4941    ///
4942    /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
4943    ///
4944    /// # Examples
4945    ///
4946    /// ```
4947    /// use spectrograms::*;
4948    /// use non_empty_slice::non_empty_vec;
4949    ///
4950    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
4951    /// let mut planner = FftPlanner::new();
4952    ///
4953    /// let signal = non_empty_vec![1.0; nzu!(512)];
4954    /// let spectrum = planner.fft(&signal, nzu!(512))?;
4955    ///
4956    /// assert_eq!(spectrum.len(), 257); // 512/2 + 1
4957    /// # Ok(())
4958    /// # }
4959    /// ```
4960    #[inline]
4961    pub fn fft(
4962        &mut self,
4963        samples: &NonEmptySlice<f64>,
4964        n_fft: NonZeroUsize,
4965    ) -> SpectrogramResult<Array1<Complex<f64>>> {
4966        use crate::fft_backend::{R2cPlan, R2cPlanner, r2c_output_size};
4967
4968        if samples.len() > n_fft {
4969            return Err(SpectrogramError::invalid_input(format!(
4970                "Input length ({}) exceeds FFT size ({})",
4971                samples.len(),
4972                n_fft
4973            )));
4974        }
4975
4976        let out_len = r2c_output_size(n_fft.get());
4977        let mut plan = self.inner.plan_r2c(n_fft.get())?;
4978
4979        let input = if samples.len() < n_fft {
4980            let mut padded = vec![0.0; n_fft.get()];
4981            padded[..samples.len().get()].copy_from_slice(samples);
4982
4983            // safety: samples.len() < n_fft checked above and n_fft > 0
4984            unsafe { NonEmptyVec::new_unchecked(padded) }
4985        } else {
4986            samples.to_non_empty_vec()
4987        };
4988
4989        let mut output = vec![Complex::new(0.0, 0.0); out_len];
4990        plan.process(&input, &mut output)?;
4991
4992        let output = Array1::from_vec(output);
4993        Ok(output)
4994    }
4995
4996    /// Compute forward real FFT magnitude
4997    ///
4998    /// # Errors
4999    ///
5000    /// Returns an error if:
5001    /// - `n_fft` doesn't match the samples length
5002    ///
5003    ///
5004    #[inline]
5005    pub fn rfft(
5006        &mut self,
5007        samples: &NonEmptySlice<f64>,
5008        n_fft: NonZeroUsize,
5009    ) -> SpectrogramResult<Array1<f64>> {
5010        let fft_with_complex = fft(samples, n_fft)?;
5011        Ok(fft_with_complex.mapv(Complex::norm))
5012    }
5013
5014    /// Compute inverse FFT, reusing cached plans.
5015    ///
5016    /// This is more efficient than calling the standalone `irfft()` function
5017    /// repeatedly for the same FFT size.
5018    ///
5019    /// # Errors
5020    /// Returns an error if:
5021    ///
5022    /// - The calculated expected length of `spectrum` doesn't match its actual length
5023    ///
5024    /// # Examples
5025    ///
5026    /// ```
5027    /// use spectrograms::*;
5028    /// use non_empty_slice::{non_empty_vec, NonEmptySlice};
5029    ///
5030    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
5031    /// let mut planner = FftPlanner::new();
5032    ///
5033    /// // Forward FFT
5034    /// let signal = non_empty_vec![1.0; nzu!(512)];
5035    /// let spectrum = planner.fft(&signal, nzu!(512))?;
5036    ///
5037    /// // Inverse FFT
5038    /// let spectrum_slice = NonEmptySlice::new(spectrum.as_slice().unwrap()).unwrap();
5039    /// let reconstructed = planner.irfft(spectrum_slice, nzu!(512))?;
5040    ///
5041    /// assert_eq!(reconstructed.len(), nzu!(512));
5042    /// # Ok(())
5043    /// # }
5044    /// ```
5045    #[inline]
5046    pub fn irfft(
5047        &mut self,
5048        spectrum: &NonEmptySlice<Complex<f64>>,
5049        n_fft: NonZeroUsize,
5050    ) -> SpectrogramResult<NonEmptyVec<f64>> {
5051        use crate::fft_backend::{C2rPlan, C2rPlanner, r2c_output_size};
5052
5053        let expected_len = r2c_output_size(n_fft.get());
5054        if spectrum.len().get() != expected_len {
5055            return Err(SpectrogramError::dimension_mismatch(
5056                expected_len,
5057                spectrum.len().get(),
5058            ));
5059        }
5060
5061        let mut plan = self.inner.plan_c2r(n_fft.get())?;
5062        let mut output = vec![0.0; n_fft.get()];
5063        plan.process(spectrum, &mut output)?;
5064        // Safety: output is non-empty since n_fft > 0
5065        let output = unsafe { NonEmptyVec::new_unchecked(output) };
5066        Ok(output)
5067    }
5068
5069    /// Compute power spectrum with optional windowing, reusing cached plans.
5070    ///
5071    /// # Automatic Zero-Padding
5072    ///
5073    /// If the input signal is shorter than `n_fft`, it will be automatically
5074    /// zero-padded to the required length.
5075    ///
5076    /// # Errors
5077    /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
5078    ///
5079    /// # Examples
5080    ///
5081    /// ```
5082    /// use spectrograms::*;
5083    /// use non_empty_slice::non_empty_vec;
5084    ///
5085    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
5086    /// let mut planner = FftPlanner::new();
5087    ///
5088    /// let signal = non_empty_vec![1.0; nzu!(512)];
5089    /// let power = planner.power_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
5090    ///
5091    /// assert_eq!(power.len(), nzu!(257));
5092    /// # Ok(())
5093    /// # }
5094    /// ```
5095    #[inline]
5096    pub fn power_spectrum(
5097        &mut self,
5098        samples: &NonEmptySlice<f64>,
5099        n_fft: NonZeroUsize,
5100        window: Option<WindowType>,
5101    ) -> SpectrogramResult<NonEmptyVec<f64>> {
5102        if samples.len() > n_fft {
5103            return Err(SpectrogramError::invalid_input(format!(
5104                "Input length ({}) exceeds FFT size ({})",
5105                samples.len(),
5106                n_fft
5107            )));
5108        }
5109
5110        let mut windowed = vec![0.0; n_fft.get()];
5111        windowed[..samples.len().get()].copy_from_slice(samples);
5112        if let Some(win_type) = window {
5113            let window_samples = make_window(win_type, n_fft);
5114            for i in 0..n_fft.get() {
5115                windowed[i] *= window_samples[i];
5116            }
5117        }
5118
5119        // safety: windowed is non-empty since n_fft > 0
5120        let windowed = unsafe { NonEmptySlice::new_unchecked(&windowed) };
5121        let fft_result = self.fft(windowed, n_fft)?;
5122        let f = fft_result
5123            .iter()
5124            .map(num_complex::Complex::norm_sqr)
5125            .collect();
5126        // safety: fft_result is non-empty since fft returned successfully
5127        Ok(unsafe { NonEmptyVec::new_unchecked(f) })
5128    }
5129
5130    /// Compute magnitude spectrum with optional windowing, reusing cached plans.
5131    ///
5132    /// # Automatic Zero-Padding
5133    ///
5134    /// If the input signal is shorter than `n_fft`, it will be automatically
5135    /// zero-padded to the required length.
5136    ///
5137    /// # Errors
5138    /// Returns `InvalidInput` error if the input length exceeds `n_fft`.
5139    ///
5140    /// # Examples
5141    ///
5142    /// ```
5143    /// use spectrograms::*;
5144    /// use non_empty_slice::non_empty_vec;
5145    ///
5146    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
5147    /// let mut planner = FftPlanner::new();
5148    ///
5149    /// let signal = non_empty_vec![1.0; nzu!(512)];
5150    /// let magnitude = planner.magnitude_spectrum(&signal, nzu!(512), Some(WindowType::Hanning))?;
5151    ///
5152    /// assert_eq!(magnitude.len(), nzu!(257));
5153    /// # Ok(())
5154    /// # }
5155    /// ```
5156    #[inline]
5157    pub fn magnitude_spectrum(
5158        &mut self,
5159        samples: &NonEmptySlice<f64>,
5160        n_fft: NonZeroUsize,
5161        window: Option<WindowType>,
5162    ) -> SpectrogramResult<NonEmptyVec<f64>> {
5163        let power = self.power_spectrum(samples, n_fft, window)?;
5164        let power = power.iter().map(|&p| p.sqrt()).collect::<Vec<f64>>();
5165        // safety: power is non-empty since power_spectrum returned successfully
5166        Ok(unsafe { NonEmptyVec::new_unchecked(power) })
5167    }
5168}
5169
5170impl Default for FftPlanner {
5171    #[inline]
5172    fn default() -> Self {
5173        Self::new()
5174    }
5175}
5176
5177#[cfg(test)]
5178mod tests {
5179    use super::*;
5180
5181    #[test]
5182    fn test_sparse_matrix_basic() {
5183        // Create a simple 3x5 sparse matrix
5184        let mut sparse = SparseMatrix::new(3, 5);
5185
5186        // Row 0: only column 1 has value 2.0
5187        sparse.set(0, 1, 2.0);
5188
5189        // Row 1: columns 2 and 3
5190        sparse.set(1, 2, 0.5);
5191        sparse.set(1, 3, 1.5);
5192
5193        // Row 2: columns 0 and 4
5194        sparse.set(2, 0, 3.0);
5195        sparse.set(2, 4, 1.0);
5196
5197        // Test matrix-vector multiplication
5198        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0];
5199        let mut output = vec![0.0; 3];
5200
5201        sparse.multiply_vec(&input, &mut output);
5202
5203        // Expected results:
5204        // Row 0: 2.0 * 2.0 = 4.0
5205        // Row 1: 0.5 * 3.0 + 1.5 * 4.0 = 1.5 + 6.0 = 7.5
5206        // Row 2: 3.0 * 1.0 + 1.0 * 5.0 = 3.0 + 5.0 = 8.0
5207        assert_eq!(output[0], 4.0);
5208        assert_eq!(output[1], 7.5);
5209        assert_eq!(output[2], 8.0);
5210    }
5211
5212    #[test]
5213    fn test_sparse_matrix_zeros_ignored() {
5214        // Verify that zero values are not stored
5215        let mut sparse = SparseMatrix::new(2, 3);
5216
5217        sparse.set(0, 0, 1.0);
5218        sparse.set(0, 1, 0.0); // Should be ignored
5219        sparse.set(0, 2, 2.0);
5220
5221        // Only 2 values should be stored in row 0
5222        assert_eq!(sparse.values[0].len(), 2);
5223        assert_eq!(sparse.indices[0].len(), 2);
5224
5225        // The stored indices should be 0 and 2
5226        assert_eq!(sparse.indices[0], vec![0, 2]);
5227        assert_eq!(sparse.values[0], vec![1.0, 2.0]);
5228    }
5229
5230    #[test]
5231    fn test_loghz_matrix_sparsity() {
5232        // Verify that LogHz matrices are very sparse (1-2 non-zeros per row)
5233        let sample_rate = 16000.0;
5234        let n_fft = nzu!(512);
5235        let n_bins = nzu!(128);
5236        let f_min = 20.0;
5237        let f_max = sample_rate / 2.0;
5238
5239        let (matrix, _freqs) =
5240            build_loghz_matrix(sample_rate, n_fft, n_bins, f_min, f_max).unwrap();
5241
5242        // Each row should have at most 2 non-zero values (linear interpolation)
5243        for row_idx in 0..matrix.nrows() {
5244            let nnz = matrix.values[row_idx].len();
5245            assert!(
5246                nnz <= 2,
5247                "Row {} has {} non-zeros, expected at most 2",
5248                row_idx,
5249                nnz
5250            );
5251            assert!(nnz >= 1, "Row {} has no non-zeros", row_idx);
5252        }
5253
5254        // Total non-zeros should be close to n_bins * 2
5255        let total_nnz: usize = matrix.values.iter().map(|v| v.len()).sum();
5256        assert!(total_nnz <= n_bins.get() * 2);
5257        assert!(total_nnz >= n_bins.get()); // At least 1 per row
5258    }
5259
5260    #[test]
5261    fn test_mel_matrix_sparsity() {
5262        // Verify that Mel matrices are sparse (triangular filters)
5263        let sample_rate = 16000.0;
5264        let n_fft = nzu!(512);
5265        let n_mels = nzu!(40);
5266        let f_min = 0.0;
5267        let f_max = sample_rate / 2.0;
5268
5269        let matrix =
5270            build_mel_filterbank_matrix(sample_rate, n_fft, n_mels, f_min, f_max, MelNorm::None)
5271                .unwrap();
5272
5273        let n_fft_bins = r2c_output_size(n_fft.get());
5274
5275        // Calculate sparsity
5276        let total_nnz: usize = matrix.values.iter().map(|v| v.len()).sum();
5277        let total_elements = n_mels.get() * n_fft_bins;
5278        let sparsity = 1.0 - (total_nnz as f64 / total_elements as f64);
5279
5280        // Mel filterbanks should be >80% sparse
5281        assert!(
5282            sparsity > 0.8,
5283            "Mel matrix sparsity is only {:.1}%, expected >80%",
5284            sparsity * 100.0
5285        );
5286
5287        // Each mel filter should have significantly fewer than n_fft_bins non-zeros
5288        for row_idx in 0..matrix.nrows() {
5289            let nnz = matrix.values[row_idx].len();
5290            assert!(
5291                nnz < n_fft_bins / 2,
5292                "Mel filter {} is not sparse enough",
5293                row_idx
5294            );
5295        }
5296    }
5297}