Skip to main content

trustformers_mobile/optimization/
quantization.rs

1//! Mobile Quantization Module
2//!
3//! Provides efficient quantization implementations for mobile deployment including:
4//! - INT4 quantization for ultra-low memory
5//! - INT8 quantization for balanced performance
6//! - FP16 quantization for GPU acceleration
7//! - Dynamic quantization for runtime adaptation
8
9use half::f16;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::fs::File;
13use std::io::BufReader;
14use std::path::{Path, PathBuf};
15use trustformers_core::errors::{invalid_config, runtime_error, tensor_op_error, Result};
16use trustformers_core::Tensor;
17
18/// Quantization scheme types
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[allow(non_camel_case_types)]
21pub enum QuantizationScheme {
22    Int4,
23    Int8,
24    FP16,
25    Dynamic,
26    /// GGUF Q2_K: 2.5625 bits per weight, ultra-low memory
27    GGUF_Q2_K,
28    /// GGUF Q3_K: 3.4375 bits per weight, balanced
29    GGUF_Q3_K,
30    /// GGUF Q4_K: 4.5 bits per weight, high quality
31    GGUF_Q4_K,
32    /// GGUF Q5_0: 5.5 bits per weight, very high quality
33    GGUF_Q5_0,
34    /// GGUF Q6_K: 6.5 bits per weight, near-lossless
35    GGUF_Q6_K,
36}
37
38impl std::fmt::Display for QuantizationScheme {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        match self {
41            QuantizationScheme::Int4 => write!(f, "INT4"),
42            QuantizationScheme::Int8 => write!(f, "INT8"),
43            QuantizationScheme::FP16 => write!(f, "FP16"),
44            QuantizationScheme::Dynamic => write!(f, "Dynamic"),
45            QuantizationScheme::GGUF_Q2_K => write!(f, "GGUF_Q2_K"),
46            QuantizationScheme::GGUF_Q3_K => write!(f, "GGUF_Q3_K"),
47            QuantizationScheme::GGUF_Q4_K => write!(f, "GGUF_Q4_K"),
48            QuantizationScheme::GGUF_Q5_0 => write!(f, "GGUF_Q5_0"),
49            QuantizationScheme::GGUF_Q6_K => write!(f, "GGUF_Q6_K"),
50        }
51    }
52}
53
54/// Calibration method for quantization
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum CalibrationMethod {
57    MinMax,
58    Percentile,
59    MovingAverage,
60    KLDivergence,
61}
62
63/// Quantization context with calibration data
64#[derive(Debug, Clone)]
65pub struct QuantizationContext {
66    pub method: CalibrationMethod,
67    pub num_calibration_samples: usize,
68    pub percentile: f32,    // For percentile method
69    pub smooth_factor: f32, // For moving average
70}
71
72impl Default for QuantizationContext {
73    fn default() -> Self {
74        Self {
75            method: CalibrationMethod::MinMax,
76            num_calibration_samples: 100,
77            percentile: 99.9,
78            smooth_factor: 0.999,
79        }
80    }
81}
82
83/// Quantization calibration data
84#[derive(Debug, Clone, Default)]
85pub struct QuantizationCalibration {
86    pub min_values: HashMap<String, f32>,
87    pub max_values: HashMap<String, f32>,
88    pub scales: HashMap<String, f32>,
89    pub zero_points: HashMap<String, i32>,
90    pub histogram_bins: HashMap<String, Vec<f32>>,
91}
92
93/// External quantization scheme configuration
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct QuantizationSchemeConfig {
96    /// Default quantization scheme
97    pub default_scheme: QuantizationScheme,
98    /// Layer-specific quantization schemes
99    pub layer_schemes: HashMap<String, QuantizationScheme>,
100    /// Tensor-specific quantization schemes (by hash or name)
101    pub tensor_schemes: HashMap<String, QuantizationScheme>,
102    /// Model-specific schemes
103    pub model_schemes: HashMap<String, QuantizationScheme>,
104    /// Performance-based scheme mappings
105    pub performance_schemes: HashMap<String, QuantizationScheme>,
106}
107
108impl Default for QuantizationSchemeConfig {
109    fn default() -> Self {
110        Self {
111            default_scheme: QuantizationScheme::Int8,
112            layer_schemes: HashMap::new(),
113            tensor_schemes: HashMap::new(),
114            model_schemes: HashMap::new(),
115            performance_schemes: HashMap::new(),
116        }
117    }
118}
119
120/// Quantization scheme storage manager
121#[derive(Debug, Clone)]
122pub struct QuantizationSchemeStorage {
123    /// Configuration file path
124    pub config_path: Option<PathBuf>,
125    /// In-memory configuration
126    pub config: QuantizationSchemeConfig,
127    /// Cache for recently determined schemes
128    pub scheme_cache: HashMap<String, QuantizationScheme>,
129}
130
131impl Default for QuantizationSchemeStorage {
132    fn default() -> Self {
133        Self::new()
134    }
135}
136
137impl QuantizationSchemeStorage {
138    /// Create a new storage manager
139    pub fn new() -> Self {
140        Self {
141            config_path: None,
142            config: QuantizationSchemeConfig::default(),
143            scheme_cache: HashMap::new(),
144        }
145    }
146
147    /// Create storage manager with config file
148    pub fn with_config_file<P: AsRef<Path>>(path: P) -> Result<Self> {
149        let config_path = path.as_ref().to_path_buf();
150        let config = Self::load_config(&config_path)?;
151
152        Ok(Self {
153            config_path: Some(config_path),
154            config,
155            scheme_cache: HashMap::new(),
156        })
157    }
158
159    /// Load configuration from file
160    pub fn load_config<P: AsRef<Path>>(path: P) -> Result<QuantizationSchemeConfig> {
161        let file = File::open(path.as_ref())
162            .map_err(|e| runtime_error(format!("Failed to open config file: {}", e)))?;
163        let reader = BufReader::new(file);
164
165        serde_json::from_reader(reader)
166            .map_err(|e| invalid_config("load_config", format!("Failed to parse config: {}", e)))
167    }
168
169    /// Save configuration to file
170    pub fn save_config(&self) -> Result<()> {
171        if let Some(ref path) = self.config_path {
172            let file = File::create(path)
173                .map_err(|e| runtime_error(format!("Failed to create config file: {}", e)))?;
174
175            serde_json::to_writer_pretty(file, &self.config)
176                .map_err(|e| runtime_error(format!("Failed to write config: {}", e)))?;
177        }
178        Ok(())
179    }
180
181    /// Determine quantization scheme for a tensor
182    pub fn determine_scheme(
183        &mut self,
184        tensor_id: &str,
185        layer_name: Option<&str>,
186        model_name: Option<&str>,
187    ) -> QuantizationScheme {
188        // Check cache first
189        if let Some(&scheme) = self.scheme_cache.get(tensor_id) {
190            return scheme;
191        }
192
193        // Check tensor-specific scheme
194        if let Some(&scheme) = self.config.tensor_schemes.get(tensor_id) {
195            self.scheme_cache.insert(tensor_id.to_string(), scheme);
196            return scheme;
197        }
198
199        // Check layer-specific scheme
200        if let Some(layer) = layer_name {
201            if let Some(&scheme) = self.config.layer_schemes.get(layer) {
202                self.scheme_cache.insert(tensor_id.to_string(), scheme);
203                return scheme;
204            }
205        }
206
207        // Check model-specific scheme
208        if let Some(model) = model_name {
209            if let Some(&scheme) = self.config.model_schemes.get(model) {
210                self.scheme_cache.insert(tensor_id.to_string(), scheme);
211                return scheme;
212            }
213        }
214
215        // Use default scheme
216        let default_scheme = self.config.default_scheme;
217        self.scheme_cache.insert(tensor_id.to_string(), default_scheme);
218        default_scheme
219    }
220
221    /// Set scheme for specific tensor
222    pub fn set_tensor_scheme(&mut self, tensor_id: String, scheme: QuantizationScheme) {
223        self.config.tensor_schemes.insert(tensor_id.clone(), scheme);
224        self.scheme_cache.insert(tensor_id, scheme);
225    }
226
227    /// Set scheme for specific layer
228    pub fn set_layer_scheme(&mut self, layer_name: String, scheme: QuantizationScheme) {
229        self.config.layer_schemes.insert(layer_name, scheme);
230    }
231
232    /// Set scheme for specific model
233    pub fn set_model_scheme(&mut self, model_name: String, scheme: QuantizationScheme) {
234        self.config.model_schemes.insert(model_name, scheme);
235    }
236
237    /// Clear cache
238    pub fn clear_cache(&mut self) {
239        self.scheme_cache.clear();
240    }
241
242    /// Generate tensor ID from tensor properties
243    pub fn generate_tensor_id(tensor: &Tensor, layer_name: Option<&str>) -> String {
244        let shape_str = tensor.shape().iter().map(|&s| s.to_string()).collect::<Vec<_>>().join("x");
245
246        let data_hash = {
247            if let Ok(data) = tensor.data() {
248                let sample_size = (data.len() / 100).max(1).min(1000); // Sample for hash
249                let mut hash = 0u64;
250                for i in (0..data.len()).step_by(sample_size) {
251                    hash = hash.wrapping_mul(31).wrapping_add(data[i].to_bits() as u64);
252                }
253                hash
254            } else {
255                0u64 // Default hash if data access fails
256            }
257        };
258
259        match layer_name {
260            Some(layer) => format!("{}:{}:{:x}", layer, shape_str, data_hash),
261            None => format!("tensor:{}:{:x}", shape_str, data_hash),
262        }
263    }
264}
265
266/// Trait for mobile quantizers
267pub trait MobileQuantizer: Send + Sync {
268    /// Get the quantization scheme
269    fn get_scheme(&self) -> QuantizationScheme;
270
271    /// Check if calibration is required
272    fn requires_calibration(&self) -> bool;
273
274    /// Calibrate the quantizer with sample data
275    fn calibrate(&self, data: &[Tensor]) -> Result<()>;
276
277    /// Quantize a tensor
278    fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor>;
279
280    /// Dequantize a tensor
281    fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor>;
282}
283
284/// INT4 Quantizer - Ultra-low memory for mobile
285pub struct Int4Quantizer {
286    context: QuantizationContext,
287    calibration: std::sync::RwLock<QuantizationCalibration>,
288}
289
290impl Default for Int4Quantizer {
291    fn default() -> Self {
292        Self::new()
293    }
294}
295
296impl Int4Quantizer {
297    pub fn new() -> Self {
298        Self {
299            context: QuantizationContext::default(),
300            calibration: std::sync::RwLock::new(QuantizationCalibration::default()),
301        }
302    }
303
304    fn compute_scale_zero_point(&self, min_val: f32, max_val: f32) -> (f32, i32) {
305        let qmin = -8.0; // 4-bit signed: -8 to 7
306        let qmax = 7.0;
307
308        let scale = (max_val - min_val) / (qmax - qmin);
309        let zero_point = ((qmin - min_val / scale).round() as i32).clamp(-8, 7);
310
311        (scale, zero_point)
312    }
313
314    fn quantize_value(&self, value: f32, scale: f32, zero_point: i32) -> i8 {
315        let quantized = (value / scale).round() as i32 + zero_point;
316        quantized.clamp(-8, 7) as i8
317    }
318
319    fn dequantize_value(&self, quantized: i8, scale: f32, zero_point: i32) -> f32 {
320        (quantized as i32 - zero_point) as f32 * scale
321    }
322}
323
324impl MobileQuantizer for Int4Quantizer {
325    fn get_scheme(&self) -> QuantizationScheme {
326        QuantizationScheme::Int4
327    }
328
329    fn requires_calibration(&self) -> bool {
330        true
331    }
332
333    fn calibrate(&self, data: &[Tensor]) -> Result<()> {
334        let mut calibration = self.calibration.write().expect("RwLock poisoned");
335
336        for tensor in data {
337            let tensor_data = tensor.data()?;
338            let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
339            let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
340
341            let (scale, zero_point) = self.compute_scale_zero_point(min_val, max_val);
342
343            // Store as global calibration parameters (simplified for single tensor case)
344            calibration.min_values.insert("global".to_string(), min_val);
345            calibration.max_values.insert("global".to_string(), max_val);
346            calibration.scales.insert("global".to_string(), scale);
347            calibration.zero_points.insert("global".to_string(), zero_point);
348        }
349
350        Ok(())
351    }
352
353    fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
354        let calibration = self.calibration.read().expect("RwLock poisoned");
355        let tensor_data = tensor.data()?;
356
357        // Get or compute scale and zero point
358        let (scale, zero_point) = if let Some(&scale) = calibration.scales.get("global") {
359            (
360                scale,
361                *calibration.zero_points.get("global").expect("No global zero point"),
362            )
363        } else {
364            // Compute on the fly if not calibrated
365            let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
366            let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
367            self.compute_scale_zero_point(min_val, max_val)
368        };
369
370        // Quantize to INT4 (stored as individual i8 values for compatibility)
371        let quantized_data: Vec<i8> =
372            tensor_data.iter().map(|&x| self.quantize_value(x, scale, zero_point)).collect();
373
374        // Convert to f32 for tensor storage (maintain same shape)
375        let quantized_f32: Vec<f32> = quantized_data.iter().map(|&x| x as f32).collect();
376
377        // Create quantized tensor with same shape as original
378        let quantized_tensor = Tensor::from_vec(quantized_f32, &tensor.shape())?;
379
380        // Note: Quantization parameters stored separately (tensor doesn't support metadata)
381
382        Ok(quantized_tensor)
383    }
384
385    fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
386        let calibration = self.calibration.read().expect("RwLock poisoned");
387        let tensor_data = tensor.data()?;
388
389        // Get quantization parameters from calibration data
390        let (scale, zero_point) = if let Some(&scale) = calibration.scales.get("global") {
391            (
392                scale,
393                *calibration.zero_points.get("global").expect("No global zero point"),
394            )
395        } else {
396            // Fallback: estimate from quantized data range
397            let min_q = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b)) as i8;
398            let max_q = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)) as i8;
399            let range = (max_q - min_q) as f32;
400            let scale = if range > 0.0 { 15.0 / range } else { 1.0 }; // 4-bit range is -8 to 7 = 15
401            (scale, 0)
402        };
403
404        // Dequantize i8 values back to f32
405        let dequantized_data: Vec<f32> = tensor_data
406            .iter()
407            .map(|&x| self.dequantize_value(x as i8, scale, zero_point))
408            .collect();
409
410        Tensor::from_vec(dequantized_data, &tensor.shape())
411    }
412}
413
414/// INT8 Quantizer - Balanced performance and accuracy
415pub struct Int8Quantizer {
416    context: QuantizationContext,
417    calibration: std::sync::RwLock<QuantizationCalibration>,
418    symmetric: bool,
419}
420
421impl Default for Int8Quantizer {
422    fn default() -> Self {
423        Self::new()
424    }
425}
426
427impl Int8Quantizer {
428    pub fn new() -> Self {
429        Self {
430            context: QuantizationContext::default(),
431            calibration: std::sync::RwLock::new(QuantizationCalibration::default()),
432            symmetric: true, // Symmetric quantization is often better for mobile
433        }
434    }
435
436    fn compute_scale_zero_point(&self, min_val: f32, max_val: f32) -> (f32, i32) {
437        if self.symmetric {
438            // Symmetric quantization (zero point = 0)
439            let abs_max = min_val.abs().max(max_val.abs());
440            let scale = abs_max / 127.0;
441            (scale, 0)
442        } else {
443            // Asymmetric quantization
444            let qmin = -128.0;
445            let qmax = 127.0;
446            let scale = (max_val - min_val) / (qmax - qmin);
447            let zero_point = ((qmin - min_val / scale).round() as i32).clamp(-128, 127);
448            (scale, zero_point)
449        }
450    }
451}
452
453impl MobileQuantizer for Int8Quantizer {
454    fn get_scheme(&self) -> QuantizationScheme {
455        QuantizationScheme::Int8
456    }
457
458    fn requires_calibration(&self) -> bool {
459        true
460    }
461
462    fn calibrate(&self, data: &[Tensor]) -> Result<()> {
463        let mut calibration = self.calibration.write().expect("RwLock poisoned");
464
465        for tensor in data {
466            let tensor_data = tensor.data()?;
467
468            let (min_val, max_val) = match self.context.method {
469                CalibrationMethod::MinMax => {
470                    let min = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
471                    let max = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
472                    (min, max)
473                },
474                CalibrationMethod::Percentile => {
475                    let mut sorted = tensor_data.to_vec();
476                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
477                    let percentile_idx =
478                        (sorted.len() as f32 * self.context.percentile / 100.0) as usize;
479                    let min_idx =
480                        (sorted.len() as f32 * (100.0 - self.context.percentile) / 100.0) as usize;
481                    (
482                        sorted[min_idx],
483                        sorted[percentile_idx.min(sorted.len() - 1)],
484                    )
485                },
486                _ => {
487                    // For other methods, fall back to min-max
488                    let min = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
489                    let max = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
490                    (min, max)
491                },
492            };
493
494            let (scale, zero_point) = self.compute_scale_zero_point(min_val, max_val);
495
496            // Store as global calibration parameters (simplified for single tensor case)
497            calibration.min_values.insert("global".to_string(), min_val);
498            calibration.max_values.insert("global".to_string(), max_val);
499            calibration.scales.insert("global".to_string(), scale);
500            calibration.zero_points.insert("global".to_string(), zero_point);
501        }
502
503        Ok(())
504    }
505
506    fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
507        let calibration = self.calibration.read().expect("RwLock poisoned");
508        let tensor_data = tensor.data()?;
509
510        let (scale, zero_point) = if let Some(&scale) = calibration.scales.get("global") {
511            (
512                scale,
513                *calibration.zero_points.get("global").expect("No global zero point"),
514            )
515        } else {
516            let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
517            let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
518            self.compute_scale_zero_point(min_val, max_val)
519        };
520
521        let quantized_data: Vec<i8> = tensor_data
522            .iter()
523            .map(|&x| {
524                let q = (x / scale).round() as i32 + zero_point;
525                q.clamp(-128, 127) as i8
526            })
527            .collect();
528
529        // Convert to f32 for tensor storage (temporary)
530        let quantized_f32: Vec<f32> = quantized_data.iter().map(|&x| x as f32).collect();
531
532        let quantized_tensor = Tensor::from_vec(quantized_f32, &tensor.shape())?;
533        // Note: Quantization parameters stored separately (tensor doesn't support metadata)
534
535        Ok(quantized_tensor)
536    }
537
538    fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
539        let calibration = self.calibration.read().expect("RwLock poisoned");
540        let tensor_data = tensor.data()?;
541
542        // Get quantization parameters from calibration data
543        let (scale, zero_point) = if let Some(&scale) = calibration.scales.get("global") {
544            (
545                scale,
546                *calibration.zero_points.get("global").expect("No global zero point"),
547            )
548        } else {
549            // Fallback: estimate from quantized data range
550            let min_q = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b)) as i32;
551            let max_q = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)) as i32;
552            let range = (max_q - min_q) as f32;
553            let scale = if range > 0.0 { 255.0 / range } else { 1.0 }; // 8-bit range is 255
554            (scale, 0)
555        };
556
557        let dequantized_data: Vec<f32> =
558            tensor_data.iter().map(|&x| ((x as i32) - zero_point) as f32 * scale).collect();
559
560        Tensor::from_vec(dequantized_data, &tensor.shape())
561    }
562}
563
564/// FP16 Quantizer - Hardware-accelerated on modern mobile GPUs
565pub struct FP16Quantizer;
566
567impl Default for FP16Quantizer {
568    fn default() -> Self {
569        Self::new()
570    }
571}
572
573impl FP16Quantizer {
574    pub fn new() -> Self {
575        Self
576    }
577}
578
579impl MobileQuantizer for FP16Quantizer {
580    fn get_scheme(&self) -> QuantizationScheme {
581        QuantizationScheme::FP16
582    }
583
584    fn requires_calibration(&self) -> bool {
585        false // FP16 doesn't require calibration
586    }
587
588    fn calibrate(&self, _data: &[Tensor]) -> Result<()> {
589        Ok(()) // No calibration needed
590    }
591
592    fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
593        let tensor_data = tensor.data()?;
594
595        // Convert to FP16
596        let fp16_data: Vec<f16> = tensor_data.iter().map(|&x| f16::from_f32(x)).collect();
597
598        // Convert back to f32 for storage (temporary - in real implementation would store as f16)
599        let quantized_data: Vec<f32> = fp16_data.iter().map(|&x| f32::from(x)).collect();
600
601        let quantized_tensor = Tensor::from_vec(quantized_data, &tensor.shape())?;
602        // Note: Quantization parameters stored separately (tensor doesn't support metadata)
603
604        Ok(quantized_tensor)
605    }
606
607    fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
608        // FP16 quantization is lossless within range, so just return clone
609        Ok(tensor.clone())
610    }
611}
612
613/// Dynamic Quantizer - Adapts quantization based on runtime statistics
614pub struct DynamicQuantizer {
615    int8_quantizer: Int8Quantizer,
616    fp16_quantizer: FP16Quantizer,
617    selection_threshold: f32,
618    scheme_storage: QuantizationSchemeStorage,
619    layer_context: Option<String>,
620    model_context: Option<String>,
621}
622
623impl Default for DynamicQuantizer {
624    fn default() -> Self {
625        Self::new()
626    }
627}
628
629impl DynamicQuantizer {
630    pub fn new() -> Self {
631        Self {
632            int8_quantizer: Int8Quantizer::new(),
633            fp16_quantizer: FP16Quantizer::new(),
634            selection_threshold: 0.1, // 10% error threshold
635            scheme_storage: QuantizationSchemeStorage::new(),
636            layer_context: None,
637            model_context: None,
638        }
639    }
640
641    pub fn with_config_file<P: AsRef<Path>>(path: P) -> Result<Self> {
642        let scheme_storage = QuantizationSchemeStorage::with_config_file(path)?;
643        Ok(Self {
644            int8_quantizer: Int8Quantizer::new(),
645            fp16_quantizer: FP16Quantizer::new(),
646            selection_threshold: 0.1,
647            scheme_storage,
648            layer_context: None,
649            model_context: None,
650        })
651    }
652
653    pub fn set_layer_context(&mut self, layer_name: String) {
654        self.layer_context = Some(layer_name);
655    }
656
657    pub fn set_model_context(&mut self, model_name: String) {
658        self.model_context = Some(model_name);
659    }
660
661    /// Get mutable reference to scheme storage for configuration
662    pub fn scheme_storage_mut(&mut self) -> &mut QuantizationSchemeStorage {
663        &mut self.scheme_storage
664    }
665
666    /// Get reference to scheme storage
667    pub fn scheme_storage(&self) -> &QuantizationSchemeStorage {
668        &self.scheme_storage
669    }
670
671    fn select_quantization_scheme(&self, tensor: &Tensor) -> Result<QuantizationScheme> {
672        let tensor_data = tensor.data()?;
673
674        // Compute dynamic range
675        let min_val = tensor_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
676        let max_val = tensor_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
677        let range = max_val - min_val;
678
679        // Compute variance
680        let mean = tensor_data.iter().sum::<f32>() / tensor_data.len() as f32;
681        let variance =
682            tensor_data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / tensor_data.len() as f32;
683
684        // Decision logic
685        if range < 1.0 && variance < 0.01 {
686            // Small range and low variance - INT8 is sufficient
687            Ok(QuantizationScheme::Int8)
688        } else {
689            // Large range or high variance - use FP16
690            Ok(QuantizationScheme::FP16)
691        }
692    }
693}
694
695impl MobileQuantizer for DynamicQuantizer {
696    fn get_scheme(&self) -> QuantizationScheme {
697        QuantizationScheme::Dynamic
698    }
699
700    fn requires_calibration(&self) -> bool {
701        true // INT8 path requires calibration
702    }
703
704    fn calibrate(&self, data: &[Tensor]) -> Result<()> {
705        // Calibrate INT8 quantizer
706        self.int8_quantizer.calibrate(data)
707    }
708
709    fn quantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
710        // Generate tensor ID for scheme determination
711        let tensor_id =
712            QuantizationSchemeStorage::generate_tensor_id(tensor, self.layer_context.as_deref());
713
714        // First check external storage for quantization scheme
715        let mut storage = self.scheme_storage.clone();
716        let scheme = storage.determine_scheme(
717            &tensor_id,
718            self.layer_context.as_deref(),
719            self.model_context.as_deref(),
720        );
721
722        // If external storage returns Dynamic, fall back to selection logic
723        let final_scheme = if scheme == QuantizationScheme::Dynamic {
724            self.select_quantization_scheme(tensor)?
725        } else {
726            scheme
727        };
728
729        match final_scheme {
730            QuantizationScheme::Int4 => {
731                // For int4, we need to create a quantizer instance
732                let int4_quantizer = Int4Quantizer::new();
733                int4_quantizer.quantize_tensor(tensor)
734            },
735            QuantizationScheme::Int8 => self.int8_quantizer.quantize_tensor(tensor),
736            QuantizationScheme::FP16 => self.fp16_quantizer.quantize_tensor(tensor),
737            // GGUF schemes fall back to INT8 (will be handled by dedicated GGUF quantizer)
738            QuantizationScheme::GGUF_Q2_K
739            | QuantizationScheme::GGUF_Q3_K
740            | QuantizationScheme::GGUF_Q4_K
741            | QuantizationScheme::GGUF_Q5_0
742            | QuantizationScheme::GGUF_Q6_K => {
743                // Note: GGUF quantization should use MobileGGUFQuantizer directly
744                // For now, fall back to INT8
745                self.int8_quantizer.quantize_tensor(tensor)
746            },
747            QuantizationScheme::Dynamic => {
748                // This shouldn't happen after our check above, but handle gracefully
749                let selected_scheme = self.select_quantization_scheme(tensor)?;
750                match selected_scheme {
751                    QuantizationScheme::Int4 => {
752                        let int4_quantizer = Int4Quantizer::new();
753                        int4_quantizer.quantize_tensor(tensor)
754                    },
755                    QuantizationScheme::Int8 => self.int8_quantizer.quantize_tensor(tensor),
756                    QuantizationScheme::FP16 => self.fp16_quantizer.quantize_tensor(tensor),
757                    // GGUF schemes fall back to INT8
758                    QuantizationScheme::GGUF_Q2_K
759                    | QuantizationScheme::GGUF_Q3_K
760                    | QuantizationScheme::GGUF_Q4_K
761                    | QuantizationScheme::GGUF_Q5_0
762                    | QuantizationScheme::GGUF_Q6_K => self.int8_quantizer.quantize_tensor(tensor),
763                    QuantizationScheme::Dynamic => {
764                        // If still Dynamic, default to Int8 as fallback
765                        self.int8_quantizer.quantize_tensor(tensor)
766                    },
767                }
768            },
769        }
770    }
771
772    fn dequantize_tensor(&self, tensor: &Tensor) -> Result<Tensor> {
773        // Generate tensor ID for scheme determination
774        let tensor_id =
775            QuantizationSchemeStorage::generate_tensor_id(tensor, self.layer_context.as_deref());
776
777        // Determine scheme from external storage
778        let mut storage = self.scheme_storage.clone();
779        let scheme = storage.determine_scheme(
780            &tensor_id,
781            self.layer_context.as_deref(),
782            self.model_context.as_deref(),
783        );
784
785        match scheme {
786            QuantizationScheme::Int8 => self.int8_quantizer.dequantize_tensor(tensor),
787            QuantizationScheme::FP16 => self.fp16_quantizer.dequantize_tensor(tensor),
788            QuantizationScheme::Int4 => {
789                // For int4, we need to create a quantizer instance
790                let int4_quantizer = Int4Quantizer::new();
791                int4_quantizer.dequantize_tensor(tensor)
792            },
793            // GGUF schemes fall back to INT8 dequantization
794            QuantizationScheme::GGUF_Q2_K
795            | QuantizationScheme::GGUF_Q3_K
796            | QuantizationScheme::GGUF_Q4_K
797            | QuantizationScheme::GGUF_Q5_0
798            | QuantizationScheme::GGUF_Q6_K => {
799                // Note: GGUF dequantization should use MobileGGUFQuantizer directly
800                // For now, fall back to INT8
801                self.int8_quantizer.dequantize_tensor(tensor)
802            },
803            QuantizationScheme::Dynamic => {
804                // For dynamic schemes, fall back to the selection logic
805                let selected_scheme = self.select_quantization_scheme(tensor)?;
806                match selected_scheme {
807                    QuantizationScheme::Int8 => self.int8_quantizer.dequantize_tensor(tensor),
808                    QuantizationScheme::FP16 => self.fp16_quantizer.dequantize_tensor(tensor),
809                    _ => self.int8_quantizer.dequantize_tensor(tensor), // Default fallback
810                }
811            },
812        }
813    }
814}
815
816/// Quantization utilities
817pub struct QuantizationUtils;
818
819impl QuantizationUtils {
820    /// Compute quantization error
821    pub fn compute_error(original: &Tensor, quantized: &Tensor) -> Result<f32> {
822        let orig_data = original.data()?;
823        let quant_data = quantized.data()?;
824
825        if orig_data.len() != quant_data.len() {
826            return Err(tensor_op_error(
827                "compute_error",
828                "Tensors must have same size for error computation",
829            ));
830        }
831
832        let mse = orig_data
833            .iter()
834            .zip(quant_data.iter())
835            .map(|(&o, &q)| (o - q).powi(2))
836            .sum::<f32>()
837            / orig_data.len() as f32;
838
839        Ok(mse.sqrt())
840    }
841
842    /// Get compression ratio
843    pub fn compression_ratio(scheme: QuantizationScheme) -> f32 {
844        match scheme {
845            QuantizationScheme::Int4 => 8.0,                // 32-bit to 4-bit
846            QuantizationScheme::Int8 => 4.0,                // 32-bit to 8-bit
847            QuantizationScheme::FP16 => 2.0,                // 32-bit to 16-bit
848            QuantizationScheme::Dynamic => 3.0,             // Average
849            QuantizationScheme::GGUF_Q2_K => 32.0 / 2.5625, // 32-bit to 2.5625-bit
850            QuantizationScheme::GGUF_Q3_K => 32.0 / 3.4375, // 32-bit to 3.4375-bit
851            QuantizationScheme::GGUF_Q4_K => 32.0 / 4.5,    // 32-bit to 4.5-bit
852            QuantizationScheme::GGUF_Q5_0 => 32.0 / 5.5,    // 32-bit to 5.5-bit
853            QuantizationScheme::GGUF_Q6_K => 32.0 / 6.5,    // 32-bit to 6.5-bit
854        }
855    }
856
857    /// Estimate memory savings
858    pub fn memory_savings_percent(scheme: QuantizationScheme) -> f32 {
859        let ratio = Self::compression_ratio(scheme);
860        (1.0 - 1.0 / ratio) * 100.0
861    }
862}
863
864#[cfg(test)]
865mod tests {
866    use super::*;
867
868    #[test]
869    fn test_int4_quantization() {
870        let quantizer = Int4Quantizer::new();
871        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 4])
872            .expect("Failed to create tensor");
873
874        // Calibrate
875        quantizer.calibrate(std::slice::from_ref(&tensor)).expect("Calibration failed");
876
877        // Quantize
878        let quantized = quantizer.quantize_tensor(&tensor).expect("Quantization failed");
879        assert_eq!(quantized.shape(), tensor.shape());
880
881        // Dequantize
882        let dequantized = quantizer.dequantize_tensor(&quantized).expect("Dequantization failed");
883        assert_eq!(dequantized.shape(), tensor.shape());
884
885        // Check error is reasonable
886        let error = QuantizationUtils::compute_error(&tensor, &dequantized)
887            .expect("Error computation failed");
888        assert!(error < 1.0); // Error should be small
889    }
890
891    #[test]
892    fn test_int8_quantization() {
893        let quantizer = Int8Quantizer::new();
894        let tensor = Tensor::from_vec(vec![-10.0, -5.0, 0.0, 5.0, 10.0], &[5])
895            .expect("Failed to create tensor");
896
897        quantizer.calibrate(std::slice::from_ref(&tensor)).expect("Calibration failed");
898
899        let quantized = quantizer.quantize_tensor(&tensor).expect("Quantization failed");
900        let dequantized = quantizer.dequantize_tensor(&quantized).expect("Dequantization failed");
901
902        let error = QuantizationUtils::compute_error(&tensor, &dequantized)
903            .expect("Error computation failed");
904        assert!(error < 0.1); // INT8 should have very low error
905    }
906
907    #[test]
908    fn test_fp16_quantization() {
909        let quantizer = FP16Quantizer::new();
910        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).expect("Operation failed");
911
912        // FP16 doesn't require calibration
913        assert!(!quantizer.requires_calibration());
914
915        let quantized = quantizer.quantize_tensor(&tensor).expect("Quantization failed");
916        let dequantized = quantizer.dequantize_tensor(&quantized).expect("Dequantization failed");
917
918        // FP16 should have minimal error for normal range values
919        let error = QuantizationUtils::compute_error(&tensor, &dequantized)
920            .expect("Error computation failed");
921        assert!(error < 0.001);
922    }
923
924    #[test]
925    fn test_dynamic_quantization() {
926        let mut quantizer = DynamicQuantizer::new();
927
928        // Small range tensor - should use INT8
929        let small_range =
930            Tensor::from_vec(vec![0.1, 0.2, 0.3, 0.4], &[4]).expect("Operation failed");
931
932        quantizer
933            .calibrate(std::slice::from_ref(&small_range))
934            .expect("Operation failed");
935        let quantized = quantizer.quantize_tensor(&small_range).expect("Operation failed");
936
937        // Test external storage functionality
938        let tensor_id = QuantizationSchemeStorage::generate_tensor_id(&small_range, None);
939
940        // Configure external storage to use FP16 for this tensor
941        quantizer
942            .scheme_storage_mut()
943            .set_tensor_scheme(tensor_id.clone(), QuantizationScheme::FP16);
944
945        // Quantize again - should now use FP16 due to external storage
946        let quantized_fp16 = quantizer.quantize_tensor(&small_range).expect("Operation failed");
947
948        // Verify the storage can retrieve the scheme
949        let stored_scheme = quantizer.scheme_storage_mut().determine_scheme(&tensor_id, None, None);
950        assert_eq!(stored_scheme, QuantizationScheme::FP16);
951
952        // Test default fallback
953        let unknown_tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).expect("Operation failed");
954        let unknown_id = QuantizationSchemeStorage::generate_tensor_id(&unknown_tensor, None);
955        let default_scheme =
956            quantizer.scheme_storage_mut().determine_scheme(&unknown_id, None, None);
957        assert_eq!(default_scheme, QuantizationScheme::Int8); // Default scheme
958    }
959
960    #[test]
961    fn test_compression_ratios() {
962        assert_eq!(
963            QuantizationUtils::compression_ratio(QuantizationScheme::Int4),
964            8.0
965        );
966        assert_eq!(
967            QuantizationUtils::compression_ratio(QuantizationScheme::Int8),
968            4.0
969        );
970        assert_eq!(
971            QuantizationUtils::compression_ratio(QuantizationScheme::FP16),
972            2.0
973        );
974
975        assert_eq!(
976            QuantizationUtils::memory_savings_percent(QuantizationScheme::Int4),
977            87.5
978        );
979        assert_eq!(
980            QuantizationUtils::memory_savings_percent(QuantizationScheme::Int8),
981            75.0
982        );
983        assert_eq!(
984            QuantizationUtils::memory_savings_percent(QuantizationScheme::FP16),
985            50.0
986        );
987    }
988}