zeta_quantization/
lib.rs

1// Copyright 2025 ZETA RETICULA INC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Unified Quantization Engine for Zeta Reticula
16//! 
17//! This module consolidates all quantization functionality from:
18//! - zeta-quantize/ (entire crate)
19//! - quantize-cli/ (core logic)
20//! - salience-engine/src/quantizer.rs
21//! - llm-rs/src/quantizer.rs
22//! - agentflow-rs/src/quantizer.rs
23//! - shared/src/quantization.rs
24
25use std::collections::HashMap;
26use serde::{Serialize, Deserialize};
27use anyhow::Result;
28use thiserror::Error;
29
30#[derive(Error, Debug)]
31pub enum QuantizationError {
32    #[error("Invalid precision level: {0}")]
33    InvalidPrecision(String),
34    #[error("Tensor operation failed: {0}")]
35    TensorError(String),
36    #[error("Model loading failed: {0}")]
37    ModelError(String),
38    #[error("Memory allocation failed: {0}")]
39    MemoryError(String),
40    #[error("Validation failed: {0}")]
41    ValidationError(String),
42    #[error("Configuration error: {0}")]
43    ConfigError(String),
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
47#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
48pub enum PrecisionLevel {
49    Int1,
50    Int2,
51    Int4,
52    Int8,
53    FP16,
54    FP32,
55}
56
57impl PrecisionLevel {
58    pub fn bits(&self) -> u8 {
59        match self {
60            PrecisionLevel::Int1 => 1,
61            PrecisionLevel::Int2 => 2,
62            PrecisionLevel::Int4 => 4,
63            PrecisionLevel::Int8 => 8,
64            PrecisionLevel::FP16 => 16,
65            PrecisionLevel::FP32 => 32,
66        }
67    }
68
69    pub fn max_value(&self) -> f32 {
70        match self {
71            PrecisionLevel::Int1 => 1.0,
72            PrecisionLevel::Int2 => 3.0,
73            PrecisionLevel::Int4 => 15.0,
74            PrecisionLevel::Int8 => 255.0,
75            PrecisionLevel::FP16 => f32::MAX,
76            PrecisionLevel::FP32 => f32::MAX,
77        }
78    }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum QuantizationAlgorithm {
83    Linear,
84    KMeans,
85    Learned,
86    BlockWise,
87    SalienceBased,
88    Adaptive,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct QuantizationConfig {
93    pub precision: PrecisionLevel,
94    pub algorithm: QuantizationAlgorithm,
95    pub block_size: usize,
96    pub salience_threshold: f32,
97    pub preserve_outliers: bool,
98    pub use_symmetric: bool,
99    pub calibration_samples: usize,
100    pub validation_threshold: f32,
101}
102
103impl Default for QuantizationConfig {
104    fn default() -> Self {
105        Self {
106            precision: PrecisionLevel::Int4,
107            algorithm: QuantizationAlgorithm::SalienceBased,
108            block_size: 128,
109            salience_threshold: 0.7,
110            preserve_outliers: true,
111            use_symmetric: false,
112            calibration_samples: 1000,
113            validation_threshold: 0.95,
114        }
115    }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct QuantizationParameters {
120    pub scale: f32,
121    pub zero_point: i32,
122    pub min_val: f32,
123    pub max_val: f32,
124}
125
126impl QuantizationParameters {
127    pub fn new(min_val: f32, max_val: f32, precision: &PrecisionLevel) -> Self {
128        let qmin = 0.0;
129        let qmax = precision.max_value();
130        let scale = (max_val - min_val) / (qmax - qmin);
131        let zero_point = (qmin - min_val / scale).round() as i32;
132
133        Self {
134            scale,
135            zero_point,
136            min_val,
137            max_val,
138        }
139    }
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct QuantizationResult {
144    pub quantized_data: Vec<i32>,
145    pub parameters: QuantizationParameters,
146    pub compression_ratio: f32,
147    pub error_metrics: ErrorMetrics,
148    pub salience_preserved: f32,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct ErrorMetrics {
153    pub mse: f32,
154    pub mae: f32,
155    pub max_error: f32,
156    pub snr: f32,
157}
158
159/// Unified Quantization Engine
160pub struct UnifiedQuantizer {
161    config: QuantizationConfig,
162    salience_weights: HashMap<usize, f32>,
163}
164
165impl UnifiedQuantizer {
166    pub fn new(config: QuantizationConfig) -> Self {
167        Self {
168            config,
169            salience_weights: HashMap::new(),
170        }
171    }
172
173    pub fn set_salience_weights(&mut self, weights: HashMap<usize, f32>) {
174        self.salience_weights = weights;
175    }
176
177    pub fn quantize(&self, data: &[f32]) -> Result<QuantizationResult, QuantizationError> {
178        match self.config.algorithm {
179            QuantizationAlgorithm::Linear => self.linear_quantize(data),
180            QuantizationAlgorithm::KMeans => self.kmeans_quantize(data),
181            QuantizationAlgorithm::Learned => self.learned_quantize(data),
182            QuantizationAlgorithm::BlockWise => self.blockwise_quantize(data),
183            QuantizationAlgorithm::SalienceBased => self.salience_quantize(data),
184            QuantizationAlgorithm::Adaptive => self.adaptive_quantize(data),
185        }
186    }
187
188    fn linear_quantize(&self, data: &[f32]) -> Result<QuantizationResult, QuantizationError> {
189        let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
190        let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
191        
192        let params = QuantizationParameters::new(min_val, max_val, &self.config.precision);
193        let mut quantized_data = Vec::with_capacity(data.len());
194        
195        for &value in data {
196            let quantized = ((value - min_val) / params.scale + params.zero_point as f32)
197                .round()
198                .clamp(0.0, self.config.precision.max_value()) as i32;
199            quantized_data.push(quantized);
200        }
201
202        let error_metrics = self.calculate_error_metrics(data, &quantized_data, &params);
203        let compression_ratio = (32.0 / self.config.precision.bits() as f32);
204
205        Ok(QuantizationResult {
206            quantized_data,
207            parameters: params,
208            compression_ratio,
209            error_metrics,
210            salience_preserved: 1.0, // Linear doesn't consider salience
211        })
212    }
213
214    fn salience_quantize(&self, data: &[f32]) -> Result<QuantizationResult, QuantizationError> {
215        // Apply salience-aware quantization
216        let mut weighted_data = Vec::with_capacity(data.len());
217        let mut salience_preserved = 0.0;
218        let mut total_salience = 0.0;
219
220        for (i, &value) in data.iter().enumerate() {
221            let salience = self.salience_weights.get(&i).copied().unwrap_or(1.0);
222            total_salience += salience;
223
224            if salience >= self.config.salience_threshold {
225                // High salience: preserve with higher precision
226                weighted_data.push(value);
227                salience_preserved += salience;
228            } else {
229                // Low salience: can use lower precision
230                let reduced_precision_value = (value * 0.9).round() / 0.9; // Slight precision reduction
231                weighted_data.push(reduced_precision_value);
232            }
233        }
234
235        salience_preserved = if total_salience > 0.0 { salience_preserved / total_salience } else { 0.0 };
236
237        // Apply linear quantization to the salience-weighted data
238        let min_val = weighted_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
239        let max_val = weighted_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
240        
241        let params = QuantizationParameters::new(min_val, max_val, &self.config.precision);
242        let mut quantized_data = Vec::with_capacity(weighted_data.len());
243        
244        for &value in &weighted_data {
245            let quantized = ((value - min_val) / params.scale + params.zero_point as f32)
246                .round()
247                .clamp(0.0, self.config.precision.max_value()) as i32;
248            quantized_data.push(quantized);
249        }
250
251        let error_metrics = self.calculate_error_metrics(data, &quantized_data, &params);
252        let compression_ratio = (32.0 / self.config.precision.bits() as f32);
253
254        Ok(QuantizationResult {
255            quantized_data,
256            parameters: params,
257            compression_ratio,
258            error_metrics,
259            salience_preserved,
260        })
261    }
262
263    fn blockwise_quantize(&self, data: &[f32]) -> Result<QuantizationResult, QuantizationError> {
264        let mut quantized_data = Vec::with_capacity(data.len());
265        let mut all_params = Vec::new();
266        let mut total_error = 0.0;
267
268        for chunk in data.chunks(self.config.block_size) {
269            let min_val = chunk.iter().fold(f32::INFINITY, |a, &b| a.min(b));
270            let max_val = chunk.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
271            
272            let params = QuantizationParameters::new(min_val, max_val, &self.config.precision);
273            all_params.push(params.clone());
274            
275            for &value in chunk {
276                let quantized = ((value - min_val) / params.scale + params.zero_point as f32)
277                    .round()
278                    .clamp(0.0, self.config.precision.max_value()) as i32;
279                quantized_data.push(quantized);
280                
281                // Calculate dequantized value for error
282                let dequantized = (quantized as f32 - params.zero_point as f32) * params.scale + min_val;
283                total_error += (value - dequantized).powi(2);
284            }
285        }
286
287        // Use average parameters for the result
288        let avg_params = if !all_params.is_empty() {
289            let avg_scale = all_params.iter().map(|p| p.scale).sum::<f32>() / all_params.len() as f32;
290            let avg_zero_point = all_params.iter().map(|p| p.zero_point).sum::<i32>() / all_params.len() as i32;
291            let avg_min = all_params.iter().map(|p| p.min_val).sum::<f32>() / all_params.len() as f32;
292            let avg_max = all_params.iter().map(|p| p.max_val).sum::<f32>() / all_params.len() as f32;
293            
294            QuantizationParameters {
295                scale: avg_scale,
296                zero_point: avg_zero_point,
297                min_val: avg_min,
298                max_val: avg_max,
299            }
300        } else {
301            QuantizationParameters::new(0.0, 1.0, &self.config.precision)
302        };
303
304        let error_metrics = self.calculate_error_metrics(data, &quantized_data, &avg_params);
305        let compression_ratio = (32.0 / self.config.precision.bits() as f32);
306
307        Ok(QuantizationResult {
308            quantized_data,
309            parameters: avg_params,
310            compression_ratio,
311            error_metrics,
312            salience_preserved: 0.8, // Blockwise preserves some structure
313        })
314    }
315
316    fn kmeans_quantize(&self, data: &[f32]) -> Result<QuantizationResult, QuantizationError> {
317        // Simplified K-means quantization
318        let k = (1 << self.config.precision.bits()).min(256) as usize;
319        let mut centroids = self.initialize_centroids(data, k);
320        
321        // Run K-means iterations
322        for _ in 0..10 {
323            let assignments = self.assign_to_centroids(data, &centroids);
324            centroids = self.update_centroids(data, &assignments, k);
325        }
326
327        // Quantize data using final centroids
328        let mut quantized_data = Vec::with_capacity(data.len());
329        for &value in data {
330            let closest_idx = self.find_closest_centroid(value, &centroids);
331            quantized_data.push(closest_idx as i32);
332        }
333
334        let min_val = centroids.iter().fold(f32::INFINITY, |a, &b| a.min(b));
335        let max_val = centroids.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
336        let params = QuantizationParameters::new(min_val, max_val, &self.config.precision);
337        
338        let error_metrics = self.calculate_kmeans_error_metrics(data, &quantized_data, &centroids);
339        let compression_ratio = (32.0 / self.config.precision.bits() as f32);
340
341        Ok(QuantizationResult {
342            quantized_data,
343            parameters: params,
344            compression_ratio,
345            error_metrics,
346            salience_preserved: 0.9, // K-means preserves data distribution
347        })
348    }
349
350    fn learned_quantize(&self, _data: &[f32]) -> Result<QuantizationResult, QuantizationError> {
351        // Placeholder for learned quantization - would require ML model
352        Err(QuantizationError::ConfigError("Learned quantization not yet implemented".to_string()))
353    }
354
355    fn adaptive_quantize(&self, data: &[f32]) -> Result<QuantizationResult, QuantizationError> {
356        // Adaptive quantization combines multiple approaches based on data characteristics
357        let variance = self.calculate_variance(data);
358        let has_outliers = self.detect_outliers(data);
359        
360        if variance > 1.0 && has_outliers {
361            // High variance with outliers: use blockwise
362            self.blockwise_quantize(data)
363        } else if !self.salience_weights.is_empty() {
364            // Has salience information: use salience-based
365            self.salience_quantize(data)
366        } else {
367            // Default: use linear
368            self.linear_quantize(data)
369        }
370    }
371
372    fn calculate_error_metrics(&self, original: &[f32], quantized: &[i32], params: &QuantizationParameters) -> ErrorMetrics {
373        let mut mse = 0.0;
374        let mut mae = 0.0;
375        let mut max_error: f32 = 0.0;
376        let mut signal_power = 0.0;
377        let mut noise_power = 0.0;
378
379        for (_i, (&orig, &quant)) in original.iter().zip(quantized.iter()).enumerate() {
380            let dequantized = (quant as f32 - params.zero_point as f32) * params.scale + params.min_val;
381            let error = orig - dequantized;
382            
383            mse += error * error;
384            mae += error.abs();
385            max_error = max_error.max(error.abs());
386            
387            signal_power += orig * orig;
388            noise_power += error * error;
389        }
390
391        let n = original.len() as f32;
392        mse /= n;
393        mae /= n;
394        
395        let snr = if noise_power > 0.0 {
396            10.0 * (signal_power / noise_power).log10()
397        } else {
398            f32::INFINITY
399        };
400
401        ErrorMetrics {
402            mse,
403            mae,
404            max_error,
405            snr,
406        }
407    }
408
409    fn calculate_kmeans_error_metrics(&self, original: &[f32], assignments: &[i32], centroids: &[f32]) -> ErrorMetrics {
410        let mut mse = 0.0;
411        let mut mae = 0.0;
412        let mut max_error: f32 = 0.0;
413        let mut signal_power = 0.0;
414        let mut noise_power = 0.0;
415
416        for (&orig, &assignment) in original.iter().zip(assignments.iter()) {
417            let centroid = centroids.get(assignment as usize).copied().unwrap_or(0.0);
418            let error = orig - centroid;
419            
420            mse += error * error;
421            mae += error.abs();
422            max_error = max_error.max(error.abs());
423            
424            signal_power += orig * orig;
425            noise_power += error * error;
426        }
427
428        let n = original.len() as f32;
429        mse /= n;
430        mae /= n;
431        
432        let snr = if noise_power > 0.0 {
433            10.0 * (signal_power / noise_power).log10()
434        } else {
435            f32::INFINITY
436        };
437
438        ErrorMetrics {
439            mse,
440            mae,
441            max_error,
442            snr,
443        }
444    }
445
446    fn initialize_centroids(&self, data: &[f32], k: usize) -> Vec<f32> {
447        let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
448        let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
449        
450        (0..k).map(|i| {
451            min_val + (max_val - min_val) * (i as f32) / (k as f32 - 1.0)
452        }).collect()
453    }
454
455    fn assign_to_centroids(&self, data: &[f32], centroids: &[f32]) -> Vec<usize> {
456        data.iter().map(|&value| {
457            self.find_closest_centroid(value, centroids)
458        }).collect()
459    }
460
461    fn find_closest_centroid(&self, value: f32, centroids: &[f32]) -> usize {
462        centroids.iter()
463            .enumerate()
464            .min_by(|(_, &a), (_, &b)| {
465                (value - a).abs().partial_cmp(&(value - b).abs()).unwrap()
466            })
467            .map(|(i, _)| i)
468            .unwrap_or(0)
469    }
470
471    fn update_centroids(&self, data: &[f32], assignments: &[usize], k: usize) -> Vec<f32> {
472        let mut new_centroids = vec![0.0; k];
473        let mut counts = vec![0; k];
474
475        for (&value, &assignment) in data.iter().zip(assignments.iter()) {
476            new_centroids[assignment] += value;
477            counts[assignment] += 1;
478        }
479
480        for i in 0..k {
481            if counts[i] > 0 {
482                new_centroids[i] /= counts[i] as f32;
483            }
484        }
485
486        new_centroids
487    }
488
489    fn calculate_variance(&self, data: &[f32]) -> f32 {
490        let mean = data.iter().sum::<f32>() / data.len() as f32;
491        let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
492        variance
493    }
494
495    fn detect_outliers(&self, data: &[f32]) -> bool {
496        let mut sorted_data = data.to_vec();
497        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
498        
499        let q1_idx = sorted_data.len() / 4;
500        let q3_idx = 3 * sorted_data.len() / 4;
501        
502        if q1_idx < sorted_data.len() && q3_idx < sorted_data.len() {
503            let q1 = sorted_data[q1_idx];
504            let q3 = sorted_data[q3_idx];
505            let iqr = q3 - q1;
506            let lower_bound = q1 - 1.5 * iqr;
507            let upper_bound = q3 + 1.5 * iqr;
508            
509            data.iter().any(|&x| x < lower_bound || x > upper_bound)
510        } else {
511            false
512        }
513    }
514
515    pub fn dequantize(&self, quantized: &[i32], params: &QuantizationParameters) -> Vec<f32> {
516        quantized.iter().map(|&q| {
517            (q as f32 - params.zero_point as f32) * params.scale + params.min_val
518        }).collect()
519    }
520}
521
522/// Factory function to create quantizer instances
523pub fn create_quantizer(config: QuantizationConfig) -> UnifiedQuantizer {
524    UnifiedQuantizer::new(config)
525}
526
527/// Convenience functions for common quantization tasks
528pub fn quantize_tensor(data: &[f32], precision: PrecisionLevel) -> Result<QuantizationResult, QuantizationError> {
529    let config = QuantizationConfig {
530        precision,
531        ..Default::default()
532    };
533    let quantizer = UnifiedQuantizer::new(config);
534    quantizer.quantize(data)
535}
536
537pub fn quantize_with_salience(
538    data: &[f32], 
539    salience_weights: HashMap<usize, f32>, 
540    precision: PrecisionLevel
541) -> Result<QuantizationResult, QuantizationError> {
542    let config = QuantizationConfig {
543        precision,
544        algorithm: QuantizationAlgorithm::SalienceBased,
545        ..Default::default()
546    };
547    let mut quantizer = UnifiedQuantizer::new(config);
548    quantizer.set_salience_weights(salience_weights);
549    quantizer.quantize(data)
550}