1use anyhow::{Context, Result};
5use burn::{
6 module::{Module, ModuleMapper, Param},
7 record::{FullPrecisionSettings, NamedMpkBytesRecorder, Recorder},
8 tensor::{Tensor, backend::Backend},
9};
10use burn_flex::{Flex, FlexDevice};
11use serde::Deserialize;
12
13#[cfg(feature = "file-io")]
14use std::path::Path;
15
16use crate::model::{AudioEncoderConfig, TextDecoderConfig, Whisper, WhisperConfig};
17
18#[derive(Debug, Clone, Copy, Deserialize)]
20#[serde(rename_all = "lowercase")]
21pub enum ModelPrecision {
22 Fp32,
23 Int8,
24}
25
26#[derive(Debug, Clone, Deserialize)]
28pub struct WhisperModelConfig {
29 pub audio_encoder_config: AudioEncoderConfigFile,
30 pub text_decoder_config: TextDecoderConfigFile,
31 #[serde(default)]
32 pub precision: Option<ModelPrecision>,
33}
34
35#[derive(Debug, Clone, Deserialize)]
36pub struct AudioEncoderConfigFile {
37 pub n_mels: usize,
38 pub n_audio_ctx: usize,
39 pub n_audio_state: usize,
40 pub n_audio_head: usize,
41 pub n_audio_layer: usize,
42}
43
44#[derive(Debug, Clone, Deserialize)]
45pub struct TextDecoderConfigFile {
46 pub n_vocab: usize,
47 pub n_text_ctx: usize,
48 pub n_text_state: usize,
49 pub n_text_head: usize,
50 pub n_text_layer: usize,
51}
52
53impl From<WhisperModelConfig> for WhisperConfig {
54 fn from(cfg: WhisperModelConfig) -> Self {
55 WhisperConfig {
56 audio_encoder_config: AudioEncoderConfig {
57 n_mels: cfg.audio_encoder_config.n_mels,
58 n_audio_ctx: cfg.audio_encoder_config.n_audio_ctx,
59 n_audio_state: cfg.audio_encoder_config.n_audio_state,
60 n_audio_head: cfg.audio_encoder_config.n_audio_head,
61 n_audio_layer: cfg.audio_encoder_config.n_audio_layer,
62 },
63 text_decoder_config: TextDecoderConfig {
64 n_vocab: cfg.text_decoder_config.n_vocab,
65 n_text_ctx: cfg.text_decoder_config.n_text_ctx,
66 n_text_state: cfg.text_decoder_config.n_text_state,
67 n_text_head: cfg.text_decoder_config.n_text_head,
68 n_text_layer: cfg.text_decoder_config.n_text_layer,
69 },
70 }
71 }
72}
73
74pub fn load_config_from_bytes(bytes: &[u8]) -> Result<WhisperConfig> {
79 let file_config: WhisperModelConfig =
80 serde_json::from_slice(bytes).with_context(|| "Failed to parse config JSON from bytes")?;
81 Ok(file_config.into())
82}
83
84struct Dequantizer;
86
87impl<B: Backend> ModuleMapper<B> for Dequantizer {
88 fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
89 let (id, tensor, mapper) = param.consume();
90 Param::from_mapped_value(id, tensor.dequantize(), mapper)
91 }
92}
93
94fn dequantize_weights_to_fp32(config: &WhisperConfig, bytes: Vec<u8>) -> Result<Vec<u8>> {
101 type Cpu = Flex<f32>;
102 let device = FlexDevice;
103 let recorder = NamedMpkBytesRecorder::<FullPrecisionSettings>::new();
104
105 let cpu_model: Whisper<Cpu> = config.init::<Cpu>(&device);
106 let record = recorder
107 .load(bytes, &device)
108 .map_err(|e| anyhow::anyhow!("Failed to load quantized weights on CPU: {:?}", e))?;
109 let cpu_model = cpu_model.load_record(record).map(&mut Dequantizer);
110
111 recorder
112 .record(cpu_model.into_record(), ())
113 .map_err(|e| anyhow::anyhow!("Failed to re-serialize dequantized weights: {:?}", e))
114}
115
116pub fn load_whisper_from_bytes<B: Backend>(
125 config: &WhisperConfig,
126 weights: Vec<u8>,
127 precision: Option<ModelPrecision>,
128 device: &B::Device,
129) -> Result<Whisper<B>> {
130 let weights = match precision {
131 Some(ModelPrecision::Int8) => dequantize_weights_to_fp32(config, weights)?,
132 _ => weights,
133 };
134 let model = config.init::<B>(device);
135 let recorder = NamedMpkBytesRecorder::<FullPrecisionSettings>::new();
136 let model = model.load_record(
137 recorder
138 .load(weights, device)
139 .map_err(|e| anyhow::anyhow!("Failed to load model weights from bytes: {:?}", e))?,
140 );
141 Ok(model)
142}
143
144#[cfg(feature = "file-io")]
150pub fn load_whisper<B: Backend>(model_path: &str, device: &B::Device) -> Result<Whisper<B>> {
151 let model_path = Path::new(model_path);
152
153 let base_path = if model_path.extension().is_some() {
154 model_path.with_extension("")
155 } else {
156 model_path.to_path_buf()
157 };
158
159 let config_path = base_path.with_extension("cfg");
160 let weights_path = base_path.with_extension("mpk");
161
162 let config_bytes = std::fs::read(&config_path)
163 .with_context(|| format!("Failed to read config file: {}", config_path.display()))?;
164 let weights_bytes = std::fs::read(&weights_path)
165 .with_context(|| format!("Failed to read weights file: {}", weights_path.display()))?;
166
167 let file_config: WhisperModelConfig =
168 serde_json::from_slice(&config_bytes).with_context(|| "Failed to parse config JSON")?;
169 let precision = file_config.precision;
170 let config: WhisperConfig = file_config.into();
171 load_whisper_from_bytes(&config, weights_bytes, precision, device)
172}
173
174#[cfg(feature = "file-io")]
176pub fn load_config(config_path: &str) -> Result<WhisperConfig> {
177 let config_path = Path::new(config_path);
178
179 let config_path = if config_path.extension().map(|e| e == "cfg").unwrap_or(false) {
180 config_path.to_path_buf()
181 } else {
182 config_path.with_extension("cfg")
183 };
184
185 let bytes = std::fs::read(&config_path)
186 .with_context(|| format!("Failed to read config file: {}", config_path.display()))?;
187 load_config_from_bytes(&bytes)
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use std::path::PathBuf;
194
195 fn models_dir() -> PathBuf {
196 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
197 .parent()
198 .unwrap()
199 .join("models")
200 }
201
202 #[test]
203 #[ignore = "slow: initialises base model (6-layer 512-dim) on NdArray CPU (~10 min)"]
204 fn test_layer_norm_dims_match_loaded_config() {
205 use burn_flex::Flex;
206 use burn_flex::FlexDevice;
207
208 let device = FlexDevice;
209 let config = WhisperConfig::base();
212 assert_ne!(
213 config.audio_encoder_config.n_audio_state,
214 WhisperConfig::tiny_en().audio_encoder_config.n_audio_state,
215 "test precondition: base and tiny_en must have different state dims"
216 );
217
218 let mut model = config.init::<Flex<f32>>(&device);
219 let fresh = config.init::<Flex<f32>>(&device);
220 model.decoder.ln = fresh.decoder.ln;
221 model.encoder.ln_post = fresh.encoder.ln_post;
222
223 let mel = burn::tensor::Tensor::<Flex<f32>, 3>::zeros([1, 80, 3000], &device);
225 let out = model.forward_encoder(mel);
226 assert_eq!(out.dims(), [1, 1500, 512]);
227 }
228
229 #[test]
230 fn test_load_config() {
231 let config_path = models_dir().join("tiny_en.cfg");
232 if !config_path.exists() {
233 eprintln!("Skipping test: model files not found at {:?}", config_path);
234 return;
235 }
236
237 let config = load_config(config_path.to_str().unwrap()).unwrap();
238
239 assert_eq!(config.audio_encoder_config.n_mels, 80);
240 assert_eq!(config.audio_encoder_config.n_audio_state, 384);
241 assert_eq!(config.audio_encoder_config.n_audio_head, 6);
242 assert_eq!(config.audio_encoder_config.n_audio_layer, 4);
243 assert_eq!(config.text_decoder_config.n_vocab, 51864);
244 assert_eq!(config.text_decoder_config.n_text_state, 384);
245 }
246
247 #[test]
248 fn test_load_whisper_model() {
249 use burn_flex::Flex;
250 use burn_flex::FlexDevice;
251
252 let model_path = models_dir().join("tiny_en_converted").join("model");
254 if !model_path.with_extension("mpk").exists() {
255 eprintln!("Skipping test: model files not found at {:?}", model_path);
256 eprintln!(
257 "Run `cargo test -p whisperforge test_convert_tiny_en` first to generate the model"
258 );
259 return;
260 }
261
262 let device = FlexDevice;
263 let model = load_whisper::<Flex<f32>>(model_path.to_str().unwrap(), &device);
264
265 match model {
266 Ok(m) => {
267 assert_eq!(m.encoder.n_mels, 80);
268 assert_eq!(m.decoder.n_vocab, 51864);
269 println!("Model loaded successfully!");
270 }
271 Err(e) => {
272 panic!("Failed to load model: {:?}", e);
274 }
275 }
276 }
277
278 #[test]
285 #[ignore = "requires models/base_converted.{mpk,cfg} — git-ignored; convert from HuggingFace first"]
286 fn test_load_base_model_and_encoder_forward() -> Result<()> {
287 use burn_flex::Flex;
288 use burn_flex::FlexDevice;
289
290 let model_path = models_dir().join("base_converted").join("model");
291 if !model_path.with_extension("mpk").exists() {
292 eprintln!(
293 "Skipping: {:?}.mpk not found. Convert from HuggingFace first.",
294 model_path
295 );
296 return Ok(());
297 }
298
299 let device = FlexDevice;
300 let m = load_whisper::<Flex<f32>>(model_path.to_str().unwrap(), &device)?;
301 assert_eq!(m.encoder.n_mels, 80);
302
303 let mel = burn::tensor::Tensor::<Flex<f32>, 3>::zeros([1, 80, 3000], &device);
304 let out = m.forward_encoder(mel);
305 assert_eq!(out.dims(), [1, 1500, 512]);
307
308 Ok(())
309 }
310
311 #[test]
318 #[ignore = "requires models/small_converted.{mpk,cfg} — git-ignored; convert from HuggingFace first"]
319 fn test_load_small_model_and_encoder_forward() -> Result<()> {
320 use burn_flex::Flex;
321 use burn_flex::FlexDevice;
322
323 let model_path = models_dir().join("small_converted").join("model");
324 if !model_path.with_extension("mpk").exists() {
325 eprintln!(
326 "Skipping: {:?}.mpk not found. Convert from HuggingFace first.",
327 model_path
328 );
329 return Ok(());
330 }
331
332 let device = FlexDevice;
333 let m = load_whisper::<Flex<f32>>(model_path.to_str().unwrap(), &device)?;
334 assert_eq!(m.encoder.n_mels, 80);
335
336 let mel = burn::tensor::Tensor::<Flex<f32>, 3>::zeros([1, 80, 3000], &device);
337 let out = m.forward_encoder(mel);
338 assert_eq!(out.dims(), [1, 1500, 768]);
340
341 Ok(())
342 }
343}