Skip to main content

whisperforge_core/
audio.rs

1use anyhow::{Result, anyhow};
2use audioadapter_buffers::direct::SequentialSliceOfVecs;
3use burn::tensor::{Tensor, backend::Backend};
4use rubato::{Async, FixedAsync, Resampler, SincInterpolationParameters, WindowFunction};
5use rustfft::{FftPlanner, num_complex::Complex};
6use std::f32::consts::PI;
7
8#[cfg(feature = "file-io")]
9use std::path::Path;
10#[cfg(feature = "file-io")]
11use symphonia::core::codecs::audio::CODEC_ID_NULL_AUDIO;
12#[cfg(feature = "file-io")]
13use symphonia::core::{
14    codecs::audio::AudioDecoderOptions,
15    formats::{FormatOptions, probe::Hint},
16    io::MediaSourceStream,
17    meta::MetadataOptions,
18};
19
20// Whisper audio parameters
21pub const WHISPER_SAMPLE_RATE: u32 = 16000;
22pub const WHISPER_N_FFT: usize = 400;
23pub const WHISPER_HOP_LENGTH: usize = 160;
24pub const WHISPER_N_MELS: usize = 80;
25pub const WHISPER_CHUNK_LENGTH: usize = 30; // seconds
26
27#[derive(Debug, Clone)]
28pub struct AudioData {
29    pub samples: Vec<f32>,
30    pub sample_rate: u32,
31    pub channels: u16,
32}
33
34impl AudioData {
35    pub fn new(samples: Vec<f32>, sample_rate: u32, channels: u16) -> Self {
36        Self {
37            samples,
38            sample_rate,
39            channels,
40        }
41    }
42
43    pub fn duration(&self) -> f32 {
44        self.samples.len() as f32 / (self.sample_rate as f32 * self.channels as f32)
45    }
46
47    pub fn to_mono(&self) -> AudioData {
48        if self.channels == 1 {
49            return self.clone();
50        }
51
52        let mono_samples: Vec<f32> = self
53            .samples
54            .chunks(self.channels as usize)
55            .map(|chunk| chunk.iter().sum::<f32>() / self.channels as f32)
56            .collect();
57
58        AudioData {
59            samples: mono_samples,
60            sample_rate: self.sample_rate,
61            channels: 1,
62        }
63    }
64
65    pub fn resample(&self, target_sample_rate: u32) -> Result<AudioData> {
66        if self.sample_rate == target_sample_rate {
67            return Ok(self.clone());
68        }
69
70        let f_ratio = target_sample_rate as f64 / self.sample_rate as f64;
71        let params = SincInterpolationParameters {
72            sinc_len: 256,
73            f_cutoff: 0.95,
74            interpolation: rubato::SincInterpolationType::Cubic,
75            oversampling_factor: 256,
76            window: WindowFunction::BlackmanHarris2,
77        };
78
79        // A reasonable chunk size in frames
80        let chunk_size = 1024;
81
82        let mut resampler = Async::<f32>::new_sinc(
83            f_ratio,
84            2.0,
85            &params,
86            chunk_size,
87            self.channels as usize,
88            FixedAsync::Input,
89        )
90        .map_err(|e| anyhow!("Failed to create resampler: {}", e))?;
91
92        // Convert interleaved samples to multi-channel format for rubato
93        let frames_per_channel = self.samples.len() / self.channels as usize;
94        let mut input_channels: Vec<Vec<f32>> =
95            vec![Vec::with_capacity(frames_per_channel); self.channels as usize];
96
97        // Deinterleave samples: [L,R,L,R,...] -> [[L,L,...], [R,R,...]]
98        for (i, &sample) in self.samples.iter().enumerate() {
99            let channel = i % self.channels as usize;
100            input_channels[channel].push(sample);
101        }
102
103        let input_adapter =
104            SequentialSliceOfVecs::new(&input_channels, self.channels as usize, frames_per_channel)
105                .map_err(|e| anyhow!("Failed to create input adapter: {}", e))?;
106
107        // The resampler processes data in chunks. We need to know the output size.
108        let estimated_output_frames = (frames_per_channel as f64 * f_ratio) as usize;
109
110        let mut output_channels: Vec<Vec<f32>> =
111            vec![vec![0.0f32; estimated_output_frames]; self.channels as usize];
112        let mut output_adapter = SequentialSliceOfVecs::new_mut(
113            &mut output_channels,
114            self.channels as usize,
115            estimated_output_frames,
116        )
117        .map_err(|e| anyhow!("Failed to create output adapter: {}", e))?;
118
119        // Use an indexing helper struct for chunked processing
120        let mut indexing = rubato::Indexing {
121            input_offset: 0,
122            output_offset: 0,
123            active_channels_mask: None,
124            partial_len: None,
125        };
126
127        let mut input_frames_left = frames_per_channel;
128        let mut input_frames_next = resampler.input_frames_next();
129
130        // Loop over all full chunks.
131        while input_frames_left >= input_frames_next {
132            let (frames_read, frames_written) = resampler
133                .process_into_buffer(&input_adapter, &mut output_adapter, Some(&indexing))
134                .map_err(|e| anyhow!("Resampling failed: {}", e))?;
135
136            indexing.input_offset += frames_read;
137            indexing.output_offset += frames_written;
138            input_frames_left -= frames_read;
139            input_frames_next = resampler.input_frames_next();
140        }
141
142        // Note: Any remaining frames < chunk_size are buffered internally by the resampler.
143        // Since we're processing a complete file here (not streaming), the resampler state
144        // persists and will output them naturally as part of the next/final processing.
145
146        // Interleave the output channels back into a single vector
147        let actual_output_frames = indexing.output_offset;
148        let mut output_samples = Vec::with_capacity(actual_output_frames * self.channels as usize);
149
150        for frame in 0..actual_output_frames {
151            for ch in &output_channels {
152                output_samples.push(ch[frame]);
153            }
154        }
155
156        Ok(AudioData {
157            samples: output_samples,
158            sample_rate: target_sample_rate,
159            channels: self.channels,
160        })
161    }
162
163    pub fn to_16khz_mono(&self) -> Result<AudioData> {
164        let mono = self.to_mono();
165        mono.resample(16000)
166    }
167}
168
169/// Load an audio file into raw f32 samples.
170///
171/// Supports WAV, MP3, FLAC, OGG/Vorbis, AAC/M4A, and MKV/WebM audio via
172/// symphonia. The file format is detected from the extension first; symphonia
173/// falls back to content-based probing when the extension is absent or unknown.
174///
175/// Returns interleaved f32 samples in `[-1.0, 1.0]` at the file's native
176/// sample rate and channel count. Call [`AudioData::to_16khz_mono`] to
177/// normalise before passing to the mel spectrogram pipeline.
178///
179/// For WASM / no-filesystem environments, construct [`AudioData`] directly
180/// from samples obtained via the Web Audio API.
181#[cfg(feature = "file-io")]
182pub fn load_audio_file<P: AsRef<Path>>(path: P) -> Result<AudioData> {
183    let path = path.as_ref();
184    let file = std::fs::File::open(path)
185        .map_err(|e| anyhow!("Failed to open audio file '{}': {}", path.display(), e))?;
186    let mss = MediaSourceStream::new(Box::new(file), Default::default());
187
188    let mut hint = Hint::new();
189    if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
190        hint.with_extension(ext);
191    }
192
193    let mut format = symphonia::default::get_probe()
194        .probe(
195            &hint,
196            mss,
197            FormatOptions::default(),
198            MetadataOptions::default(),
199        )
200        .map_err(|e| anyhow!("Unsupported audio format '{}': {}", path.display(), e))?;
201
202    let track = format
203        .tracks()
204        .iter()
205        .find(|t| {
206            t.codec_params
207                .as_ref()
208                .and_then(|cp| cp.audio())
209                .map(|ap| ap.codec != CODEC_ID_NULL_AUDIO)
210                .unwrap_or(false)
211        })
212        .ok_or_else(|| anyhow!("No audio tracks found in '{}'", path.display()))?;
213
214    let track_id = track.id;
215    let codec_params = track
216        .codec_params
217        .as_ref()
218        .and_then(|cp| cp.audio())
219        .ok_or_else(|| anyhow!("Missing codec parameters in '{}'", path.display()))?;
220
221    let sample_rate = codec_params
222        .sample_rate
223        .ok_or_else(|| anyhow!("Unknown sample rate in '{}'", path.display()))?;
224    let channels = codec_params
225        .channels
226        .as_ref()
227        .ok_or_else(|| anyhow!("Unknown channel count in '{}'", path.display()))?
228        .count() as u16;
229
230    let mut decoder = symphonia::default::get_codecs()
231        .make_audio_decoder(codec_params, &AudioDecoderOptions::default())
232        .map_err(|e| anyhow!("Failed to create decoder for '{}': {}", path.display(), e))?;
233
234    let mut samples: Vec<f32> = Vec::new();
235
236    loop {
237        let packet = match format.next_packet() {
238            Ok(Some(p)) => p,
239            Ok(None) => {
240                break;
241            }
242            Err(symphonia::core::errors::Error::ResetRequired) => {
243                continue;
244            }
245            Err(e) => {
246                return Err(anyhow!("Error reading '{}': {}", path.display(), e));
247            }
248        };
249
250        if packet.track_id != track_id {
251            continue;
252        }
253
254        let decoded = match decoder.decode(&packet) {
255            Ok(d) => d,
256            Err(symphonia::core::errors::Error::IoError(_)) => continue,
257            Err(e) => return Err(anyhow!("Decode error in '{}': {}", path.display(), e)),
258        };
259
260        decoded.copy_to_vec_interleaved::<f32>(&mut samples);
261    }
262
263    Ok(AudioData {
264        samples,
265        sample_rate,
266        channels,
267    })
268}
269
270/// Apply center=True STFT reflection padding to a 16 kHz mono sample slice.
271///
272/// The caller is responsible for ensuring `samples` has already been padded/truncated
273/// to the desired length (typically `30 * WHISPER_SAMPLE_RATE` = 480,000). No
274/// resampling or conversion is performed. Returns a Vec ready for `compute_stft_magnitudes`.
275pub fn prepare_centered_samples_raw(samples: &[f32], n_fft: usize) -> Vec<f32> {
276    let pad_len = n_fft / 2;
277    let n = samples.len();
278    let mut centered = Vec::with_capacity(n + 2 * pad_len);
279    for i in (1..=pad_len).rev() {
280        centered.push(samples[i]);
281    }
282    centered.extend_from_slice(samples);
283    for i in 0..pad_len {
284        centered.push(samples[n - 2 - i]);
285    }
286    centered
287}
288
289/// Resample/pad/center-reflect audio to produce centered samples ready for STFT.
290///
291/// Returns the reflection-padded sample vector that `compute_stft_magnitudes` and
292/// `compute_stft_power_gpu` both expect as input.
293fn prepare_centered_samples(audio: &AudioData, n_fft: usize) -> Result<Vec<f32>> {
294    let audio = if audio.channels != 1 || audio.sample_rate != WHISPER_SAMPLE_RATE {
295        audio.to_16khz_mono()?
296    } else {
297        audio.clone()
298    };
299
300    let target_samples = 30 * WHISPER_SAMPLE_RATE as usize;
301    let mut padded = audio.samples;
302    if padded.len() > target_samples {
303        padded.truncate(target_samples);
304    } else {
305        padded.resize(target_samples, 0.0);
306    }
307
308    Ok(prepare_centered_samples_raw(&padded, n_fft))
309}
310
311/// Apply mel filterbank, log10 compression, and Whisper normalization to a
312/// power-spectrum tensor of shape `[n_freqs, n_frames]`.
313///
314/// Returns `[1, n_mels, n_frames]`.
315fn mel_compress<B: Backend>(
316    ps: Tensor<B, 2>,
317    n_mels: usize,
318    n_fft: usize,
319    device: &B::Device,
320) -> Tensor<B, 3> {
321    let [n_freqs, n_frames] = ps.dims();
322    let mel_filters = create_mel_filter_bank(n_fft, n_mels, WHISPER_SAMPLE_RATE as f32);
323    let mf_flat: Vec<f32> = mel_filters.into_iter().flatten().collect();
324    let mf_tensor: Tensor<B, 2> =
325        Tensor::<B, 1>::from_floats(mf_flat.as_slice(), device).reshape([n_mels, n_freqs]);
326
327    let mel: Tensor<B, 2> = mf_tensor.matmul(ps);
328
329    let log10_e = std::f32::consts::LOG10_E;
330    let log_mel: Tensor<B, 2> = mel.clamp_min(1e-10_f32).log().mul_scalar(log10_e);
331
332    let max_val: Tensor<B, 2> = log_mel.clone().max().reshape([1, 1]);
333    let log_mel = (log_mel - max_val.clone())
334        .clamp_min(-8.0_f32)
335        .add(max_val)
336        .add_scalar(4.0_f32)
337        .div_scalar(4.0_f32);
338
339    log_mel.reshape([1, n_mels, n_frames])
340}
341
342/// Compute mel spectrogram for Whisper model input (CPU STFT path).
343///
344/// Matches OpenAI Whisper's Python preprocessing exactly:
345/// 1. Resample/convert to 16 kHz mono.
346/// 2. Pad (or trim) audio to exactly 30 seconds in **sample** space before the STFT.
347/// 3. Apply `center=True` reflection padding (`n_fft/2` samples each side) so each
348///    STFT frame is centred on its sample — matches `torch.stft` default.
349/// 4. Use power spectrum and Slaney-normalised mel filters.
350///
351/// Always returns a tensor of shape `[1, n_mels, 3000]` (30 s at 100 fps).
352///
353/// The mel filterbank matmul and log-compression run on the target `B` backend
354/// (GPU when using WGPU); only the STFT uses CPU `rustfft`.
355/// For full GPU STFT use `compute_mel_spectrogram_wgpu` (feature `cubecl-stft`).
356pub fn compute_mel_spectrogram<B: Backend>(
357    audio: &AudioData,
358    n_fft: usize,
359    hop_length: usize,
360    n_mels: usize,
361    device: &B::Device,
362) -> Result<Tensor<B, 3>> {
363    let centered = prepare_centered_samples(audio, n_fft)?;
364    let magnitudes = compute_stft_magnitudes(&centered, n_fft, hop_length);
365
366    let n_freqs = n_fft / 2 + 1;
367    let n_frames = magnitudes[0].len().saturating_sub(1);
368
369    let ps_flat: Vec<f32> = (0..n_freqs)
370        .flat_map(|f| magnitudes[f][..n_frames].iter().copied())
371        .collect();
372    let ps_tensor: Tensor<B, 2> =
373        Tensor::<B, 1>::from_floats(ps_flat.as_slice(), device).reshape([n_freqs, n_frames]);
374
375    Ok(mel_compress(ps_tensor, n_mels, n_fft, device))
376}
377
378/// Compute mel spectrogram from a raw 16 kHz mono sample slice.
379///
380/// `samples` must be exactly `30 * WHISPER_SAMPLE_RATE` (480,000) samples — caller is
381/// responsible for padding/truncating. Returns `[1, n_mels, 3000]`.
382pub fn compute_mel_from_samples<B: Backend>(
383    samples: &[f32],
384    n_fft: usize,
385    hop_length: usize,
386    n_mels: usize,
387    device: &B::Device,
388) -> Result<Tensor<B, 3>> {
389    let expected = 30 * WHISPER_SAMPLE_RATE as usize;
390    anyhow::ensure!(
391        samples.len() == expected,
392        "compute_mel_from_samples: expected {} samples, got {}",
393        expected,
394        samples.len()
395    );
396    let centered = prepare_centered_samples_raw(samples, n_fft);
397    let magnitudes = compute_stft_magnitudes(&centered, n_fft, hop_length);
398    let n_freqs = n_fft / 2 + 1;
399    let n_frames = magnitudes[0].len().saturating_sub(1);
400    let ps_flat: Vec<f32> = (0..n_freqs)
401        .flat_map(|f| magnitudes[f][..n_frames].iter().copied())
402        .collect();
403    let ps_tensor: Tensor<B, 2> =
404        Tensor::<B, 1>::from_floats(ps_flat.as_slice(), device).reshape([n_freqs, n_frames]);
405    Ok(mel_compress(ps_tensor, n_mels, n_fft, device))
406}
407
408/// Compute mel spectrogram using the GPU DFT kernel (feature `cubecl-stft`).
409///
410/// Identical output to `compute_mel_spectrogram` but the STFT runs on the WGPU
411/// device via CubeCL. Requires a bare `CubeBackend<WgpuRuntime,f32,i32,u32>` —
412/// the Fusion-wrapped `Wgpu` alias cannot expose the inner `Runtime`.
413#[cfg(feature = "cubecl-stft")]
414pub fn compute_mel_spectrogram_wgpu(
415    audio: &AudioData,
416    n_fft: usize,
417    hop_length: usize,
418    n_mels: usize,
419    device: &burn_wgpu::WgpuDevice,
420) -> Result<Tensor<WgpuBackend, 3>> {
421    use cubecl::prelude::Runtime;
422    let centered = prepare_centered_samples(audio, n_fft)?;
423    let n_freqs = n_fft / 2 + 1;
424    let n_frames_total = (centered.len() - n_fft) / hop_length + 1;
425    let n_frames = n_frames_total.saturating_sub(1);
426
427    let client = burn_wgpu::WgpuRuntime::client(device);
428    let gpu_out = crate::stft_gpu::compute_stft_power_gpu(
429        &client,
430        &centered,
431        n_fft,
432        hop_length,
433        n_frames_total,
434    );
435
436    // GPU output is frame-major [n_frames_total, n_freqs]; drop last frame and
437    // transpose to [n_freqs, n_frames] to match mel_compress expectation.
438    let ps_tensor: Tensor<WgpuBackend, 2> =
439        Tensor::<WgpuBackend, 1>::from_floats(&gpu_out[..n_frames * n_freqs], device)
440            .reshape([n_frames, n_freqs])
441            .transpose();
442
443    Ok(mel_compress(ps_tensor, n_mels, n_fft, device))
444}
445
446/// GPU variant of `compute_mel_from_samples` using the CubeCL STFT kernel.
447///
448/// `samples` must be exactly `30 * WHISPER_SAMPLE_RATE` (480,000) samples. Returns
449/// `[1, n_mels, 3000]` computed on the WGPU device.
450#[cfg(feature = "cubecl-stft")]
451pub fn compute_mel_from_samples_wgpu(
452    samples: &[f32],
453    n_fft: usize,
454    hop_length: usize,
455    n_mels: usize,
456    device: &burn_wgpu::WgpuDevice,
457) -> Result<Tensor<WgpuBackend, 3>> {
458    use cubecl::prelude::Runtime;
459    let expected = 30 * WHISPER_SAMPLE_RATE as usize;
460    anyhow::ensure!(
461        samples.len() == expected,
462        "compute_mel_from_samples_wgpu: expected {} samples, got {}",
463        expected,
464        samples.len()
465    );
466    let centered = prepare_centered_samples_raw(samples, n_fft);
467    let n_freqs = n_fft / 2 + 1;
468    let n_frames_total = (centered.len() - n_fft) / hop_length + 1;
469    let n_frames = n_frames_total.saturating_sub(1);
470    let client = burn_wgpu::WgpuRuntime::client(device);
471    let gpu_out = crate::stft_gpu::compute_stft_power_gpu(
472        &client,
473        &centered,
474        n_fft,
475        hop_length,
476        n_frames_total,
477    );
478    let ps_tensor: Tensor<WgpuBackend, 2> =
479        Tensor::<WgpuBackend, 1>::from_floats(&gpu_out[..n_frames * n_freqs], device)
480            .reshape([n_frames, n_freqs])
481            .transpose();
482    Ok(mel_compress(ps_tensor, n_mels, n_fft, device))
483}
484
485/// Concrete backend type for the bare WGPU path (no Fusion wrapper).
486#[cfg(feature = "cubecl-stft")]
487pub type WgpuBackend = burn_wgpu::CubeBackend<burn_wgpu::WgpuRuntime, f32, i32, u32>;
488
489/// Stack mel spectrograms for multiple audio chunks into a single batch tensor.
490///
491/// Returns `[N, n_mels, 3000]` where N = `chunks.len()`. Each chunk is padded /
492/// trimmed to exactly 3000 frames by `compute_mel_spectrogram`, so all slices are
493/// the same shape and can be concatenated without further alignment.
494///
495/// Feeding the batch to the encoder in one call amortises GPU kernel launch
496/// overhead and keeps the GPU compute unit saturated across chunks.
497pub fn batch_mel_spectrograms<B: Backend>(
498    chunks: &[AudioData],
499    n_fft: usize,
500    hop_length: usize,
501    n_mels: usize,
502    device: &B::Device,
503) -> Result<Tensor<B, 3>> {
504    anyhow::ensure!(!chunks.is_empty(), "batch_mel_spectrograms: no chunks");
505    let mels: Vec<Tensor<B, 3>> = chunks
506        .iter()
507        .map(|c| compute_mel_spectrogram(c, n_fft, hop_length, n_mels, device))
508        .collect::<Result<_>>()?;
509    Ok(Tensor::cat(mels, 0))
510}
511
512/// GPU variant of `batch_mel_spectrograms` using the CubeCL STFT kernel.
513///
514/// Requires feature `cubecl-stft` and a bare `WgpuDevice`; each chunk's STFT
515/// runs on GPU while the mel filterbank matmul and normalization also run on GPU.
516#[cfg(feature = "cubecl-stft")]
517pub fn batch_mel_spectrograms_wgpu(
518    chunks: &[AudioData],
519    n_fft: usize,
520    hop_length: usize,
521    n_mels: usize,
522    device: &burn_wgpu::WgpuDevice,
523) -> Result<Tensor<WgpuBackend, 3>> {
524    anyhow::ensure!(!chunks.is_empty(), "batch_mel_spectrograms_wgpu: no chunks");
525    let mels: Vec<Tensor<WgpuBackend, 3>> = chunks
526        .iter()
527        .map(|c| compute_mel_spectrogram_wgpu(c, n_fft, hop_length, n_mels, device))
528        .collect::<Result<_>>()?;
529    Ok(Tensor::cat(mels, 0))
530}
531
532/// Compute STFT and return magnitude spectrogram.
533/// Returns [n_freqs][n_frames] where n_freqs = n_fft/2 + 1
534#[allow(clippy::needless_range_loop)]
535fn compute_stft_magnitudes(samples: &[f32], n_fft: usize, hop_length: usize) -> Vec<Vec<f32>> {
536    let n_freqs = n_fft / 2 + 1;
537    let n_frames = if samples.len() >= n_fft {
538        (samples.len() - n_fft) / hop_length + 1
539    } else {
540        0
541    };
542
543    if n_frames == 0 {
544        return vec![vec![0.0]; n_freqs];
545    }
546
547    // Create Hann window
548    let window: Vec<f32> = (0..n_fft)
549        .map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / n_fft as f32).cos()))
550        .collect();
551
552    // Setup FFT
553    let mut planner = FftPlanner::<f32>::new();
554    let fft = planner.plan_fft_forward(n_fft);
555
556    // Output: [n_freqs][n_frames]
557    let mut magnitudes = vec![vec![0.0f32; n_frames]; n_freqs];
558
559    // Process each frame
560    let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n_fft];
561
562    for frame_idx in 0..n_frames {
563        let start = frame_idx * hop_length;
564
565        // Apply window and copy to buffer
566        for i in 0..n_fft {
567            let sample = if start + i < samples.len() {
568                samples[start + i]
569            } else {
570                0.0
571            };
572            buffer[i] = Complex::new(sample * window[i], 0.0);
573        }
574
575        // Compute FFT in-place
576        fft.process(&mut buffer);
577
578        // Extract power spectrum (magnitude squared) for positive frequencies.
579        // Whisper uses |STFT|^2, not |STFT|.
580        for freq in 0..n_freqs {
581            magnitudes[freq][frame_idx] = buffer[freq].norm_sqr();
582        }
583    }
584
585    magnitudes
586}
587
588/// Create mel filter bank matrix matching OpenAI Whisper / librosa defaults.
589///
590/// Uses the Slaney mel scale (linear below 1 kHz, log above) with Slaney
591/// normalization (`2 / (upper_hz - lower_hz)`) and FFT frequencies
592/// `k * sample_rate / n_fft`. This replicates `librosa.filters.mel` exactly.
593///
594/// Returns `[n_mels][n_freqs]` where `n_freqs = n_fft / 2 + 1`.
595fn create_mel_filter_bank(n_fft: usize, n_mels: usize, sample_rate: f32) -> Vec<Vec<f32>> {
596    let n_freqs = n_fft / 2 + 1;
597    let fmax = sample_rate / 2.0;
598
599    // Slaney mel scale: linear below 1 kHz, log above.
600    let f_sp: f32 = 200.0 / 3.0;
601    let min_log_hz: f32 = 1000.0;
602    let min_log_mel: f32 = min_log_hz / f_sp;
603    let logstep: f32 = 6.4f32.ln() / 27.0;
604
605    let hz_to_mel = |f: f32| -> f32 {
606        if f >= min_log_hz {
607            min_log_mel + (f / min_log_hz).ln() / logstep
608        } else {
609            f / f_sp
610        }
611    };
612    let mel_to_hz = |m: f32| -> f32 {
613        if m >= min_log_mel {
614            min_log_hz * ((m - min_log_mel) * logstep).exp()
615        } else {
616            f_sp * m
617        }
618    };
619
620    // n_mels + 2 equally spaced points in mel space
621    let mel_min = hz_to_mel(0.0);
622    let mel_max = hz_to_mel(fmax);
623    let hz_pts: Vec<f32> = (0..=n_mels + 1)
624        .map(|i| mel_to_hz(mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32))
625        .collect();
626
627    // FFT center frequencies: k * SR / N_FFT  (matches np.fft.rfftfreq)
628    let fftfreqs: Vec<f32> = (0..n_freqs)
629        .map(|k| k as f32 * sample_rate / n_fft as f32)
630        .collect();
631
632    let mut filters = vec![vec![0.0f32; n_freqs]; n_mels];
633    for (i, filt) in filters.iter_mut().enumerate() {
634        let lower = hz_pts[i];
635        let center = hz_pts[i + 1];
636        let upper = hz_pts[i + 2];
637        // Slaney normalization: scale so that the filter sums to 2/(upper-lower)
638        let enorm = 2.0 / (upper - lower).max(1e-8);
639        for (k, &freq) in fftfreqs.iter().enumerate() {
640            let rising = if center > lower {
641                ((freq - lower) / (center - lower)).max(0.0)
642            } else {
643                0.0
644            };
645            let falling = if upper > center {
646                ((upper - freq) / (upper - center)).max(0.0)
647            } else {
648                0.0
649            };
650            filt[k] = rising.min(falling) * enorm;
651        }
652    }
653
654    filters
655}
656
657/// Pad or trim audio to exactly 30 seconds (Whisper chunk length).
658pub fn pad_or_trim_audio(audio: &AudioData, length_samples: usize) -> AudioData {
659    let mut samples = audio.samples.clone();
660
661    if samples.len() > length_samples {
662        samples.truncate(length_samples);
663    } else if samples.len() < length_samples {
664        samples.resize(length_samples, 0.0);
665    }
666
667    AudioData {
668        samples,
669        sample_rate: audio.sample_rate,
670        channels: audio.channels,
671    }
672}
673
674#[cfg(test)]
675mod tests {
676    use super::*;
677    use burn_flex::FlexDevice;
678
679    #[test]
680    fn test_audio_data_creation() {
681        let audio = AudioData::new(vec![0.0, 0.5, -0.5, 1.0], 44100, 2);
682        assert_eq!(audio.duration(), 2.0 / 44100.0);
683        assert_eq!(audio.channels, 2);
684    }
685
686    #[test]
687    fn test_mono_conversion() {
688        let stereo = AudioData::new(vec![1.0, 2.0, 3.0, 4.0], 44100, 2);
689        let mono = stereo.to_mono();
690        assert_eq!(mono.samples, vec![1.5, 3.5]);
691        assert_eq!(mono.channels, 1);
692    }
693
694    #[test]
695    fn test_hann_window() {
696        // Verify Hann window properties
697        let n = 400;
698        let window: Vec<f32> = (0..n)
699            .map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / n as f32).cos()))
700            .collect();
701
702        // Window should start and end near 0
703        assert!(window[0] < 0.01);
704        assert!(window[n - 1] < 0.01);
705
706        // Window should peak at center
707        let center = n / 2;
708        assert!(window[center] > 0.99);
709    }
710
711    #[test]
712    fn test_mel_filter_bank() {
713        let filters = create_mel_filter_bank(400, 80, 16000.0);
714
715        // Should have 80 mel bins
716        assert_eq!(filters.len(), 80);
717
718        // Each filter should have n_fft/2+1 = 201 frequency bins
719        assert_eq!(filters[0].len(), 201);
720
721        // Filters should be non-negative
722        for filter in &filters {
723            for &val in filter {
724                assert!(val >= 0.0, "Filter values should be non-negative");
725            }
726        }
727
728        // Slaney-normalized filters don't sum to 1; they use enorm = 2/(upper_hz-lower_hz).
729        // Just verify each filter has at least one non-zero bin and a reasonable peak value.
730        for filter in &filters {
731            let max_val = filter.iter().cloned().fold(0.0f32, f32::max);
732            // Peak should be positive (filter covers some FFT bins) and bounded
733            assert!(
734                max_val > 0.0,
735                "Filter should have at least one non-zero bin"
736            );
737            assert!(
738                max_val < 1.0,
739                "Filter peak should be less than 1.0 (Slaney norm)"
740            );
741        }
742    }
743
744    #[test]
745    fn test_stft_magnitudes() {
746        // Test with a simple sine wave
747        let sample_rate = 16000.0;
748        let freq = 440.0; // A4 note
749        let duration = 0.1; // 100ms
750        let n_samples = (sample_rate * duration) as usize;
751
752        let samples: Vec<f32> = (0..n_samples)
753            .map(|i| (2.0 * PI * freq * i as f32 / sample_rate).sin())
754            .collect();
755
756        let magnitudes = compute_stft_magnitudes(&samples, 400, 160);
757
758        // Should have n_fft/2+1 frequency bins
759        assert_eq!(magnitudes.len(), 201);
760
761        // Should have multiple frames
762        assert!(!magnitudes[0].is_empty(), "Should have at least one frame");
763
764        // Magnitudes should be non-negative
765        for freq_bin in &magnitudes {
766            for &mag in freq_bin {
767                assert!(mag >= 0.0, "Magnitudes should be non-negative");
768            }
769        }
770    }
771
772    #[test]
773    fn test_mel_spectrogram() {
774        let audio = AudioData::new(vec![0.0; 16000], 16000, 1); // 1 second of silence
775        let result = compute_mel_spectrogram::<burn_flex::Flex<f32>>(
776            &audio,
777            WHISPER_N_FFT,
778            WHISPER_HOP_LENGTH,
779            WHISPER_N_MELS,
780            &FlexDevice,
781        );
782
783        assert!(result.is_ok());
784        let tensor = result.unwrap();
785        let dims = tensor.dims();
786
787        // Should be [1, n_mels, n_frames]
788        assert_eq!(dims[0], 1, "Batch size should be 1");
789        assert_eq!(
790            dims[1], WHISPER_N_MELS,
791            "Should have {} mel bins",
792            WHISPER_N_MELS
793        );
794
795        // compute_mel_spectrogram always pads to 30 s (480000 samples) with center=True
796        // STFT reflection padding → always returns exactly 3000 mel frames.
797        assert_eq!(dims[2], 3000, "Should always return 3000 mel frames");
798    }
799
800    #[test]
801    fn test_mel_spectrogram_with_sine() {
802        // Test with a 440 Hz sine wave
803        let sample_rate = 16000;
804        let freq = 440.0;
805        let duration = 1.0;
806        let n_samples = (sample_rate as f32 * duration) as usize;
807
808        let samples: Vec<f32> = (0..n_samples)
809            .map(|i| 0.5 * (2.0 * PI * freq * i as f32 / sample_rate as f32).sin())
810            .collect();
811
812        let audio = AudioData::new(samples, sample_rate, 1);
813        let result = compute_mel_spectrogram::<burn_flex::Flex<f32>>(
814            &audio,
815            WHISPER_N_FFT,
816            WHISPER_HOP_LENGTH,
817            WHISPER_N_MELS,
818            &FlexDevice,
819        );
820
821        assert!(result.is_ok());
822        let tensor = result.unwrap();
823
824        // Values should be in reasonable range after normalization
825        let data = tensor.to_data();
826        let values: Vec<f32> = data.to_vec().unwrap();
827
828        for &val in &values {
829            assert!(val.is_finite(), "All values should be finite");
830            // After Whisper normalization, values should be roughly in [-1, 1] range
831            assert!(
832                (-2.0..=2.0).contains(&val),
833                "Values should be in reasonable range, got {}",
834                val
835            );
836        }
837    }
838
839    #[test]
840    fn test_pad_or_trim() {
841        let audio = AudioData::new(vec![1.0, 2.0, 3.0], 16000, 1);
842
843        // Test trimming
844        let trimmed = pad_or_trim_audio(&audio, 2);
845        assert_eq!(trimmed.samples, vec![1.0, 2.0]);
846
847        // Test padding
848        let padded = pad_or_trim_audio(&audio, 5);
849        assert_eq!(padded.samples, vec![1.0, 2.0, 3.0, 0.0, 0.0]);
850    }
851
852    #[test]
853    fn test_compute_mel_from_samples_matches_audio_data() {
854        // 30 s of 440 Hz sine at 16 kHz mono — exercises a non-trivial mel pattern.
855        let samples: Vec<f32> = (0..480_000)
856            .map(|i| 0.5 * (2.0 * PI * 440.0 * i as f32 / 16000.0).sin())
857            .collect();
858
859        let audio = AudioData::new(samples.clone(), WHISPER_SAMPLE_RATE, 1);
860
861        let mel_via_audio = compute_mel_spectrogram::<burn_flex::Flex<f32>>(
862            &audio,
863            WHISPER_N_FFT,
864            WHISPER_HOP_LENGTH,
865            WHISPER_N_MELS,
866            &FlexDevice,
867        )
868        .unwrap();
869
870        let mel_via_raw = compute_mel_from_samples::<burn_flex::Flex<f32>>(
871            &samples,
872            WHISPER_N_FFT,
873            WHISPER_HOP_LENGTH,
874            WHISPER_N_MELS,
875            &FlexDevice,
876        )
877        .unwrap();
878
879        let a: Vec<f32> = mel_via_audio.to_data().to_vec().unwrap();
880        let b: Vec<f32> = mel_via_raw.to_data().to_vec().unwrap();
881        assert_eq!(a.len(), b.len(), "shape mismatch");
882        for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
883            assert_eq!(x, y, "mismatch at index {i}");
884        }
885    }
886}