1use super::config::TernaryConfig;
10use super::linear::TernaryLinear;
11use super::quantize::quantize_tensor;
12use crate::error::{Result, UnslothError};
13use candle_core::{Device, Tensor};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone, Default)]
18pub struct QuantizationStats {
19 pub layers_quantized: usize,
21 pub layers_skipped: usize,
23 pub original_params: usize,
25 pub quantized_params: usize,
27 pub original_bytes: usize,
29 pub quantized_bytes: usize,
31 pub average_sparsity: f32,
33 pub layer_sparsities: HashMap<String, f32>,
35 finalized: bool,
37}
38
39impl QuantizationStats {
40 #[must_use]
47 pub fn compression_ratio(&self) -> f32 {
48 if self.quantized_bytes == 0 {
49 1.0 } else {
51 #[allow(clippy::cast_precision_loss)]
53 {
54 self.original_bytes as f32 / self.quantized_bytes as f32
55 }
56 }
57 }
58
59 pub fn print_summary(&self) {
61 println!("=== Quantization Summary ===");
62 println!("Layers quantized: {}", self.layers_quantized);
63 println!("Layers skipped: {}", self.layers_skipped);
64 println!("Original params: {}", self.original_params);
65 println!("Quantized params: {}", self.quantized_params);
66 println!(
67 "Size reduction: {:.2}x ({:.2} MB -> {:.2} MB)",
68 self.compression_ratio(),
69 self.original_bytes as f64 / 1e6,
70 self.quantized_bytes as f64 / 1e6
71 );
72 println!("Average sparsity: {:.1}%", self.average_sparsity * 100.0);
73 }
74}
75
76#[derive(Debug)]
78pub struct QuantizedLayer {
79 pub layer: TernaryLinear,
81 pub name: String,
83 pub sparsity: f32,
85}
86
87#[derive(Debug, Clone)]
89pub struct ModelQuantizationConfig {
90 pub ternary_config: TernaryConfig,
92 pub min_layer_size: usize,
94 pub skip_patterns: Vec<String>,
96 pub verbose: bool,
98}
99
100impl Default for ModelQuantizationConfig {
101 fn default() -> Self {
102 Self {
103 ternary_config: TernaryConfig::default(),
104 min_layer_size: 1024, skip_patterns: vec![
106 "embed".to_string(),
107 "norm".to_string(),
108 "lm_head".to_string(),
109 ],
110 verbose: false,
111 }
112 }
113}
114
115pub fn quantize_linear_layer(
130 weight: &Tensor,
131 bias: Option<&Tensor>,
132 name: &str,
133 config: &ModelQuantizationConfig,
134 _device: &Device,
135) -> Result<Option<QuantizedLayer>> {
136 let dims = weight.dims();
137 if dims.len() != 2 {
138 return Err(UnslothError::ShapeMismatch {
139 expected: vec![2],
141 actual: dims.to_vec(),
142 });
143 }
144
145 let (out_features, in_features) = (dims[0], dims[1]);
146 let num_params = out_features * in_features;
147
148 if num_params < config.min_layer_size {
150 if config.verbose {
151 println!("Skipping {name} (too small: {num_params} params)");
152 }
153 return Ok(None);
154 }
155
156 for pattern in &config.skip_patterns {
157 if name.to_lowercase().contains(&pattern.to_lowercase()) {
158 if config.verbose {
159 println!("Skipping {name} (matches pattern: {pattern})");
160 }
161 return Ok(None);
162 }
163 }
164
165 let (ternary_weights, _scale) = quantize_tensor(weight, &config.ternary_config)?;
167
168 let sparsity = ternary_weights.sparsity();
169
170 if config.verbose {
171 println!(
172 "Quantizing {}: [{}, {}] -> sparsity {:.1}%",
173 name,
174 out_features,
175 in_features,
176 sparsity * 100.0
177 );
178 }
179
180 let layer = TernaryLinear::with_config(ternary_weights, bias.cloned(), config.ternary_config)?;
184
185 Ok(Some(QuantizedLayer {
186 layer,
187 name: name.to_string(),
188 sparsity,
189 }))
190}
191
192#[derive(Debug)]
194pub struct TernaryModel {
195 pub layers: HashMap<String, TernaryLinear>,
197 pub preserved_tensors: HashMap<String, Tensor>,
199 pub stats: QuantizationStats,
201 pub config: ModelQuantizationConfig,
203}
204
205impl TernaryModel {
206 #[must_use]
208 pub fn new(config: ModelQuantizationConfig) -> Self {
209 Self {
210 layers: HashMap::new(),
211 preserved_tensors: HashMap::new(),
212 stats: QuantizationStats::default(),
213 config,
214 }
215 }
216
217 pub fn add_layer(&mut self, name: String, layer: TernaryLinear, sparsity: f32) {
219 let (out_features, in_features) = layer.dims();
220 let num_params = out_features * in_features;
221
222 self.stats.layers_quantized += 1;
223 self.stats.quantized_params += num_params;
224 self.stats.quantized_bytes += layer.memory_bytes();
226 self.stats.layer_sparsities.insert(name.clone(), sparsity);
227
228 self.layers.insert(name, layer);
229 }
230
231 pub fn add_preserved(&mut self, name: String, tensor: Tensor) {
233 let num_params = tensor.elem_count();
234 self.stats.layers_skipped += 1;
235 self.stats.original_params += num_params;
236 self.stats.quantized_bytes += num_params * 4; self.preserved_tensors.insert(name, tensor);
239 }
240
241 pub fn finalize_stats(&mut self) {
246 if self.stats.finalized {
248 return;
249 }
250
251 self.stats.original_params += self.stats.quantized_params;
253 self.stats.original_bytes = self.stats.original_params * 4;
255
256 if !self.stats.layer_sparsities.is_empty() {
257 self.stats.average_sparsity = self.stats.layer_sparsities.values().sum::<f32>()
258 / self.stats.layer_sparsities.len() as f32;
259 }
260
261 self.stats.finalized = true;
262 }
263
264 #[must_use]
266 pub fn get_layer(&self, name: &str) -> Option<&TernaryLinear> {
267 self.layers.get(name)
268 }
269
270 #[must_use]
272 pub fn get_preserved(&self, name: &str) -> Option<&Tensor> {
273 self.preserved_tensors.get(name)
274 }
275}
276
277pub fn quantize_weights_collection(
290 weights: HashMap<String, Tensor>,
291 biases: HashMap<String, Tensor>,
292 config: ModelQuantizationConfig,
293 device: &Device,
294) -> Result<TernaryModel> {
295 let mut model = TernaryModel::new(config);
296
297 for (name, weight) in weights {
298 let bias = biases.get(&name);
299
300 if let Some(quantized) = quantize_linear_layer(&weight, bias, &name, &model.config, device)?
301 {
302 model.add_layer(quantized.name, quantized.layer, quantized.sparsity);
303 } else {
304 model.add_preserved(format!("{name}.weight"), weight);
306 if let Some(b) = bias {
307 model.add_preserved(format!("{name}.bias"), b.clone());
308 }
309 }
310 }
311
312 model.finalize_stats();
313
314 if model.config.verbose {
315 model.stats.print_summary();
316 }
317
318 Ok(model)
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_quantization_stats() {
327 let mut stats = QuantizationStats {
328 original_bytes: 1000,
329 quantized_bytes: 100,
330 ..Default::default()
331 };
332
333 assert!((stats.compression_ratio() - 10.0).abs() < 0.001);
334 }
335
336 #[test]
337 fn test_model_quantization_config_default() {
338 let config = ModelQuantizationConfig::default();
339 assert_eq!(config.min_layer_size, 1024);
340 assert!(config.skip_patterns.contains(&"embed".to_string()));
341 }
342
343 #[test]
344 fn test_quantize_linear_layer() -> Result<()> {
345 let device = Device::Cpu;
346 let config = ModelQuantizationConfig {
347 min_layer_size: 0, skip_patterns: vec![],
349 ..Default::default()
350 };
351
352 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device)?;
353
354 let result = quantize_linear_layer(&weight, None, "test_layer", &config, &device)?;
355
356 assert!(result.is_some());
357 let quantized = result.unwrap();
358 assert_eq!(quantized.name, "test_layer");
359 assert!(quantized.sparsity >= 0.0 && quantized.sparsity <= 1.0);
360
361 Ok(())
362 }
363
364 #[test]
365 fn test_skip_small_layer() -> Result<()> {
366 let device = Device::Cpu;
367 let config = ModelQuantizationConfig {
368 min_layer_size: 10000, ..Default::default()
370 };
371
372 let weight = Tensor::randn(0.0f32, 1.0, (8, 8), &device)?;
373
374 let result = quantize_linear_layer(&weight, None, "small_layer", &config, &device)?;
375
376 assert!(result.is_none());
377
378 Ok(())
379 }
380
381 #[test]
382 fn test_skip_pattern() -> Result<()> {
383 let device = Device::Cpu;
384 let config = ModelQuantizationConfig::default();
385
386 let weight = Tensor::randn(0.0f32, 1.0, (128, 128), &device)?;
387
388 let result = quantize_linear_layer(&weight, None, "model.embed_tokens", &config, &device)?;
389
390 assert!(result.is_none()); Ok(())
393 }
394
395 #[test]
396 fn test_ternary_model() -> Result<()> {
397 let device = Device::Cpu;
398 let config = ModelQuantizationConfig {
399 min_layer_size: 0,
400 skip_patterns: vec![],
401 verbose: false,
402 ..Default::default()
403 };
404
405 let mut weights = HashMap::new();
406 weights.insert(
407 "layer1".to_string(),
408 Tensor::randn(0.0f32, 1.0, (64, 128), &device)?,
409 );
410 weights.insert(
411 "layer2".to_string(),
412 Tensor::randn(0.0f32, 1.0, (128, 64), &device)?,
413 );
414
415 let model = quantize_weights_collection(weights, HashMap::new(), config, &device)?;
416
417 assert_eq!(model.stats.layers_quantized, 2);
418 assert!(model.get_layer("layer1").is_some());
419 assert!(model.get_layer("layer2").is_some());
420
421 let expected_params = 64 * 128 + 128 * 64;
423 assert_eq!(model.stats.original_params, expected_params);
424 assert_eq!(model.stats.quantized_params, expected_params);
425 assert_eq!(model.stats.original_bytes, expected_params * 4); Ok(())
428 }
429
430 #[test]
431 fn test_accounting_with_preserved() -> Result<()> {
432 let device = Device::Cpu;
433 let config = ModelQuantizationConfig {
434 min_layer_size: 10000, skip_patterns: vec![],
436 verbose: false,
437 ..Default::default()
438 };
439
440 let mut weights = HashMap::new();
441 weights.insert(
443 "large".to_string(),
444 Tensor::randn(0.0f32, 1.0, (256, 256), &device)?,
445 );
446 weights.insert(
448 "small".to_string(),
449 Tensor::randn(0.0f32, 1.0, (8, 8), &device)?,
450 );
451
452 let model = quantize_weights_collection(weights, HashMap::new(), config, &device)?;
453
454 assert_eq!(model.stats.layers_quantized, 1);
455 assert_eq!(model.stats.layers_skipped, 1);
456
457 let large_params = 256 * 256; let small_params = 8 * 8; let total_params = large_params + small_params;
461
462 assert_eq!(model.stats.quantized_params, large_params);
463 assert_eq!(model.stats.original_params, total_params);
464 assert_eq!(model.stats.original_bytes, total_params * 4); Ok(())
467 }
468
469 #[test]
470 fn test_finalize_stats_idempotent() -> Result<()> {
471 let device = Device::Cpu;
472 let config = ModelQuantizationConfig {
473 min_layer_size: 0,
474 skip_patterns: vec![],
475 verbose: false,
476 ..Default::default()
477 };
478
479 let mut weights = HashMap::new();
480 weights.insert(
481 "layer1".to_string(),
482 Tensor::randn(0.0f32, 1.0, (64, 128), &device)?,
483 );
484
485 let mut model = quantize_weights_collection(weights, HashMap::new(), config, &device)?;
486
487 let initial_original_params = model.stats.original_params;
489 let initial_original_bytes = model.stats.original_bytes;
490
491 model.finalize_stats();
493
494 assert_eq!(model.stats.original_params, initial_original_params);
496 assert_eq!(model.stats.original_bytes, initial_original_bytes);
497
498 model.finalize_stats();
500 assert_eq!(model.stats.original_params, initial_original_params);
501
502 Ok(())
503 }
504
505 #[test]
506 fn test_compression_ratio_no_quantization() {
507 let stats = QuantizationStats::default();
508 assert!((stats.compression_ratio() - 1.0).abs() < 0.001);
510 }
511}