Skip to main content

trustformers_wasm/optimization/quantization/
config.rs

1//! Quantization configuration and supporting types
2
3use serde::{Deserialize, Serialize};
4use std::vec::Vec;
5use wasm_bindgen::prelude::*;
6
7/// Quantization strategies
8#[wasm_bindgen]
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
10pub enum QuantizationStrategy {
11    /// No quantization
12    None,
13    /// Dynamic quantization - quantize weights only
14    Dynamic,
15    /// Static quantization - quantize weights and activations
16    Static,
17    /// Post-training quantization
18    PostTraining,
19    /// Quantization-aware training (requires pre-quantized model)
20    QAT,
21    /// AWQ (Activation-aware Weight Quantization) - preserves important weights
22    AWQ,
23    /// GPTQ (Gradient-based Post-Training Quantization) - uses second-order information
24    GPTQ,
25    /// SmoothQuant - balances weights and activations difficulty
26    SmoothQuant,
27    /// LLM.int8() - mixed-precision quantization for large models
28    LLMInt8,
29    /// QLoRA - Quantized Low-Rank Adaptation
30    QLoRA,
31    /// GGML-style quantization for efficient inference
32    GGML,
33    /// Adaptive bitwidth quantization with dynamic allocation
34    AdaptiveBitwidth,
35    /// Outlier-aware quantization for handling activation spikes
36    OutlierAware,
37    /// HQQ (Half-Quadratic Quantization) - superior quality quantization using half-quadratic optimization
38    HQQ,
39    /// SpQR (Sparse-Quantized Representation) - ultra-sparse models with mixed precision
40    SpQR,
41    /// AQLM (Additive Quantization for Language Models) - additive quantization for transformers
42    AQLM,
43}
44
45/// Quantization precision levels
46#[wasm_bindgen]
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
48pub enum QuantizationPrecision {
49    /// 16-bit floating point
50    FP16,
51    /// 8-bit floating point (E4M3 or E5M2 format)
52    FP8,
53    /// 8-bit integer
54    INT8,
55    /// 4-bit integer
56    INT4,
57    /// 2-bit integer (experimental)
58    INT2,
59    /// 1-bit binary quantization
60    INT1,
61    /// Mixed precision (FP16 for outliers, INT8 for normal weights)
62    Mixed,
63    /// Adaptive precision based on layer importance
64    Adaptive,
65}
66
67/// Quantization configuration
68#[wasm_bindgen]
69#[derive(Debug, Clone)]
70pub struct QuantizationConfig {
71    strategy: QuantizationStrategy,
72    precision: QuantizationPrecision,
73    target_size_mb: f32,
74    performance_threshold: f32,
75    accuracy_threshold: f32,
76    auto_select: bool,
77}
78
79#[wasm_bindgen]
80impl QuantizationConfig {
81    /// Create a new quantization configuration
82    #[wasm_bindgen(constructor)]
83    pub fn new(strategy: QuantizationStrategy, precision: QuantizationPrecision) -> Self {
84        Self {
85            strategy,
86            precision,
87            target_size_mb: 50.0,       // Default target size
88            performance_threshold: 2.0, // 2x speedup minimum
89            accuracy_threshold: 0.95,   // 95% accuracy retention minimum
90            auto_select: false,
91        }
92    }
93
94    /// Create an automatic configuration that selects best settings
95    pub fn auto() -> Self {
96        Self {
97            strategy: QuantizationStrategy::Dynamic,
98            precision: QuantizationPrecision::INT8,
99            target_size_mb: 10.0,
100            performance_threshold: 1.5,
101            accuracy_threshold: 0.90,
102            auto_select: true,
103        }
104    }
105
106    /// Create a configuration optimized for mobile devices
107    pub fn mobile() -> Self {
108        Self {
109            strategy: QuantizationStrategy::PostTraining,
110            precision: QuantizationPrecision::INT8,
111            target_size_mb: 5.0,
112            performance_threshold: 3.0,
113            accuracy_threshold: 0.85,
114            auto_select: false,
115        }
116    }
117
118    /// Create a configuration for desktop/high-performance devices
119    pub fn desktop() -> Self {
120        Self {
121            strategy: QuantizationStrategy::Dynamic,
122            precision: QuantizationPrecision::FP16,
123            target_size_mb: 100.0,
124            performance_threshold: 1.2,
125            accuracy_threshold: 0.98,
126            auto_select: false,
127        }
128    }
129
130    /// Create a configuration for ultra-low latency inference
131    pub fn ultra_fast() -> Self {
132        Self {
133            strategy: QuantizationStrategy::GGML,
134            precision: QuantizationPrecision::FP8,
135            target_size_mb: 15.0,
136            performance_threshold: 4.0,
137            accuracy_threshold: 0.88,
138            auto_select: false,
139        }
140    }
141
142    /// Create a configuration for fine-tuning with QLoRA
143    pub fn qlora() -> Self {
144        Self {
145            strategy: QuantizationStrategy::QLoRA,
146            precision: QuantizationPrecision::Mixed,
147            target_size_mb: 8.0,
148            performance_threshold: 2.5,
149            accuracy_threshold: 0.92,
150            auto_select: false,
151        }
152    }
153
154    /// Create a configuration with adaptive bitwidth for optimal efficiency
155    pub fn adaptive() -> Self {
156        Self {
157            strategy: QuantizationStrategy::AdaptiveBitwidth,
158            precision: QuantizationPrecision::Adaptive,
159            target_size_mb: 12.0,
160            performance_threshold: 3.0,
161            accuracy_threshold: 0.93,
162            auto_select: true,
163        }
164    }
165
166    /// Create a configuration for models with activation outliers
167    pub fn outlier_aware() -> Self {
168        Self {
169            strategy: QuantizationStrategy::OutlierAware,
170            precision: QuantizationPrecision::Mixed,
171            target_size_mb: 20.0,
172            performance_threshold: 2.0,
173            accuracy_threshold: 0.96,
174            auto_select: false,
175        }
176    }
177
178    /// Set target model size in MB
179    pub fn set_target_size_mb(mut self, size_mb: f32) -> Self {
180        self.target_size_mb = size_mb;
181        self
182    }
183
184    /// Set performance threshold (minimum speedup factor)
185    pub fn set_performance_threshold(mut self, threshold: f32) -> Self {
186        self.performance_threshold = threshold;
187        self
188    }
189
190    /// Set accuracy threshold (minimum accuracy retention)
191    pub fn set_accuracy_threshold(mut self, threshold: f32) -> Self {
192        self.accuracy_threshold = threshold;
193        self
194    }
195
196    /// Enable automatic strategy selection
197    pub fn enable_auto_select(mut self) -> Self {
198        self.auto_select = true;
199        self
200    }
201
202    // Getters for private fields
203    #[wasm_bindgen(getter)]
204    pub fn strategy(&self) -> QuantizationStrategy {
205        self.strategy
206    }
207
208    #[wasm_bindgen(getter)]
209    pub fn precision(&self) -> QuantizationPrecision {
210        self.precision
211    }
212
213    #[wasm_bindgen(getter)]
214    pub fn target_size_mb(&self) -> f32 {
215        self.target_size_mb
216    }
217
218    #[wasm_bindgen(getter)]
219    pub fn performance_threshold(&self) -> f32 {
220        self.performance_threshold
221    }
222
223    #[wasm_bindgen(getter)]
224    pub fn accuracy_threshold(&self) -> f32 {
225        self.accuracy_threshold
226    }
227
228    #[wasm_bindgen(getter)]
229    pub fn auto_select(&self) -> bool {
230        self.auto_select
231    }
232}
233
234/// Quantization statistics
235#[wasm_bindgen]
236#[derive(Debug, Clone, Serialize, Deserialize)]
237pub struct QuantizationStats {
238    original_size_bytes: usize,
239    quantized_size_bytes: usize,
240    compression_ratio: f32,
241    size_reduction_percent: f32,
242    estimated_speedup: f32,
243    strategy_used: QuantizationStrategy,
244    precision_used: QuantizationPrecision,
245}
246
247#[wasm_bindgen]
248impl QuantizationStats {
249    #[wasm_bindgen(getter)]
250    pub fn original_size_bytes(&self) -> usize {
251        self.original_size_bytes
252    }
253
254    #[wasm_bindgen(getter)]
255    pub fn quantized_size_bytes(&self) -> usize {
256        self.quantized_size_bytes
257    }
258
259    #[wasm_bindgen(getter)]
260    pub fn compression_ratio(&self) -> f32 {
261        self.compression_ratio
262    }
263
264    #[wasm_bindgen(getter)]
265    pub fn size_reduction_percent(&self) -> f32 {
266        self.size_reduction_percent
267    }
268
269    #[wasm_bindgen(getter)]
270    pub fn estimated_speedup(&self) -> f32 {
271        self.estimated_speedup
272    }
273
274    #[wasm_bindgen(getter)]
275    pub fn strategy_used(&self) -> QuantizationStrategy {
276        self.strategy_used
277    }
278
279    #[wasm_bindgen(getter)]
280    pub fn precision_used(&self) -> QuantizationPrecision {
281        self.precision_used
282    }
283}
284
285impl QuantizationStats {
286    /// Create new quantization statistics
287    pub fn new(
288        original_size_bytes: usize,
289        quantized_size_bytes: usize,
290        compression_ratio: f32,
291        size_reduction_percent: f32,
292        estimated_speedup: f32,
293        strategy_used: QuantizationStrategy,
294        precision_used: QuantizationPrecision,
295    ) -> Self {
296        Self {
297            original_size_bytes,
298            quantized_size_bytes,
299            compression_ratio,
300            size_reduction_percent,
301            estimated_speedup,
302            strategy_used,
303            precision_used,
304        }
305    }
306}
307
308/// Runtime performance monitoring for adaptive quantization
309#[derive(Debug, Clone)]
310pub struct RuntimeMonitor {
311    pub inference_times: Vec<f64>,
312    pub memory_usage: Vec<usize>,
313    pub accuracy_scores: Vec<f32>,
314    pub thermal_state: ThermalState,
315    pub adaptation_history: Vec<AdaptationEvent>,
316}
317
318/// Current device thermal state
319#[derive(Debug, Clone, Copy, PartialEq, Eq)]
320pub enum ThermalState {
321    Nominal,  // Normal operating temperature
322    Fair,     // Slightly elevated
323    Serious,  // High temperature, throttling recommended
324    Critical, // Very high temperature, aggressive throttling needed
325}
326
327/// Adaptive quantization state that changes at runtime
328#[derive(Debug, Clone)]
329pub struct AdaptiveQuantizationState {
330    pub current_strategy: QuantizationStrategy,
331    pub current_precision: QuantizationPrecision,
332    pub adaptation_rate: f32,
333    pub performance_target: f32,
334    pub accuracy_target: f32,
335    pub last_adaptation: f64,
336    pub confidence_score: f32,
337}
338
339/// Record of quantization adaptations
340#[derive(Debug, Clone)]
341pub struct AdaptationEvent {
342    pub timestamp: f64,
343    pub trigger: AdaptationTrigger,
344    pub old_strategy: QuantizationStrategy,
345    pub new_strategy: QuantizationStrategy,
346    pub old_precision: QuantizationPrecision,
347    pub new_precision: QuantizationPrecision,
348    pub improvement_ratio: f32,
349}
350
351/// What triggered the adaptation
352#[derive(Debug, Clone, Copy, PartialEq, Eq)]
353pub enum AdaptationTrigger {
354    PerformanceDrop,     // Inference too slow
355    MemoryPressure,      // Running out of memory
356    AccuracyDrop,        // Model accuracy below threshold
357    ThermalThrottling,   // Device overheating
358    BatteryOptimization, // Low battery, need efficiency
359    WorkloadChange,      // Different type of inference requests
360}
361
362/// Device capabilities for quantization optimization
363#[derive(Debug, Clone)]
364pub struct DeviceCapabilities {
365    pub supports_int8: bool,
366    pub supports_int4: bool,
367    pub supports_fp16: bool,
368    pub memory_bandwidth_gb_s: f32,
369    pub compute_capability: ComputeCapability,
370}
371
372#[derive(Debug, Clone, Copy, PartialEq, Eq)]
373pub enum ComputeCapability {
374    Low,    // Basic CPU
375    Medium, // High-end CPU or integrated GPU
376    High,   // Dedicated GPU
377}