1use anyhow::{Result, anyhow};
2use audioadapter_buffers::direct::SequentialSliceOfVecs;
3use burn::tensor::{Tensor, backend::Backend};
4use rubato::{Async, FixedAsync, Resampler, SincInterpolationParameters, WindowFunction};
5use rustfft::{FftPlanner, num_complex::Complex};
6use std::f32::consts::PI;
7
8#[cfg(feature = "file-io")]
9use std::path::Path;
10#[cfg(feature = "file-io")]
11use symphonia::core::codecs::audio::CODEC_ID_NULL_AUDIO;
12#[cfg(feature = "file-io")]
13use symphonia::core::{
14 codecs::audio::AudioDecoderOptions,
15 formats::{FormatOptions, probe::Hint},
16 io::MediaSourceStream,
17 meta::MetadataOptions,
18};
19
20pub const WHISPER_SAMPLE_RATE: u32 = 16000;
22pub const WHISPER_N_FFT: usize = 400;
23pub const WHISPER_HOP_LENGTH: usize = 160;
24pub const WHISPER_N_MELS: usize = 80;
25pub const WHISPER_CHUNK_LENGTH: usize = 30; #[derive(Debug, Clone)]
28pub struct AudioData {
29 pub samples: Vec<f32>,
30 pub sample_rate: u32,
31 pub channels: u16,
32}
33
34impl AudioData {
35 pub fn new(samples: Vec<f32>, sample_rate: u32, channels: u16) -> Self {
36 Self {
37 samples,
38 sample_rate,
39 channels,
40 }
41 }
42
43 pub fn duration(&self) -> f32 {
44 self.samples.len() as f32 / (self.sample_rate as f32 * self.channels as f32)
45 }
46
47 pub fn to_mono(&self) -> AudioData {
48 if self.channels == 1 {
49 return self.clone();
50 }
51
52 let mono_samples: Vec<f32> = self
53 .samples
54 .chunks(self.channels as usize)
55 .map(|chunk| chunk.iter().sum::<f32>() / self.channels as f32)
56 .collect();
57
58 AudioData {
59 samples: mono_samples,
60 sample_rate: self.sample_rate,
61 channels: 1,
62 }
63 }
64
65 pub fn resample(&self, target_sample_rate: u32) -> Result<AudioData> {
66 if self.sample_rate == target_sample_rate {
67 return Ok(self.clone());
68 }
69
70 let f_ratio = target_sample_rate as f64 / self.sample_rate as f64;
71 let params = SincInterpolationParameters {
72 sinc_len: 256,
73 f_cutoff: 0.95,
74 interpolation: rubato::SincInterpolationType::Cubic,
75 oversampling_factor: 256,
76 window: WindowFunction::BlackmanHarris2,
77 };
78
79 let chunk_size = 1024;
81
82 let mut resampler = Async::<f32>::new_sinc(
83 f_ratio,
84 2.0,
85 ¶ms,
86 chunk_size,
87 self.channels as usize,
88 FixedAsync::Input,
89 )
90 .map_err(|e| anyhow!("Failed to create resampler: {}", e))?;
91
92 let frames_per_channel = self.samples.len() / self.channels as usize;
94 let mut input_channels: Vec<Vec<f32>> =
95 vec![Vec::with_capacity(frames_per_channel); self.channels as usize];
96
97 for (i, &sample) in self.samples.iter().enumerate() {
99 let channel = i % self.channels as usize;
100 input_channels[channel].push(sample);
101 }
102
103 let input_adapter =
104 SequentialSliceOfVecs::new(&input_channels, self.channels as usize, frames_per_channel)
105 .map_err(|e| anyhow!("Failed to create input adapter: {}", e))?;
106
107 let estimated_output_frames = (frames_per_channel as f64 * f_ratio) as usize;
109
110 let mut output_channels: Vec<Vec<f32>> =
111 vec![vec![0.0f32; estimated_output_frames]; self.channels as usize];
112 let mut output_adapter = SequentialSliceOfVecs::new_mut(
113 &mut output_channels,
114 self.channels as usize,
115 estimated_output_frames,
116 )
117 .map_err(|e| anyhow!("Failed to create output adapter: {}", e))?;
118
119 let mut indexing = rubato::Indexing {
121 input_offset: 0,
122 output_offset: 0,
123 active_channels_mask: None,
124 partial_len: None,
125 };
126
127 let mut input_frames_left = frames_per_channel;
128 let mut input_frames_next = resampler.input_frames_next();
129
130 while input_frames_left >= input_frames_next {
132 let (frames_read, frames_written) = resampler
133 .process_into_buffer(&input_adapter, &mut output_adapter, Some(&indexing))
134 .map_err(|e| anyhow!("Resampling failed: {}", e))?;
135
136 indexing.input_offset += frames_read;
137 indexing.output_offset += frames_written;
138 input_frames_left -= frames_read;
139 input_frames_next = resampler.input_frames_next();
140 }
141
142 let actual_output_frames = indexing.output_offset;
148 let mut output_samples = Vec::with_capacity(actual_output_frames * self.channels as usize);
149
150 for frame in 0..actual_output_frames {
151 for ch in &output_channels {
152 output_samples.push(ch[frame]);
153 }
154 }
155
156 Ok(AudioData {
157 samples: output_samples,
158 sample_rate: target_sample_rate,
159 channels: self.channels,
160 })
161 }
162
163 pub fn to_16khz_mono(&self) -> Result<AudioData> {
164 let mono = self.to_mono();
165 mono.resample(16000)
166 }
167}
168
169#[cfg(feature = "file-io")]
182pub fn load_audio_file<P: AsRef<Path>>(path: P) -> Result<AudioData> {
183 let path = path.as_ref();
184 let file = std::fs::File::open(path)
185 .map_err(|e| anyhow!("Failed to open audio file '{}': {}", path.display(), e))?;
186 let mss = MediaSourceStream::new(Box::new(file), Default::default());
187
188 let mut hint = Hint::new();
189 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
190 hint.with_extension(ext);
191 }
192
193 let mut format = symphonia::default::get_probe()
194 .probe(
195 &hint,
196 mss,
197 FormatOptions::default(),
198 MetadataOptions::default(),
199 )
200 .map_err(|e| anyhow!("Unsupported audio format '{}': {}", path.display(), e))?;
201
202 let track = format
203 .tracks()
204 .iter()
205 .find(|t| {
206 t.codec_params
207 .as_ref()
208 .and_then(|cp| cp.audio())
209 .map(|ap| ap.codec != CODEC_ID_NULL_AUDIO)
210 .unwrap_or(false)
211 })
212 .ok_or_else(|| anyhow!("No audio tracks found in '{}'", path.display()))?;
213
214 let track_id = track.id;
215 let codec_params = track
216 .codec_params
217 .as_ref()
218 .and_then(|cp| cp.audio())
219 .ok_or_else(|| anyhow!("Missing codec parameters in '{}'", path.display()))?;
220
221 let sample_rate = codec_params
222 .sample_rate
223 .ok_or_else(|| anyhow!("Unknown sample rate in '{}'", path.display()))?;
224 let channels = codec_params
225 .channels
226 .as_ref()
227 .ok_or_else(|| anyhow!("Unknown channel count in '{}'", path.display()))?
228 .count() as u16;
229
230 let mut decoder = symphonia::default::get_codecs()
231 .make_audio_decoder(codec_params, &AudioDecoderOptions::default())
232 .map_err(|e| anyhow!("Failed to create decoder for '{}': {}", path.display(), e))?;
233
234 let mut samples: Vec<f32> = Vec::new();
235
236 loop {
237 let packet = match format.next_packet() {
238 Ok(Some(p)) => p,
239 Ok(None) => {
240 break;
241 }
242 Err(symphonia::core::errors::Error::ResetRequired) => {
243 continue;
244 }
245 Err(e) => {
246 return Err(anyhow!("Error reading '{}': {}", path.display(), e));
247 }
248 };
249
250 if packet.track_id != track_id {
251 continue;
252 }
253
254 let decoded = match decoder.decode(&packet) {
255 Ok(d) => d,
256 Err(symphonia::core::errors::Error::IoError(_)) => continue,
257 Err(e) => return Err(anyhow!("Decode error in '{}': {}", path.display(), e)),
258 };
259
260 decoded.copy_to_vec_interleaved::<f32>(&mut samples);
261 }
262
263 Ok(AudioData {
264 samples,
265 sample_rate,
266 channels,
267 })
268}
269
270pub fn prepare_centered_samples_raw(samples: &[f32], n_fft: usize) -> Vec<f32> {
276 let pad_len = n_fft / 2;
277 let n = samples.len();
278 let mut centered = Vec::with_capacity(n + 2 * pad_len);
279 for i in (1..=pad_len).rev() {
280 centered.push(samples[i]);
281 }
282 centered.extend_from_slice(samples);
283 for i in 0..pad_len {
284 centered.push(samples[n - 2 - i]);
285 }
286 centered
287}
288
289fn prepare_centered_samples(audio: &AudioData, n_fft: usize) -> Result<Vec<f32>> {
294 let audio = if audio.channels != 1 || audio.sample_rate != WHISPER_SAMPLE_RATE {
295 audio.to_16khz_mono()?
296 } else {
297 audio.clone()
298 };
299
300 let target_samples = 30 * WHISPER_SAMPLE_RATE as usize;
301 let mut padded = audio.samples;
302 if padded.len() > target_samples {
303 padded.truncate(target_samples);
304 } else {
305 padded.resize(target_samples, 0.0);
306 }
307
308 Ok(prepare_centered_samples_raw(&padded, n_fft))
309}
310
311fn mel_compress<B: Backend>(
316 ps: Tensor<B, 2>,
317 n_mels: usize,
318 n_fft: usize,
319 device: &B::Device,
320) -> Tensor<B, 3> {
321 let [n_freqs, n_frames] = ps.dims();
322 let mel_filters = create_mel_filter_bank(n_fft, n_mels, WHISPER_SAMPLE_RATE as f32);
323 let mf_flat: Vec<f32> = mel_filters.into_iter().flatten().collect();
324 let mf_tensor: Tensor<B, 2> =
325 Tensor::<B, 1>::from_floats(mf_flat.as_slice(), device).reshape([n_mels, n_freqs]);
326
327 let mel: Tensor<B, 2> = mf_tensor.matmul(ps);
328
329 let log10_e = std::f32::consts::LOG10_E;
330 let log_mel: Tensor<B, 2> = mel.clamp_min(1e-10_f32).log().mul_scalar(log10_e);
331
332 let max_val: Tensor<B, 2> = log_mel.clone().max().reshape([1, 1]);
333 let log_mel = (log_mel - max_val.clone())
334 .clamp_min(-8.0_f32)
335 .add(max_val)
336 .add_scalar(4.0_f32)
337 .div_scalar(4.0_f32);
338
339 log_mel.reshape([1, n_mels, n_frames])
340}
341
342pub fn compute_mel_spectrogram<B: Backend>(
357 audio: &AudioData,
358 n_fft: usize,
359 hop_length: usize,
360 n_mels: usize,
361 device: &B::Device,
362) -> Result<Tensor<B, 3>> {
363 let centered = prepare_centered_samples(audio, n_fft)?;
364 let magnitudes = compute_stft_magnitudes(¢ered, n_fft, hop_length);
365
366 let n_freqs = n_fft / 2 + 1;
367 let n_frames = magnitudes[0].len().saturating_sub(1);
368
369 let ps_flat: Vec<f32> = (0..n_freqs)
370 .flat_map(|f| magnitudes[f][..n_frames].iter().copied())
371 .collect();
372 let ps_tensor: Tensor<B, 2> =
373 Tensor::<B, 1>::from_floats(ps_flat.as_slice(), device).reshape([n_freqs, n_frames]);
374
375 Ok(mel_compress(ps_tensor, n_mels, n_fft, device))
376}
377
378pub fn compute_mel_from_samples<B: Backend>(
383 samples: &[f32],
384 n_fft: usize,
385 hop_length: usize,
386 n_mels: usize,
387 device: &B::Device,
388) -> Result<Tensor<B, 3>> {
389 let expected = 30 * WHISPER_SAMPLE_RATE as usize;
390 anyhow::ensure!(
391 samples.len() == expected,
392 "compute_mel_from_samples: expected {} samples, got {}",
393 expected,
394 samples.len()
395 );
396 let centered = prepare_centered_samples_raw(samples, n_fft);
397 let magnitudes = compute_stft_magnitudes(¢ered, n_fft, hop_length);
398 let n_freqs = n_fft / 2 + 1;
399 let n_frames = magnitudes[0].len().saturating_sub(1);
400 let ps_flat: Vec<f32> = (0..n_freqs)
401 .flat_map(|f| magnitudes[f][..n_frames].iter().copied())
402 .collect();
403 let ps_tensor: Tensor<B, 2> =
404 Tensor::<B, 1>::from_floats(ps_flat.as_slice(), device).reshape([n_freqs, n_frames]);
405 Ok(mel_compress(ps_tensor, n_mels, n_fft, device))
406}
407
408#[cfg(feature = "cubecl-stft")]
414pub fn compute_mel_spectrogram_wgpu(
415 audio: &AudioData,
416 n_fft: usize,
417 hop_length: usize,
418 n_mels: usize,
419 device: &burn_wgpu::WgpuDevice,
420) -> Result<Tensor<WgpuBackend, 3>> {
421 use cubecl::prelude::Runtime;
422 let centered = prepare_centered_samples(audio, n_fft)?;
423 let n_freqs = n_fft / 2 + 1;
424 let n_frames_total = (centered.len() - n_fft) / hop_length + 1;
425 let n_frames = n_frames_total.saturating_sub(1);
426
427 let client = burn_wgpu::WgpuRuntime::client(device);
428 let gpu_out = crate::stft_gpu::compute_stft_power_gpu(
429 &client,
430 ¢ered,
431 n_fft,
432 hop_length,
433 n_frames_total,
434 );
435
436 let ps_tensor: Tensor<WgpuBackend, 2> =
439 Tensor::<WgpuBackend, 1>::from_floats(&gpu_out[..n_frames * n_freqs], device)
440 .reshape([n_frames, n_freqs])
441 .transpose();
442
443 Ok(mel_compress(ps_tensor, n_mels, n_fft, device))
444}
445
446#[cfg(feature = "cubecl-stft")]
451pub fn compute_mel_from_samples_wgpu(
452 samples: &[f32],
453 n_fft: usize,
454 hop_length: usize,
455 n_mels: usize,
456 device: &burn_wgpu::WgpuDevice,
457) -> Result<Tensor<WgpuBackend, 3>> {
458 use cubecl::prelude::Runtime;
459 let expected = 30 * WHISPER_SAMPLE_RATE as usize;
460 anyhow::ensure!(
461 samples.len() == expected,
462 "compute_mel_from_samples_wgpu: expected {} samples, got {}",
463 expected,
464 samples.len()
465 );
466 let centered = prepare_centered_samples_raw(samples, n_fft);
467 let n_freqs = n_fft / 2 + 1;
468 let n_frames_total = (centered.len() - n_fft) / hop_length + 1;
469 let n_frames = n_frames_total.saturating_sub(1);
470 let client = burn_wgpu::WgpuRuntime::client(device);
471 let gpu_out = crate::stft_gpu::compute_stft_power_gpu(
472 &client,
473 ¢ered,
474 n_fft,
475 hop_length,
476 n_frames_total,
477 );
478 let ps_tensor: Tensor<WgpuBackend, 2> =
479 Tensor::<WgpuBackend, 1>::from_floats(&gpu_out[..n_frames * n_freqs], device)
480 .reshape([n_frames, n_freqs])
481 .transpose();
482 Ok(mel_compress(ps_tensor, n_mels, n_fft, device))
483}
484
485#[cfg(feature = "cubecl-stft")]
487pub type WgpuBackend = burn_wgpu::CubeBackend<burn_wgpu::WgpuRuntime, f32, i32, u32>;
488
489pub fn batch_mel_spectrograms<B: Backend>(
498 chunks: &[AudioData],
499 n_fft: usize,
500 hop_length: usize,
501 n_mels: usize,
502 device: &B::Device,
503) -> Result<Tensor<B, 3>> {
504 anyhow::ensure!(!chunks.is_empty(), "batch_mel_spectrograms: no chunks");
505 let mels: Vec<Tensor<B, 3>> = chunks
506 .iter()
507 .map(|c| compute_mel_spectrogram(c, n_fft, hop_length, n_mels, device))
508 .collect::<Result<_>>()?;
509 Ok(Tensor::cat(mels, 0))
510}
511
512#[cfg(feature = "cubecl-stft")]
517pub fn batch_mel_spectrograms_wgpu(
518 chunks: &[AudioData],
519 n_fft: usize,
520 hop_length: usize,
521 n_mels: usize,
522 device: &burn_wgpu::WgpuDevice,
523) -> Result<Tensor<WgpuBackend, 3>> {
524 anyhow::ensure!(!chunks.is_empty(), "batch_mel_spectrograms_wgpu: no chunks");
525 let mels: Vec<Tensor<WgpuBackend, 3>> = chunks
526 .iter()
527 .map(|c| compute_mel_spectrogram_wgpu(c, n_fft, hop_length, n_mels, device))
528 .collect::<Result<_>>()?;
529 Ok(Tensor::cat(mels, 0))
530}
531
532#[allow(clippy::needless_range_loop)]
535fn compute_stft_magnitudes(samples: &[f32], n_fft: usize, hop_length: usize) -> Vec<Vec<f32>> {
536 let n_freqs = n_fft / 2 + 1;
537 let n_frames = if samples.len() >= n_fft {
538 (samples.len() - n_fft) / hop_length + 1
539 } else {
540 0
541 };
542
543 if n_frames == 0 {
544 return vec![vec![0.0]; n_freqs];
545 }
546
547 let window: Vec<f32> = (0..n_fft)
549 .map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / n_fft as f32).cos()))
550 .collect();
551
552 let mut planner = FftPlanner::<f32>::new();
554 let fft = planner.plan_fft_forward(n_fft);
555
556 let mut magnitudes = vec![vec![0.0f32; n_frames]; n_freqs];
558
559 let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n_fft];
561
562 for frame_idx in 0..n_frames {
563 let start = frame_idx * hop_length;
564
565 for i in 0..n_fft {
567 let sample = if start + i < samples.len() {
568 samples[start + i]
569 } else {
570 0.0
571 };
572 buffer[i] = Complex::new(sample * window[i], 0.0);
573 }
574
575 fft.process(&mut buffer);
577
578 for freq in 0..n_freqs {
581 magnitudes[freq][frame_idx] = buffer[freq].norm_sqr();
582 }
583 }
584
585 magnitudes
586}
587
588fn create_mel_filter_bank(n_fft: usize, n_mels: usize, sample_rate: f32) -> Vec<Vec<f32>> {
596 let n_freqs = n_fft / 2 + 1;
597 let fmax = sample_rate / 2.0;
598
599 let f_sp: f32 = 200.0 / 3.0;
601 let min_log_hz: f32 = 1000.0;
602 let min_log_mel: f32 = min_log_hz / f_sp;
603 let logstep: f32 = 6.4f32.ln() / 27.0;
604
605 let hz_to_mel = |f: f32| -> f32 {
606 if f >= min_log_hz {
607 min_log_mel + (f / min_log_hz).ln() / logstep
608 } else {
609 f / f_sp
610 }
611 };
612 let mel_to_hz = |m: f32| -> f32 {
613 if m >= min_log_mel {
614 min_log_hz * ((m - min_log_mel) * logstep).exp()
615 } else {
616 f_sp * m
617 }
618 };
619
620 let mel_min = hz_to_mel(0.0);
622 let mel_max = hz_to_mel(fmax);
623 let hz_pts: Vec<f32> = (0..=n_mels + 1)
624 .map(|i| mel_to_hz(mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32))
625 .collect();
626
627 let fftfreqs: Vec<f32> = (0..n_freqs)
629 .map(|k| k as f32 * sample_rate / n_fft as f32)
630 .collect();
631
632 let mut filters = vec![vec![0.0f32; n_freqs]; n_mels];
633 for (i, filt) in filters.iter_mut().enumerate() {
634 let lower = hz_pts[i];
635 let center = hz_pts[i + 1];
636 let upper = hz_pts[i + 2];
637 let enorm = 2.0 / (upper - lower).max(1e-8);
639 for (k, &freq) in fftfreqs.iter().enumerate() {
640 let rising = if center > lower {
641 ((freq - lower) / (center - lower)).max(0.0)
642 } else {
643 0.0
644 };
645 let falling = if upper > center {
646 ((upper - freq) / (upper - center)).max(0.0)
647 } else {
648 0.0
649 };
650 filt[k] = rising.min(falling) * enorm;
651 }
652 }
653
654 filters
655}
656
657pub fn pad_or_trim_audio(audio: &AudioData, length_samples: usize) -> AudioData {
659 let mut samples = audio.samples.clone();
660
661 if samples.len() > length_samples {
662 samples.truncate(length_samples);
663 } else if samples.len() < length_samples {
664 samples.resize(length_samples, 0.0);
665 }
666
667 AudioData {
668 samples,
669 sample_rate: audio.sample_rate,
670 channels: audio.channels,
671 }
672}
673
674#[cfg(test)]
675mod tests {
676 use super::*;
677 use burn_flex::FlexDevice;
678
679 #[test]
680 fn test_audio_data_creation() {
681 let audio = AudioData::new(vec![0.0, 0.5, -0.5, 1.0], 44100, 2);
682 assert_eq!(audio.duration(), 2.0 / 44100.0);
683 assert_eq!(audio.channels, 2);
684 }
685
686 #[test]
687 fn test_mono_conversion() {
688 let stereo = AudioData::new(vec![1.0, 2.0, 3.0, 4.0], 44100, 2);
689 let mono = stereo.to_mono();
690 assert_eq!(mono.samples, vec![1.5, 3.5]);
691 assert_eq!(mono.channels, 1);
692 }
693
694 #[test]
695 fn test_hann_window() {
696 let n = 400;
698 let window: Vec<f32> = (0..n)
699 .map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / n as f32).cos()))
700 .collect();
701
702 assert!(window[0] < 0.01);
704 assert!(window[n - 1] < 0.01);
705
706 let center = n / 2;
708 assert!(window[center] > 0.99);
709 }
710
711 #[test]
712 fn test_mel_filter_bank() {
713 let filters = create_mel_filter_bank(400, 80, 16000.0);
714
715 assert_eq!(filters.len(), 80);
717
718 assert_eq!(filters[0].len(), 201);
720
721 for filter in &filters {
723 for &val in filter {
724 assert!(val >= 0.0, "Filter values should be non-negative");
725 }
726 }
727
728 for filter in &filters {
731 let max_val = filter.iter().cloned().fold(0.0f32, f32::max);
732 assert!(
734 max_val > 0.0,
735 "Filter should have at least one non-zero bin"
736 );
737 assert!(
738 max_val < 1.0,
739 "Filter peak should be less than 1.0 (Slaney norm)"
740 );
741 }
742 }
743
744 #[test]
745 fn test_stft_magnitudes() {
746 let sample_rate = 16000.0;
748 let freq = 440.0; let duration = 0.1; let n_samples = (sample_rate * duration) as usize;
751
752 let samples: Vec<f32> = (0..n_samples)
753 .map(|i| (2.0 * PI * freq * i as f32 / sample_rate).sin())
754 .collect();
755
756 let magnitudes = compute_stft_magnitudes(&samples, 400, 160);
757
758 assert_eq!(magnitudes.len(), 201);
760
761 assert!(!magnitudes[0].is_empty(), "Should have at least one frame");
763
764 for freq_bin in &magnitudes {
766 for &mag in freq_bin {
767 assert!(mag >= 0.0, "Magnitudes should be non-negative");
768 }
769 }
770 }
771
772 #[test]
773 fn test_mel_spectrogram() {
774 let audio = AudioData::new(vec![0.0; 16000], 16000, 1); let result = compute_mel_spectrogram::<burn_flex::Flex<f32>>(
776 &audio,
777 WHISPER_N_FFT,
778 WHISPER_HOP_LENGTH,
779 WHISPER_N_MELS,
780 &FlexDevice,
781 );
782
783 assert!(result.is_ok());
784 let tensor = result.unwrap();
785 let dims = tensor.dims();
786
787 assert_eq!(dims[0], 1, "Batch size should be 1");
789 assert_eq!(
790 dims[1], WHISPER_N_MELS,
791 "Should have {} mel bins",
792 WHISPER_N_MELS
793 );
794
795 assert_eq!(dims[2], 3000, "Should always return 3000 mel frames");
798 }
799
800 #[test]
801 fn test_mel_spectrogram_with_sine() {
802 let sample_rate = 16000;
804 let freq = 440.0;
805 let duration = 1.0;
806 let n_samples = (sample_rate as f32 * duration) as usize;
807
808 let samples: Vec<f32> = (0..n_samples)
809 .map(|i| 0.5 * (2.0 * PI * freq * i as f32 / sample_rate as f32).sin())
810 .collect();
811
812 let audio = AudioData::new(samples, sample_rate, 1);
813 let result = compute_mel_spectrogram::<burn_flex::Flex<f32>>(
814 &audio,
815 WHISPER_N_FFT,
816 WHISPER_HOP_LENGTH,
817 WHISPER_N_MELS,
818 &FlexDevice,
819 );
820
821 assert!(result.is_ok());
822 let tensor = result.unwrap();
823
824 let data = tensor.to_data();
826 let values: Vec<f32> = data.to_vec().unwrap();
827
828 for &val in &values {
829 assert!(val.is_finite(), "All values should be finite");
830 assert!(
832 (-2.0..=2.0).contains(&val),
833 "Values should be in reasonable range, got {}",
834 val
835 );
836 }
837 }
838
839 #[test]
840 fn test_pad_or_trim() {
841 let audio = AudioData::new(vec![1.0, 2.0, 3.0], 16000, 1);
842
843 let trimmed = pad_or_trim_audio(&audio, 2);
845 assert_eq!(trimmed.samples, vec![1.0, 2.0]);
846
847 let padded = pad_or_trim_audio(&audio, 5);
849 assert_eq!(padded.samples, vec![1.0, 2.0, 3.0, 0.0, 0.0]);
850 }
851
852 #[test]
853 fn test_compute_mel_from_samples_matches_audio_data() {
854 let samples: Vec<f32> = (0..480_000)
856 .map(|i| 0.5 * (2.0 * PI * 440.0 * i as f32 / 16000.0).sin())
857 .collect();
858
859 let audio = AudioData::new(samples.clone(), WHISPER_SAMPLE_RATE, 1);
860
861 let mel_via_audio = compute_mel_spectrogram::<burn_flex::Flex<f32>>(
862 &audio,
863 WHISPER_N_FFT,
864 WHISPER_HOP_LENGTH,
865 WHISPER_N_MELS,
866 &FlexDevice,
867 )
868 .unwrap();
869
870 let mel_via_raw = compute_mel_from_samples::<burn_flex::Flex<f32>>(
871 &samples,
872 WHISPER_N_FFT,
873 WHISPER_HOP_LENGTH,
874 WHISPER_N_MELS,
875 &FlexDevice,
876 )
877 .unwrap();
878
879 let a: Vec<f32> = mel_via_audio.to_data().to_vec().unwrap();
880 let b: Vec<f32> = mel_via_raw.to_data().to_vec().unwrap();
881 assert_eq!(a.len(), b.len(), "shape mismatch");
882 for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
883 assert_eq!(x, y, "mismatch at index {i}");
884 }
885 }
886}