Skip to main content

whisper_cpp_plus/
quantize.rs

1//! Model quantization for reducing model size and improving inference speed.
2//!
3//! Provides functionality to quantize Whisper models to various bit depths,
4//! reducing model size while maintaining reasonable accuracy. Quantization is
5//! particularly useful for deployment on resource-constrained devices.
6//!
7//! Enable with the `quantization` feature flag:
8//! ```toml
9//! whisper-cpp-plus = { version = "0.1.5", features = ["quantization"] }
10//! ```
11
12use std::ffi::CString;
13use std::path::Path;
14use std::sync::Arc;
15use std::sync::Mutex;
16use thiserror::Error;
17
18use whisper_cpp_plus_sys as ffi;
19
20/// Error type for quantization operations
21#[derive(Debug, Error)]
22pub enum QuantizeError {
23    #[error("Model file not found: {0}")]
24    FileNotFound(String),
25
26    #[error("Failed to open file: {0}")]
27    FileOpenError(String),
28
29    #[error("Failed to write file: {0}")]
30    FileWriteError(String),
31
32    #[error("Invalid model format")]
33    InvalidModel,
34
35    #[error("Invalid quantization type")]
36    InvalidQuantizationType,
37
38    #[error("Quantization failed: {0}")]
39    QuantizationFailed(String),
40}
41
42type Result<T> = std::result::Result<T, QuantizeError>;
43
44/// Quantization types supported by whisper.cpp
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46#[repr(i32)]
47#[allow(non_camel_case_types)]
48pub enum QuantizationType {
49    /// 4-bit quantization (method 0) - ~3.5 GB for base model
50    Q4_0 = ffi::GGML_FTYPE_MOSTLY_Q4_0,
51
52    /// 4-bit quantization (method 1) - ~3.9 GB for base model
53    Q4_1 = ffi::GGML_FTYPE_MOSTLY_Q4_1,
54
55    /// 5-bit quantization (method 0) - ~4.3 GB for base model
56    Q5_0 = ffi::GGML_FTYPE_MOSTLY_Q5_0,
57
58    /// 5-bit quantization (method 1) - ~4.7 GB for base model
59    Q5_1 = ffi::GGML_FTYPE_MOSTLY_Q5_1,
60
61    /// 8-bit quantization - ~7.7 GB for base model
62    Q8_0 = ffi::GGML_FTYPE_MOSTLY_Q8_0,
63
64    /// 2-bit k-quantization
65    Q2_K = ffi::GGML_FTYPE_MOSTLY_Q2_K,
66
67    /// 3-bit k-quantization
68    Q3_K = ffi::GGML_FTYPE_MOSTLY_Q3_K,
69
70    /// 4-bit k-quantization
71    Q4_K = ffi::GGML_FTYPE_MOSTLY_Q4_K,
72
73    /// 5-bit k-quantization
74    Q5_K = ffi::GGML_FTYPE_MOSTLY_Q5_K,
75
76    /// 6-bit k-quantization
77    Q6_K = ffi::GGML_FTYPE_MOSTLY_Q6_K,
78}
79
80impl QuantizationType {
81    /// Get a human-readable name for the quantization type
82    pub fn name(&self) -> &'static str {
83        match self {
84            Self::Q4_0 => "Q4_0",
85            Self::Q4_1 => "Q4_1",
86            Self::Q5_0 => "Q5_0",
87            Self::Q5_1 => "Q5_1",
88            Self::Q8_0 => "Q8_0",
89            Self::Q2_K => "Q2_K",
90            Self::Q3_K => "Q3_K",
91            Self::Q4_K => "Q4_K",
92            Self::Q5_K => "Q5_K",
93            Self::Q6_K => "Q6_K",
94        }
95    }
96
97    /// Estimate the size reduction factor for this quantization type.
98    /// Returns the approximate size as a fraction of the original F32 model.
99    pub fn size_factor(&self) -> f32 {
100        match self {
101            Self::Q2_K => 0.19, // ~19% of original
102            Self::Q3_K => 0.26, // ~26% of original
103            Self::Q4_0 => 0.31, // ~31% of original
104            Self::Q4_1 => 0.35, // ~35% of original
105            Self::Q4_K => 0.33, // ~33% of original
106            Self::Q5_0 => 0.39, // ~39% of original
107            Self::Q5_1 => 0.43, // ~43% of original
108            Self::Q5_K => 0.41, // ~41% of original
109            Self::Q6_K => 0.49, // ~49% of original
110            Self::Q8_0 => 0.69, // ~69% of original
111        }
112    }
113
114    /// Get all available quantization types
115    pub fn all() -> &'static [QuantizationType] {
116        &[
117            Self::Q4_0,
118            Self::Q4_1,
119            Self::Q5_0,
120            Self::Q5_1,
121            Self::Q8_0,
122            Self::Q2_K,
123            Self::Q3_K,
124            Self::Q4_K,
125            Self::Q5_K,
126            Self::Q6_K,
127        ]
128    }
129}
130
131impl std::str::FromStr for QuantizationType {
132    type Err = QuantizeError;
133
134    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
135        match s.to_uppercase().as_str() {
136            "Q4_0" | "Q40" => Ok(Self::Q4_0),
137            "Q4_1" | "Q41" => Ok(Self::Q4_1),
138            "Q5_0" | "Q50" => Ok(Self::Q5_0),
139            "Q5_1" | "Q51" => Ok(Self::Q5_1),
140            "Q8_0" | "Q80" => Ok(Self::Q8_0),
141            "Q2_K" | "Q2K" => Ok(Self::Q2_K),
142            "Q3_K" | "Q3K" => Ok(Self::Q3_K),
143            "Q4_K" | "Q4K" => Ok(Self::Q4_K),
144            "Q5_K" | "Q5K" => Ok(Self::Q5_K),
145            "Q6_K" | "Q6K" => Ok(Self::Q6_K),
146            _ => Err(QuantizeError::InvalidQuantizationType),
147        }
148    }
149}
150
151impl std::fmt::Display for QuantizationType {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        write!(f, "{}", self.name())
154    }
155}
156
157/// Progress callback for quantization operations
158pub type ProgressCallback = Box<dyn Fn(f32) + Send>;
159
160/// Model quantizer for converting Whisper models to different quantization formats
161pub struct WhisperQuantize;
162
163impl WhisperQuantize {
164    /// Quantize a model file to a specified quantization type
165    ///
166    /// # Arguments
167    /// * `input_path` - Path to the input model file (must be in GGML format)
168    /// * `output_path` - Path where the quantized model will be saved
169    /// * `qtype` - The quantization type to use
170    ///
171    /// # Example
172    /// ```no_run
173    /// use whisper_cpp_plus::{WhisperQuantize, QuantizationType};
174    ///
175    /// WhisperQuantize::quantize_model_file(
176    ///     "models/ggml-base.bin",
177    ///     "models/ggml-base-q5_0.bin",
178    ///     QuantizationType::Q5_0
179    /// ).expect("Failed to quantize model");
180    /// ```
181    pub fn quantize_model_file<P: AsRef<Path>>(
182        input_path: P,
183        output_path: P,
184        qtype: QuantizationType,
185    ) -> Result<()> {
186        Self::quantize_model_file_impl(input_path.as_ref(), output_path.as_ref(), qtype, None)
187    }
188
189    /// Quantize a model file with progress callback
190    ///
191    /// # Arguments
192    /// * `input_path` - Path to the input model file
193    /// * `output_path` - Path where the quantized model will be saved
194    /// * `qtype` - The quantization type to use
195    /// * `callback` - Progress callback function (receives values from 0.0 to 1.0)
196    ///
197    /// # Example
198    /// ```no_run
199    /// use whisper_cpp_plus::{WhisperQuantize, QuantizationType};
200    ///
201    /// WhisperQuantize::quantize_model_file_with_progress(
202    ///     "models/ggml-base.bin",
203    ///     "models/ggml-base-q4_0.bin",
204    ///     QuantizationType::Q4_0,
205    ///     |progress| {
206    ///         println!("Progress: {:.1}%", progress * 100.0);
207    ///     }
208    /// ).expect("Failed to quantize model");
209    /// ```
210    pub fn quantize_model_file_with_progress<P, F>(
211        input_path: P,
212        output_path: P,
213        qtype: QuantizationType,
214        callback: F,
215    ) -> Result<()>
216    where
217        P: AsRef<Path>,
218        F: Fn(f32) + Send + 'static,
219    {
220        Self::quantize_model_file_impl(
221            input_path.as_ref(),
222            output_path.as_ref(),
223            qtype,
224            Some(Box::new(callback)),
225        )
226    }
227
228    fn quantize_model_file_impl(
229        input_path: &Path,
230        output_path: &Path,
231        qtype: QuantizationType,
232        callback: Option<ProgressCallback>,
233    ) -> Result<()> {
234        // Validate input file exists
235        if !input_path.exists() {
236            return Err(QuantizeError::FileNotFound(format!(
237                "{}",
238                input_path.display()
239            )));
240        }
241
242        // Convert paths to C strings
243        let input_cstr = path_to_cstring(input_path)?;
244        let output_cstr = path_to_cstring(output_path)?;
245
246        // Set up progress callback if provided
247        let callback_data = callback.map(|cb| Arc::new(Mutex::new(cb)));
248        let callback_ptr = callback_data
249            .as_ref()
250            .map(|data| Arc::clone(data) as Arc<Mutex<dyn Fn(f32) + Send>>);
251
252        // Create the FFI callback function
253        let ffi_callback: ffi::whisper_quantize_progress_callback = if callback_ptr.is_some() {
254            Some(quantize_progress_callback)
255        } else {
256            None
257        };
258
259        // Store callback data in thread-local storage for the callback to access
260        if let Some(ptr) = callback_ptr {
261            CALLBACK_DATA.with(|data| {
262                *data.borrow_mut() = Some(ptr);
263            });
264        }
265
266        // Perform quantization
267        let result = unsafe {
268            ffi::whisper_model_quantize(
269                input_cstr.as_ptr(),
270                output_cstr.as_ptr(),
271                qtype as i32,
272                ffi_callback,
273            )
274        };
275
276        // Clear callback data
277        CALLBACK_DATA.with(|data| {
278            *data.borrow_mut() = None;
279        });
280
281        // Check result
282        match result {
283            ffi::WHISPER_QUANTIZE_OK => Ok(()),
284            ffi::WHISPER_QUANTIZE_ERROR_INVALID_MODEL => Err(QuantizeError::QuantizationFailed(
285                "Invalid model file".to_string(),
286            )),
287            ffi::WHISPER_QUANTIZE_ERROR_FILE_OPEN => Err(QuantizeError::QuantizationFailed(
288                format!("Failed to open input file: {}", input_path.display()),
289            )),
290            ffi::WHISPER_QUANTIZE_ERROR_FILE_WRITE => Err(QuantizeError::QuantizationFailed(
291                format!("Failed to write output file: {}", output_path.display()),
292            )),
293            ffi::WHISPER_QUANTIZE_ERROR_INVALID_FTYPE => Err(QuantizeError::QuantizationFailed(
294                format!("Invalid quantization type: {}", qtype),
295            )),
296            ffi::WHISPER_QUANTIZE_ERROR_QUANTIZATION_FAILED => Err(
297                QuantizeError::QuantizationFailed("Quantization failed".to_string()),
298            ),
299            _ => Err(QuantizeError::QuantizationFailed(format!(
300                "Unknown quantization error: {}",
301                result
302            ))),
303        }
304    }
305
306    /// Get the quantization type of an existing model file
307    ///
308    /// # Returns
309    /// * `Ok(Some(qtype))` - The quantization type if the model is quantized
310    /// * `Ok(None)` - If the model is in full precision (F32 or F16)
311    /// * `Err(_)` - If the file cannot be read or is not a valid model
312    ///
313    /// # Example
314    /// ```no_run
315    /// use whisper_cpp_plus::WhisperQuantize;
316    ///
317    /// match WhisperQuantize::get_model_quantization_type("models/ggml-base-q5_0.bin") {
318    ///     Ok(Some(qtype)) => println!("Model is quantized as: {}", qtype),
319    ///     Ok(None) => println!("Model is not quantized"),
320    ///     Err(e) => println!("Error reading model: {}", e),
321    /// }
322    /// ```
323    pub fn get_model_quantization_type<P: AsRef<Path>>(
324        model_path: P,
325    ) -> Result<Option<QuantizationType>> {
326        let path = model_path.as_ref();
327
328        if !path.exists() {
329            return Err(QuantizeError::FileNotFound(format!("{}", path.display())));
330        }
331
332        let path_cstr = path_to_cstring(path)?;
333
334        let ftype = unsafe { ffi::whisper_model_get_ftype(path_cstr.as_ptr()) };
335
336        if ftype < 0 {
337            return Err(QuantizeError::FileOpenError(format!("{}", path.display())));
338        }
339
340        // Map the ftype to our QuantizationType enum
341        let qtype = match ftype {
342            x if x == ffi::GGML_FTYPE_ALL_F32 => None,
343            x if x == ffi::GGML_FTYPE_MOSTLY_F16 => None,
344            x if x == QuantizationType::Q4_0 as i32 => Some(QuantizationType::Q4_0),
345            x if x == QuantizationType::Q4_1 as i32 => Some(QuantizationType::Q4_1),
346            x if x == QuantizationType::Q5_0 as i32 => Some(QuantizationType::Q5_0),
347            x if x == QuantizationType::Q5_1 as i32 => Some(QuantizationType::Q5_1),
348            x if x == QuantizationType::Q8_0 as i32 => Some(QuantizationType::Q8_0),
349            x if x == QuantizationType::Q2_K as i32 => Some(QuantizationType::Q2_K),
350            x if x == QuantizationType::Q3_K as i32 => Some(QuantizationType::Q3_K),
351            x if x == QuantizationType::Q4_K as i32 => Some(QuantizationType::Q4_K),
352            x if x == QuantizationType::Q5_K as i32 => Some(QuantizationType::Q5_K),
353            x if x == QuantizationType::Q6_K as i32 => Some(QuantizationType::Q6_K),
354            _ => None,
355        };
356
357        Ok(qtype)
358    }
359
360    /// Estimate the size of a quantized model given the original model path and target quantization type
361    ///
362    /// # Returns
363    /// Estimated size in bytes of the quantized model
364    ///
365    /// # Example
366    /// ```no_run
367    /// use whisper_cpp_plus::{WhisperQuantize, QuantizationType};
368    ///
369    /// let estimated_size = WhisperQuantize::estimate_quantized_size(
370    ///     "models/ggml-base.bin",
371    ///     QuantizationType::Q5_0
372    /// ).unwrap_or(0);
373    ///
374    /// println!("Estimated after Q5_0: {} MB", estimated_size / 1024 / 1024);
375    /// ```
376    pub fn estimate_quantized_size<P: AsRef<Path>>(
377        model_path: P,
378        qtype: QuantizationType,
379    ) -> Result<u64> {
380        let path = model_path.as_ref();
381        let metadata = std::fs::metadata(path).map_err(|e| {
382            QuantizeError::QuantizationFailed(format!("Failed to read model file: {}", e))
383        })?;
384
385        let original_size = metadata.len();
386        let estimated_size = (original_size as f64 * qtype.size_factor() as f64) as u64;
387
388        Ok(estimated_size)
389    }
390}
391
392// Thread-local storage for callback data
393thread_local! {
394    static CALLBACK_DATA: std::cell::RefCell<Option<Arc<Mutex<dyn Fn(f32) + Send>>>> =
395        std::cell::RefCell::new(None);
396}
397
398// FFI callback function that forwards to the Rust callback
399extern "C" fn quantize_progress_callback(progress: f32) {
400    CALLBACK_DATA.with(|data| {
401        if let Some(callback) = data.borrow().as_ref() {
402            if let Ok(cb) = callback.lock() {
403                cb(progress);
404            }
405        }
406    });
407}
408
409// Helper function to convert Path to CString
410fn path_to_cstring(path: &Path) -> Result<CString> {
411    let path_str = path
412        .to_str()
413        .ok_or_else(|| QuantizeError::QuantizationFailed("Invalid UTF-8 in path".to_string()))?;
414
415    CString::new(path_str)
416        .map_err(|_| QuantizeError::QuantizationFailed("Path contains null byte".to_string()))
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    #[test]
424    fn test_quantization_type_names() {
425        assert_eq!(QuantizationType::Q4_0.name(), "Q4_0");
426        assert_eq!(QuantizationType::Q5_1.name(), "Q5_1");
427        assert_eq!(QuantizationType::Q8_0.name(), "Q8_0");
428        assert_eq!(QuantizationType::Q3_K.name(), "Q3_K");
429    }
430
431    #[test]
432    fn test_quantization_type_from_str() {
433        assert_eq!(
434            "q4_0".parse::<QuantizationType>().unwrap(),
435            QuantizationType::Q4_0
436        );
437        assert_eq!(
438            "Q5_1".parse::<QuantizationType>().unwrap(),
439            QuantizationType::Q5_1
440        );
441        assert_eq!(
442            "q8_0".parse::<QuantizationType>().unwrap(),
443            QuantizationType::Q8_0
444        );
445        assert_eq!(
446            "Q3K".parse::<QuantizationType>().unwrap(),
447            QuantizationType::Q3_K
448        );
449        assert!("invalid".parse::<QuantizationType>().is_err());
450    }
451
452    #[test]
453    fn test_size_factors() {
454        for qtype in QuantizationType::all() {
455            let factor = qtype.size_factor();
456            assert!(
457                factor > 0.0 && factor < 1.0,
458                "{} has invalid size factor: {}",
459                qtype,
460                factor
461            );
462        }
463    }
464}