Skip to main content

whisperforge_core/
load.rs

1// Model loading utilities for Whisper models saved in MessagePack format
2// Compatible with whisper-burn pre-converted models from HuggingFace
3
4use anyhow::{Context, Result};
5use burn::{
6    module::{Module, ModuleMapper, Param},
7    record::{FullPrecisionSettings, NamedMpkBytesRecorder, Recorder},
8    tensor::{Tensor, backend::Backend},
9};
10use burn_flex::{Flex, FlexDevice};
11use serde::Deserialize;
12
13#[cfg(feature = "file-io")]
14use std::path::Path;
15
16use crate::model::{AudioEncoderConfig, TextDecoderConfig, Whisper, WhisperConfig};
17
18/// Model precision for quantization
19#[derive(Debug, Clone, Copy, Deserialize)]
20#[serde(rename_all = "lowercase")]
21pub enum ModelPrecision {
22    Fp32,
23    Int8,
24}
25
26/// Configuration file format for whisper-burn models (.cfg files)
27#[derive(Debug, Clone, Deserialize)]
28pub struct WhisperModelConfig {
29    pub audio_encoder_config: AudioEncoderConfigFile,
30    pub text_decoder_config: TextDecoderConfigFile,
31    #[serde(default)]
32    pub precision: Option<ModelPrecision>,
33}
34
35#[derive(Debug, Clone, Deserialize)]
36pub struct AudioEncoderConfigFile {
37    pub n_mels: usize,
38    pub n_audio_ctx: usize,
39    pub n_audio_state: usize,
40    pub n_audio_head: usize,
41    pub n_audio_layer: usize,
42}
43
44#[derive(Debug, Clone, Deserialize)]
45pub struct TextDecoderConfigFile {
46    pub n_vocab: usize,
47    pub n_text_ctx: usize,
48    pub n_text_state: usize,
49    pub n_text_head: usize,
50    pub n_text_layer: usize,
51}
52
53impl From<WhisperModelConfig> for WhisperConfig {
54    fn from(cfg: WhisperModelConfig) -> Self {
55        WhisperConfig {
56            audio_encoder_config: AudioEncoderConfig {
57                n_mels: cfg.audio_encoder_config.n_mels,
58                n_audio_ctx: cfg.audio_encoder_config.n_audio_ctx,
59                n_audio_state: cfg.audio_encoder_config.n_audio_state,
60                n_audio_head: cfg.audio_encoder_config.n_audio_head,
61                n_audio_layer: cfg.audio_encoder_config.n_audio_layer,
62            },
63            text_decoder_config: TextDecoderConfig {
64                n_vocab: cfg.text_decoder_config.n_vocab,
65                n_text_ctx: cfg.text_decoder_config.n_text_ctx,
66                n_text_state: cfg.text_decoder_config.n_text_state,
67                n_text_head: cfg.text_decoder_config.n_text_head,
68                n_text_layer: cfg.text_decoder_config.n_text_layer,
69            },
70        }
71    }
72}
73
74/// Load a Whisper config from raw JSON bytes.
75///
76/// Accepts the contents of a `.cfg` file as a byte slice. This is the
77/// WASM-compatible path — no filesystem access required.
78pub fn load_config_from_bytes(bytes: &[u8]) -> Result<WhisperConfig> {
79    let file_config: WhisperModelConfig =
80        serde_json::from_slice(bytes).with_context(|| "Failed to parse config JSON from bytes")?;
81    Ok(file_config.into())
82}
83
84/// Dequantizes INT8 tensor params back to FP32, used inside [`dequantize_weights_to_fp32`].
85struct Dequantizer;
86
87impl<B: Backend> ModuleMapper<B> for Dequantizer {
88    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
89        let (id, tensor, mapper) = param.consume();
90        Param::from_mapped_value(id, tensor.dequantize(), mapper)
91    }
92}
93
94/// Load an INT8-quantized model on the CPU (burn-flex), dequantize all weights to FP32,
95/// and re-serialize to in-memory bytes. The returned bytes are a plain FP32 NamedMpk
96/// that any backend can load without needing INT8 kernel support.
97///
98/// WGPU/WGSL has no i8 element type, so this indirection is required for quantized models
99/// on the GPU path. Only called when the .cfg reports `precision: int8`.
100fn dequantize_weights_to_fp32(config: &WhisperConfig, bytes: Vec<u8>) -> Result<Vec<u8>> {
101    type Cpu = Flex<f32>;
102    let device = FlexDevice;
103    let recorder = NamedMpkBytesRecorder::<FullPrecisionSettings>::new();
104
105    let cpu_model: Whisper<Cpu> = config.init::<Cpu>(&device);
106    let record = recorder
107        .load(bytes, &device)
108        .map_err(|e| anyhow::anyhow!("Failed to load quantized weights on CPU: {:?}", e))?;
109    let cpu_model = cpu_model.load_record(record).map(&mut Dequantizer);
110
111    recorder
112        .record(cpu_model.into_record(), ())
113        .map_err(|e| anyhow::anyhow!("Failed to re-serialize dequantized weights: {:?}", e))
114}
115
116/// Load a Whisper model from in-memory NamedMpk bytes.
117///
118/// `config` is the parsed model config (from [`load_config_from_bytes`]).
119/// `weights` is the raw contents of a `.mpk` file.
120/// `precision` is the optional precision from the `.cfg` sidecar; pass `None` when unknown.
121///
122/// INT8-quantized models are transparently dequantized on the CPU before being loaded
123/// onto `device`, so this works on backends (e.g. WGPU) that have no INT8 kernel support.
124pub fn load_whisper_from_bytes<B: Backend>(
125    config: &WhisperConfig,
126    weights: Vec<u8>,
127    precision: Option<ModelPrecision>,
128    device: &B::Device,
129) -> Result<Whisper<B>> {
130    let weights = match precision {
131        Some(ModelPrecision::Int8) => dequantize_weights_to_fp32(config, weights)?,
132        _ => weights,
133    };
134    let model = config.init::<B>(device);
135    let recorder = NamedMpkBytesRecorder::<FullPrecisionSettings>::new();
136    let model = model.load_record(
137        recorder
138            .load(weights, device)
139            .map_err(|e| anyhow::anyhow!("Failed to load model weights from bytes: {:?}", e))?,
140    );
141    Ok(model)
142}
143
144/// Load a Whisper model from whisper-burn format (.mpk + .cfg files).
145///
146/// # Arguments
147/// * `model_path` - Path to the model files (without extension); `.cfg` and `.mpk` are appended
148/// * `device` - The device to load the model onto
149#[cfg(feature = "file-io")]
150pub fn load_whisper<B: Backend>(model_path: &str, device: &B::Device) -> Result<Whisper<B>> {
151    let model_path = Path::new(model_path);
152
153    let base_path = if model_path.extension().is_some() {
154        model_path.with_extension("")
155    } else {
156        model_path.to_path_buf()
157    };
158
159    let config_path = base_path.with_extension("cfg");
160    let weights_path = base_path.with_extension("mpk");
161
162    let config_bytes = std::fs::read(&config_path)
163        .with_context(|| format!("Failed to read config file: {}", config_path.display()))?;
164    let weights_bytes = std::fs::read(&weights_path)
165        .with_context(|| format!("Failed to read weights file: {}", weights_path.display()))?;
166
167    let file_config: WhisperModelConfig =
168        serde_json::from_slice(&config_bytes).with_context(|| "Failed to parse config JSON")?;
169    let precision = file_config.precision;
170    let config: WhisperConfig = file_config.into();
171    load_whisper_from_bytes(&config, weights_bytes, precision, device)
172}
173
174/// Load just the config from a `.cfg` file path.
175#[cfg(feature = "file-io")]
176pub fn load_config(config_path: &str) -> Result<WhisperConfig> {
177    let config_path = Path::new(config_path);
178
179    let config_path = if config_path.extension().map(|e| e == "cfg").unwrap_or(false) {
180        config_path.to_path_buf()
181    } else {
182        config_path.with_extension("cfg")
183    };
184
185    let bytes = std::fs::read(&config_path)
186        .with_context(|| format!("Failed to read config file: {}", config_path.display()))?;
187    load_config_from_bytes(&bytes)
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use std::path::PathBuf;
194
195    fn models_dir() -> PathBuf {
196        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
197            .parent()
198            .unwrap()
199            .join("models")
200    }
201
202    #[test]
203    #[ignore = "slow: initialises base model (6-layer 512-dim) on NdArray CPU (~10 min)"]
204    fn test_layer_norm_dims_match_loaded_config() {
205        use burn_flex::Flex;
206        use burn_flex::FlexDevice;
207
208        let device = FlexDevice;
209        // base has n_audio_state=512; tiny_en has 384. The bug hardcoded tiny_en, so a base
210        // model would get ln_post with gamma shape [384] instead of [512] and panic on forward.
211        let config = WhisperConfig::base();
212        assert_ne!(
213            config.audio_encoder_config.n_audio_state,
214            WhisperConfig::tiny_en().audio_encoder_config.n_audio_state,
215            "test precondition: base and tiny_en must have different state dims"
216        );
217
218        let mut model = config.init::<Flex<f32>>(&device);
219        let fresh = config.init::<Flex<f32>>(&device);
220        model.decoder.ln = fresh.decoder.ln;
221        model.encoder.ln_post = fresh.encoder.ln_post;
222
223        // Encoder forward would panic if ln_post had wrong dims (384 vs expected 512).
224        let mel = burn::tensor::Tensor::<Flex<f32>, 3>::zeros([1, 80, 3000], &device);
225        let out = model.forward_encoder(mel);
226        assert_eq!(out.dims(), [1, 1500, 512]);
227    }
228
229    #[test]
230    fn test_load_config() {
231        let config_path = models_dir().join("tiny_en.cfg");
232        if !config_path.exists() {
233            eprintln!("Skipping test: model files not found at {:?}", config_path);
234            return;
235        }
236
237        let config = load_config(config_path.to_str().unwrap()).unwrap();
238
239        assert_eq!(config.audio_encoder_config.n_mels, 80);
240        assert_eq!(config.audio_encoder_config.n_audio_state, 384);
241        assert_eq!(config.audio_encoder_config.n_audio_head, 6);
242        assert_eq!(config.audio_encoder_config.n_audio_layer, 4);
243        assert_eq!(config.text_decoder_config.n_vocab, 51864);
244        assert_eq!(config.text_decoder_config.n_text_state, 384);
245    }
246
247    #[test]
248    fn test_load_whisper_model() {
249        use burn_flex::Flex;
250        use burn_flex::FlexDevice;
251
252        // Per-model layout: weights/config live at `<name>/model.{mpk,cfg}`.
253        let model_path = models_dir().join("tiny_en_converted").join("model");
254        if !model_path.with_extension("mpk").exists() {
255            eprintln!("Skipping test: model files not found at {:?}", model_path);
256            eprintln!(
257                "Run `cargo test -p whisperforge test_convert_tiny_en` first to generate the model"
258            );
259            return;
260        }
261
262        let device = FlexDevice;
263        let model = load_whisper::<Flex<f32>>(model_path.to_str().unwrap(), &device);
264
265        match model {
266            Ok(m) => {
267                assert_eq!(m.encoder.n_mels, 80);
268                assert_eq!(m.decoder.n_vocab, 51864);
269                println!("Model loaded successfully!");
270            }
271            Err(e) => {
272                // Print detailed error for debugging
273                panic!("Failed to load model: {:?}", e);
274            }
275        }
276    }
277
278    /// Load base weights (6-layer, n_audio_state=512) and verify encoder output shape.
279    ///
280    /// Run locally with model files present:
281    /// `cargo test --release -p whisperforge-core -- --ignored test_load_base_model_and_encoder_forward --nocapture`
282    ///
283    /// medium and large-v2 are excluded from CI (too large for automated test infra).
284    #[test]
285    #[ignore = "requires models/base_converted.{mpk,cfg} — git-ignored; convert from HuggingFace first"]
286    fn test_load_base_model_and_encoder_forward() -> Result<()> {
287        use burn_flex::Flex;
288        use burn_flex::FlexDevice;
289
290        let model_path = models_dir().join("base_converted").join("model");
291        if !model_path.with_extension("mpk").exists() {
292            eprintln!(
293                "Skipping: {:?}.mpk not found. Convert from HuggingFace first.",
294                model_path
295            );
296            return Ok(());
297        }
298
299        let device = FlexDevice;
300        let m = load_whisper::<Flex<f32>>(model_path.to_str().unwrap(), &device)?;
301        assert_eq!(m.encoder.n_mels, 80);
302
303        let mel = burn::tensor::Tensor::<Flex<f32>, 3>::zeros([1, 80, 3000], &device);
304        let out = m.forward_encoder(mel);
305        // base: n_audio_ctx=1500, n_audio_state=512
306        assert_eq!(out.dims(), [1, 1500, 512]);
307
308        Ok(())
309    }
310
311    /// Load small weights (12-layer, n_audio_state=768) and verify encoder output shape.
312    ///
313    /// Run locally with model files present:
314    /// `cargo test --release -p whisperforge-core -- --ignored test_load_small_model_and_encoder_forward --nocapture`
315    ///
316    /// medium and large-v2 are excluded from CI (too large for automated test infra).
317    #[test]
318    #[ignore = "requires models/small_converted.{mpk,cfg} — git-ignored; convert from HuggingFace first"]
319    fn test_load_small_model_and_encoder_forward() -> Result<()> {
320        use burn_flex::Flex;
321        use burn_flex::FlexDevice;
322
323        let model_path = models_dir().join("small_converted").join("model");
324        if !model_path.with_extension("mpk").exists() {
325            eprintln!(
326                "Skipping: {:?}.mpk not found. Convert from HuggingFace first.",
327                model_path
328            );
329            return Ok(());
330        }
331
332        let device = FlexDevice;
333        let m = load_whisper::<Flex<f32>>(model_path.to_str().unwrap(), &device)?;
334        assert_eq!(m.encoder.n_mels, 80);
335
336        let mel = burn::tensor::Tensor::<Flex<f32>, 3>::zeros([1, 80, 3000], &device);
337        let out = m.forward_encoder(mel);
338        // small: n_audio_ctx=1500, n_audio_state=768
339        assert_eq!(out.dims(), [1, 1500, 768]);
340
341        Ok(())
342    }
343}