1use std::collections::HashMap;
26use serde::{Serialize, Deserialize};
27use anyhow::Result;
28use thiserror::Error;
29
30#[derive(Error, Debug)]
31pub enum QuantizationError {
32 #[error("Invalid precision level: {0}")]
33 InvalidPrecision(String),
34 #[error("Tensor operation failed: {0}")]
35 TensorError(String),
36 #[error("Model loading failed: {0}")]
37 ModelError(String),
38 #[error("Memory allocation failed: {0}")]
39 MemoryError(String),
40 #[error("Validation failed: {0}")]
41 ValidationError(String),
42 #[error("Configuration error: {0}")]
43 ConfigError(String),
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
47#[cfg_attr(feature = "clap", derive(clap::ValueEnum))]
48pub enum PrecisionLevel {
49 Int1,
50 Int2,
51 Int4,
52 Int8,
53 FP16,
54 FP32,
55}
56
57impl PrecisionLevel {
58 pub fn bits(&self) -> u8 {
59 match self {
60 PrecisionLevel::Int1 => 1,
61 PrecisionLevel::Int2 => 2,
62 PrecisionLevel::Int4 => 4,
63 PrecisionLevel::Int8 => 8,
64 PrecisionLevel::FP16 => 16,
65 PrecisionLevel::FP32 => 32,
66 }
67 }
68
69 pub fn max_value(&self) -> f32 {
70 match self {
71 PrecisionLevel::Int1 => 1.0,
72 PrecisionLevel::Int2 => 3.0,
73 PrecisionLevel::Int4 => 15.0,
74 PrecisionLevel::Int8 => 255.0,
75 PrecisionLevel::FP16 => f32::MAX,
76 PrecisionLevel::FP32 => f32::MAX,
77 }
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub enum QuantizationAlgorithm {
83 Linear,
84 KMeans,
85 Learned,
86 BlockWise,
87 SalienceBased,
88 Adaptive,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct QuantizationConfig {
93 pub precision: PrecisionLevel,
94 pub algorithm: QuantizationAlgorithm,
95 pub block_size: usize,
96 pub salience_threshold: f32,
97 pub preserve_outliers: bool,
98 pub use_symmetric: bool,
99 pub calibration_samples: usize,
100 pub validation_threshold: f32,
101}
102
103impl Default for QuantizationConfig {
104 fn default() -> Self {
105 Self {
106 precision: PrecisionLevel::Int4,
107 algorithm: QuantizationAlgorithm::SalienceBased,
108 block_size: 128,
109 salience_threshold: 0.7,
110 preserve_outliers: true,
111 use_symmetric: false,
112 calibration_samples: 1000,
113 validation_threshold: 0.95,
114 }
115 }
116}
117
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct QuantizationParameters {
120 pub scale: f32,
121 pub zero_point: i32,
122 pub min_val: f32,
123 pub max_val: f32,
124}
125
126impl QuantizationParameters {
127 pub fn new(min_val: f32, max_val: f32, precision: &PrecisionLevel) -> Self {
128 let qmin = 0.0;
129 let qmax = precision.max_value();
130 let scale = (max_val - min_val) / (qmax - qmin);
131 let zero_point = (qmin - min_val / scale).round() as i32;
132
133 Self {
134 scale,
135 zero_point,
136 min_val,
137 max_val,
138 }
139 }
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct QuantizationResult {
144 pub quantized_data: Vec<i32>,
145 pub parameters: QuantizationParameters,
146 pub compression_ratio: f32,
147 pub error_metrics: ErrorMetrics,
148 pub salience_preserved: f32,
149}
150
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct ErrorMetrics {
153 pub mse: f32,
154 pub mae: f32,
155 pub max_error: f32,
156 pub snr: f32,
157}
158
159pub struct UnifiedQuantizer {
161 config: QuantizationConfig,
162 salience_weights: HashMap<usize, f32>,
163}
164
165impl UnifiedQuantizer {
166 pub fn new(config: QuantizationConfig) -> Self {
167 Self {
168 config,
169 salience_weights: HashMap::new(),
170 }
171 }
172
173 pub fn set_salience_weights(&mut self, weights: HashMap<usize, f32>) {
174 self.salience_weights = weights;
175 }
176
177 pub fn quantize(&self, data: &[f32]) -> Result<QuantizationResult, QuantizationError> {
178 match self.config.algorithm {
179 QuantizationAlgorithm::Linear => self.linear_quantize(data),
180 QuantizationAlgorithm::KMeans => self.kmeans_quantize(data),
181 QuantizationAlgorithm::Learned => self.learned_quantize(data),
182 QuantizationAlgorithm::BlockWise => self.blockwise_quantize(data),
183 QuantizationAlgorithm::SalienceBased => self.salience_quantize(data),
184 QuantizationAlgorithm::Adaptive => self.adaptive_quantize(data),
185 }
186 }
187
188 fn linear_quantize(&self, data: &[f32]) -> Result<QuantizationResult, QuantizationError> {
189 let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
190 let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
191
192 let params = QuantizationParameters::new(min_val, max_val, &self.config.precision);
193 let mut quantized_data = Vec::with_capacity(data.len());
194
195 for &value in data {
196 let quantized = ((value - min_val) / params.scale + params.zero_point as f32)
197 .round()
198 .clamp(0.0, self.config.precision.max_value()) as i32;
199 quantized_data.push(quantized);
200 }
201
202 let error_metrics = self.calculate_error_metrics(data, &quantized_data, ¶ms);
203 let compression_ratio = (32.0 / self.config.precision.bits() as f32);
204
205 Ok(QuantizationResult {
206 quantized_data,
207 parameters: params,
208 compression_ratio,
209 error_metrics,
210 salience_preserved: 1.0, })
212 }
213
214 fn salience_quantize(&self, data: &[f32]) -> Result<QuantizationResult, QuantizationError> {
215 let mut weighted_data = Vec::with_capacity(data.len());
217 let mut salience_preserved = 0.0;
218 let mut total_salience = 0.0;
219
220 for (i, &value) in data.iter().enumerate() {
221 let salience = self.salience_weights.get(&i).copied().unwrap_or(1.0);
222 total_salience += salience;
223
224 if salience >= self.config.salience_threshold {
225 weighted_data.push(value);
227 salience_preserved += salience;
228 } else {
229 let reduced_precision_value = (value * 0.9).round() / 0.9; weighted_data.push(reduced_precision_value);
232 }
233 }
234
235 salience_preserved = if total_salience > 0.0 { salience_preserved / total_salience } else { 0.0 };
236
237 let min_val = weighted_data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
239 let max_val = weighted_data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
240
241 let params = QuantizationParameters::new(min_val, max_val, &self.config.precision);
242 let mut quantized_data = Vec::with_capacity(weighted_data.len());
243
244 for &value in &weighted_data {
245 let quantized = ((value - min_val) / params.scale + params.zero_point as f32)
246 .round()
247 .clamp(0.0, self.config.precision.max_value()) as i32;
248 quantized_data.push(quantized);
249 }
250
251 let error_metrics = self.calculate_error_metrics(data, &quantized_data, ¶ms);
252 let compression_ratio = (32.0 / self.config.precision.bits() as f32);
253
254 Ok(QuantizationResult {
255 quantized_data,
256 parameters: params,
257 compression_ratio,
258 error_metrics,
259 salience_preserved,
260 })
261 }
262
263 fn blockwise_quantize(&self, data: &[f32]) -> Result<QuantizationResult, QuantizationError> {
264 let mut quantized_data = Vec::with_capacity(data.len());
265 let mut all_params = Vec::new();
266 let mut total_error = 0.0;
267
268 for chunk in data.chunks(self.config.block_size) {
269 let min_val = chunk.iter().fold(f32::INFINITY, |a, &b| a.min(b));
270 let max_val = chunk.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
271
272 let params = QuantizationParameters::new(min_val, max_val, &self.config.precision);
273 all_params.push(params.clone());
274
275 for &value in chunk {
276 let quantized = ((value - min_val) / params.scale + params.zero_point as f32)
277 .round()
278 .clamp(0.0, self.config.precision.max_value()) as i32;
279 quantized_data.push(quantized);
280
281 let dequantized = (quantized as f32 - params.zero_point as f32) * params.scale + min_val;
283 total_error += (value - dequantized).powi(2);
284 }
285 }
286
287 let avg_params = if !all_params.is_empty() {
289 let avg_scale = all_params.iter().map(|p| p.scale).sum::<f32>() / all_params.len() as f32;
290 let avg_zero_point = all_params.iter().map(|p| p.zero_point).sum::<i32>() / all_params.len() as i32;
291 let avg_min = all_params.iter().map(|p| p.min_val).sum::<f32>() / all_params.len() as f32;
292 let avg_max = all_params.iter().map(|p| p.max_val).sum::<f32>() / all_params.len() as f32;
293
294 QuantizationParameters {
295 scale: avg_scale,
296 zero_point: avg_zero_point,
297 min_val: avg_min,
298 max_val: avg_max,
299 }
300 } else {
301 QuantizationParameters::new(0.0, 1.0, &self.config.precision)
302 };
303
304 let error_metrics = self.calculate_error_metrics(data, &quantized_data, &avg_params);
305 let compression_ratio = (32.0 / self.config.precision.bits() as f32);
306
307 Ok(QuantizationResult {
308 quantized_data,
309 parameters: avg_params,
310 compression_ratio,
311 error_metrics,
312 salience_preserved: 0.8, })
314 }
315
316 fn kmeans_quantize(&self, data: &[f32]) -> Result<QuantizationResult, QuantizationError> {
317 let k = (1 << self.config.precision.bits()).min(256) as usize;
319 let mut centroids = self.initialize_centroids(data, k);
320
321 for _ in 0..10 {
323 let assignments = self.assign_to_centroids(data, ¢roids);
324 centroids = self.update_centroids(data, &assignments, k);
325 }
326
327 let mut quantized_data = Vec::with_capacity(data.len());
329 for &value in data {
330 let closest_idx = self.find_closest_centroid(value, ¢roids);
331 quantized_data.push(closest_idx as i32);
332 }
333
334 let min_val = centroids.iter().fold(f32::INFINITY, |a, &b| a.min(b));
335 let max_val = centroids.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
336 let params = QuantizationParameters::new(min_val, max_val, &self.config.precision);
337
338 let error_metrics = self.calculate_kmeans_error_metrics(data, &quantized_data, ¢roids);
339 let compression_ratio = (32.0 / self.config.precision.bits() as f32);
340
341 Ok(QuantizationResult {
342 quantized_data,
343 parameters: params,
344 compression_ratio,
345 error_metrics,
346 salience_preserved: 0.9, })
348 }
349
350 fn learned_quantize(&self, _data: &[f32]) -> Result<QuantizationResult, QuantizationError> {
351 Err(QuantizationError::ConfigError("Learned quantization not yet implemented".to_string()))
353 }
354
355 fn adaptive_quantize(&self, data: &[f32]) -> Result<QuantizationResult, QuantizationError> {
356 let variance = self.calculate_variance(data);
358 let has_outliers = self.detect_outliers(data);
359
360 if variance > 1.0 && has_outliers {
361 self.blockwise_quantize(data)
363 } else if !self.salience_weights.is_empty() {
364 self.salience_quantize(data)
366 } else {
367 self.linear_quantize(data)
369 }
370 }
371
372 fn calculate_error_metrics(&self, original: &[f32], quantized: &[i32], params: &QuantizationParameters) -> ErrorMetrics {
373 let mut mse = 0.0;
374 let mut mae = 0.0;
375 let mut max_error: f32 = 0.0;
376 let mut signal_power = 0.0;
377 let mut noise_power = 0.0;
378
379 for (_i, (&orig, &quant)) in original.iter().zip(quantized.iter()).enumerate() {
380 let dequantized = (quant as f32 - params.zero_point as f32) * params.scale + params.min_val;
381 let error = orig - dequantized;
382
383 mse += error * error;
384 mae += error.abs();
385 max_error = max_error.max(error.abs());
386
387 signal_power += orig * orig;
388 noise_power += error * error;
389 }
390
391 let n = original.len() as f32;
392 mse /= n;
393 mae /= n;
394
395 let snr = if noise_power > 0.0 {
396 10.0 * (signal_power / noise_power).log10()
397 } else {
398 f32::INFINITY
399 };
400
401 ErrorMetrics {
402 mse,
403 mae,
404 max_error,
405 snr,
406 }
407 }
408
409 fn calculate_kmeans_error_metrics(&self, original: &[f32], assignments: &[i32], centroids: &[f32]) -> ErrorMetrics {
410 let mut mse = 0.0;
411 let mut mae = 0.0;
412 let mut max_error: f32 = 0.0;
413 let mut signal_power = 0.0;
414 let mut noise_power = 0.0;
415
416 for (&orig, &assignment) in original.iter().zip(assignments.iter()) {
417 let centroid = centroids.get(assignment as usize).copied().unwrap_or(0.0);
418 let error = orig - centroid;
419
420 mse += error * error;
421 mae += error.abs();
422 max_error = max_error.max(error.abs());
423
424 signal_power += orig * orig;
425 noise_power += error * error;
426 }
427
428 let n = original.len() as f32;
429 mse /= n;
430 mae /= n;
431
432 let snr = if noise_power > 0.0 {
433 10.0 * (signal_power / noise_power).log10()
434 } else {
435 f32::INFINITY
436 };
437
438 ErrorMetrics {
439 mse,
440 mae,
441 max_error,
442 snr,
443 }
444 }
445
446 fn initialize_centroids(&self, data: &[f32], k: usize) -> Vec<f32> {
447 let min_val = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
448 let max_val = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
449
450 (0..k).map(|i| {
451 min_val + (max_val - min_val) * (i as f32) / (k as f32 - 1.0)
452 }).collect()
453 }
454
455 fn assign_to_centroids(&self, data: &[f32], centroids: &[f32]) -> Vec<usize> {
456 data.iter().map(|&value| {
457 self.find_closest_centroid(value, centroids)
458 }).collect()
459 }
460
461 fn find_closest_centroid(&self, value: f32, centroids: &[f32]) -> usize {
462 centroids.iter()
463 .enumerate()
464 .min_by(|(_, &a), (_, &b)| {
465 (value - a).abs().partial_cmp(&(value - b).abs()).unwrap()
466 })
467 .map(|(i, _)| i)
468 .unwrap_or(0)
469 }
470
471 fn update_centroids(&self, data: &[f32], assignments: &[usize], k: usize) -> Vec<f32> {
472 let mut new_centroids = vec![0.0; k];
473 let mut counts = vec![0; k];
474
475 for (&value, &assignment) in data.iter().zip(assignments.iter()) {
476 new_centroids[assignment] += value;
477 counts[assignment] += 1;
478 }
479
480 for i in 0..k {
481 if counts[i] > 0 {
482 new_centroids[i] /= counts[i] as f32;
483 }
484 }
485
486 new_centroids
487 }
488
489 fn calculate_variance(&self, data: &[f32]) -> f32 {
490 let mean = data.iter().sum::<f32>() / data.len() as f32;
491 let variance = data.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
492 variance
493 }
494
495 fn detect_outliers(&self, data: &[f32]) -> bool {
496 let mut sorted_data = data.to_vec();
497 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
498
499 let q1_idx = sorted_data.len() / 4;
500 let q3_idx = 3 * sorted_data.len() / 4;
501
502 if q1_idx < sorted_data.len() && q3_idx < sorted_data.len() {
503 let q1 = sorted_data[q1_idx];
504 let q3 = sorted_data[q3_idx];
505 let iqr = q3 - q1;
506 let lower_bound = q1 - 1.5 * iqr;
507 let upper_bound = q3 + 1.5 * iqr;
508
509 data.iter().any(|&x| x < lower_bound || x > upper_bound)
510 } else {
511 false
512 }
513 }
514
515 pub fn dequantize(&self, quantized: &[i32], params: &QuantizationParameters) -> Vec<f32> {
516 quantized.iter().map(|&q| {
517 (q as f32 - params.zero_point as f32) * params.scale + params.min_val
518 }).collect()
519 }
520}
521
522pub fn create_quantizer(config: QuantizationConfig) -> UnifiedQuantizer {
524 UnifiedQuantizer::new(config)
525}
526
527pub fn quantize_tensor(data: &[f32], precision: PrecisionLevel) -> Result<QuantizationResult, QuantizationError> {
529 let config = QuantizationConfig {
530 precision,
531 ..Default::default()
532 };
533 let quantizer = UnifiedQuantizer::new(config);
534 quantizer.quantize(data)
535}
536
537pub fn quantize_with_salience(
538 data: &[f32],
539 salience_weights: HashMap<usize, f32>,
540 precision: PrecisionLevel
541) -> Result<QuantizationResult, QuantizationError> {
542 let config = QuantizationConfig {
543 precision,
544 algorithm: QuantizationAlgorithm::SalienceBased,
545 ..Default::default()
546 };
547 let mut quantizer = UnifiedQuantizer::new(config);
548 quantizer.set_salience_weights(salience_weights);
549 quantizer.quantize(data)
550}