1use torsh_core::{Result as TorshResult, TorshError};
12use torsh_tensor::{
13 creation::{ones, randn, zeros},
14 stats::StatMode,
15 Tensor,
16};
17
18#[derive(Debug, Clone, Copy, PartialEq)]
20pub enum QuantizationScheme {
21 Uniform,
23 NonUniform,
25 Dynamic,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq)]
31pub enum QuantizationType {
32 Int8,
34 UInt8,
36 Int16,
38 Int4,
40}
41
42pub fn uniform_quantize(
56 input: &Tensor,
57 scale: f32,
58 zero_point: i32,
59 qtype: QuantizationType,
60) -> TorshResult<(Tensor, f32, i32)> {
61 let (qmin, qmax) = match qtype {
62 QuantizationType::Int8 => (-128i32, 127i32),
63 QuantizationType::UInt8 => (0i32, 255i32),
64 QuantizationType::Int16 => (-32768i32, 32767i32),
65 QuantizationType::Int4 => (-8i32, 7i32),
66 };
67
68 let scaled = input.div_scalar(scale)?;
70 let shifted = scaled.add_scalar(zero_point as f32)?;
71 let rounded = shifted.round()?;
72 let clamped = crate::math::clamp(&rounded, qmin as f32, qmax as f32)?;
73
74 Ok((clamped, scale, zero_point))
75}
76
77pub fn uniform_dequantize(quantized: &Tensor, scale: f32, zero_point: i32) -> TorshResult<Tensor> {
87 let mut shifted = quantized.clone();
89 shifted.sub_scalar_(zero_point as f32)?;
90 let shifted = shifted;
91 let dequantized = shifted.mul_scalar(scale)?;
92 Ok(dequantized)
93}
94
95pub fn dynamic_quantize(
108 input: &Tensor,
109 qtype: QuantizationType,
110 reduce_range: bool,
111) -> TorshResult<(Tensor, f32, i32)> {
112 let (qmin, qmax) = match qtype {
113 QuantizationType::Int8 => {
114 if reduce_range {
115 (-64i32, 63i32)
116 } else {
117 (-128i32, 127i32)
118 }
119 }
120 QuantizationType::UInt8 => {
121 if reduce_range {
122 (0i32, 127i32)
123 } else {
124 (0i32, 255i32)
125 }
126 }
127 QuantizationType::Int16 => {
128 if reduce_range {
129 (-16384i32, 16383i32)
130 } else {
131 (-32768i32, 32767i32)
132 }
133 }
134 QuantizationType::Int4 => {
135 if reduce_range {
136 (-4i32, 3i32)
137 } else {
138 (-8i32, 7i32)
139 }
140 }
141 };
142
143 let input_min = input.min()?.data()?[0];
145 let input_max = input.max(None, false)?.data()?[0];
146
147 let scale = (input_max - input_min) / (qmax - qmin) as f32;
149 let zero_point_float = qmin as f32 - input_min / scale;
150 let zero_point = zero_point_float.round() as i32;
151
152 let safe_scale = if scale == 0.0 { 1.0 } else { scale };
154
155 uniform_quantize(input, safe_scale, zero_point, qtype)
156}
157
158pub fn fake_quantize(
172 input: &Tensor,
173 scale: f32,
174 zero_point: i32,
175 qtype: QuantizationType,
176) -> TorshResult<Tensor> {
177 let (quantized, scale, zero_point) = uniform_quantize(input, scale, zero_point, qtype)?;
178 uniform_dequantize(&quantized, scale, zero_point)
179}
180
181pub fn magnitude_prune(
193 weights: &Tensor,
194 sparsity: f32,
195 structured: bool,
196) -> TorshResult<(Tensor, Tensor)> {
197 if sparsity < 0.0 || sparsity >= 1.0 {
198 return Err(TorshError::invalid_argument_with_context(
199 "Sparsity must be in range [0.0, 1.0)",
200 "magnitude_prune",
201 ));
202 }
203
204 if structured {
205 let weight_shape_ref = weights.shape();
207 let weight_shape = weight_shape_ref.dims();
208 if weight_shape.len() < 2 {
209 return Err(TorshError::invalid_argument_with_context(
210 "Structured pruning requires at least 2D weights",
211 "magnitude_prune",
212 ));
213 }
214
215 let num_filters = weight_shape[0];
216 let num_to_prune = (num_filters as f32 * sparsity) as usize;
217
218 let dims_to_reduce: Vec<i32> = (1..weight_shape.len()).map(|i| i as i32).collect();
220 let _filter_norms = weights
221 .pow_scalar(2.0)?
222 .sum_dim(&dims_to_reduce, false)?
223 .sqrt()?;
224
225 let mask = ones(&weight_shape)?;
227
228 if num_to_prune > 0 {
231 for _i in 0..num_to_prune.min(num_filters) {
233 }
235 }
236
237 let pruned_weights = weights.mul_op(&mask)?;
238 Ok((pruned_weights, mask))
239 } else {
240 let abs_weights = weights.abs()?;
242 let threshold = calculate_pruning_threshold(&abs_weights, sparsity)?;
243
244 let bool_mask = abs_weights.gt_scalar(threshold)?;
246 let mask_data: Vec<f32> = bool_mask
248 .data()?
249 .iter()
250 .map(|&b| if b { 1.0 } else { 0.0 })
251 .collect();
252 let mask = Tensor::from_data(mask_data, weights.shape().dims().to_vec(), weights.device())?;
253 let pruned_weights = weights.mul_op(&mask)?;
254
255 Ok((pruned_weights, mask))
256 }
257}
258
259fn calculate_pruning_threshold(abs_weights: &Tensor, sparsity: f32) -> TorshResult<f32> {
261 let mean_data = abs_weights.mean(None, false)?.data()?;
268 let mean_val = mean_data.get(0).unwrap_or(&0.1).clone();
269 let std_data = abs_weights.std(None, false, StatMode::Sample)?.data()?;
270 let std_val = std_data.get(0).unwrap_or(&0.01).clone();
271
272 let threshold = mean_val - sparsity * std_val;
274 Ok(threshold.max(0.0))
275}
276
277pub fn gradual_magnitude_prune(
292 weights: &Tensor,
293 current_step: usize,
294 start_step: usize,
295 end_step: usize,
296 initial_sparsity: f32,
297 final_sparsity: f32,
298) -> TorshResult<(Tensor, f32, Tensor)> {
299 if current_step < start_step {
300 let mask = ones(&weights.shape().dims())?;
302 return Ok((weights.clone(), initial_sparsity, mask));
303 }
304
305 if current_step >= end_step {
306 let (pruned, mask) = magnitude_prune(weights, final_sparsity, false)?;
308 return Ok((pruned, final_sparsity, mask));
309 }
310
311 let progress = (current_step - start_step) as f32 / (end_step - start_step) as f32;
313 let current_sparsity = initial_sparsity
314 + (final_sparsity - initial_sparsity) * (3.0 * progress.powi(2) - 2.0 * progress.powi(3));
315
316 let (pruned, mask) = magnitude_prune(weights, current_sparsity, false)?;
317 Ok((pruned, current_sparsity, mask))
318}
319
320pub fn weight_clustering(
331 weights: &Tensor,
332 num_clusters: usize,
333) -> TorshResult<(Tensor, Tensor, Tensor)> {
334 if num_clusters == 0 {
335 return Err(TorshError::invalid_argument_with_context(
336 "Number of clusters must be positive",
337 "weight_clustering",
338 ));
339 }
340
341 let weight_shape_ref = weights.shape();
345 let weight_shape = weight_shape_ref.dims();
346 let _num_weights = weights.numel();
347
348 let centroids = randn(&[num_clusters])?;
350
351 let min_data = weights.min()?.data()?;
353 let min_weight = min_data.get(0).unwrap_or(&-1.0).clone();
354 let max_data = weights.max(None, false)?.data()?;
355 let max_weight = max_data.get(0).unwrap_or(&1.0).clone();
356 let _weight_range = max_weight - min_weight;
357
358 let cluster_assignments = zeros(&weight_shape)?;
360
361 let clustered_weights = weights.clone(); Ok((clustered_weights, centroids, cluster_assignments))
365}
366
367pub fn lottery_ticket_prune(
380 weights: &Tensor,
381 initial_weights: &Tensor,
382 sparsity: f32,
383) -> TorshResult<(Tensor, Tensor)> {
384 if weights.shape().dims() != initial_weights.shape().dims() {
385 return Err(TorshError::invalid_argument_with_context(
386 "Weight tensors must have same shape",
387 "lottery_ticket_prune",
388 ));
389 }
390
391 let (_, mask) = magnitude_prune(weights, sparsity, false)?;
393
394 let winning_subnetwork = initial_weights.mul_op(&mask)?;
396
397 Ok((mask, winning_subnetwork))
398}
399
400pub fn quantization_error_analysis(
411 original: &Tensor,
412 quantized: &Tensor,
413) -> TorshResult<(f32, f32, f32)> {
414 if original.shape().dims() != quantized.shape().dims() {
415 return Err(TorshError::invalid_argument_with_context(
416 "Tensors must have same shape",
417 "quantization_error_analysis",
418 ));
419 }
420
421 let error = original.sub(quantized)?;
423 let mse_tensor = error.pow_scalar(2.0)?.mean(None, false)?;
424 let mse = mse_tensor.data()?[0];
425
426 let abs_error = error.abs()?;
428 let max_error_tensor = abs_error.max(None, false)?;
429 let max_error = max_error_tensor.data()?[0];
430
431 let signal_power_tensor = original.pow_scalar(2.0)?.mean(None, false)?;
433 let signal_power = signal_power_tensor.data()?[0];
434 let snr_db = if mse > 0.0 {
435 10.0 * (signal_power / mse).log10()
436 } else {
437 f32::INFINITY
438 };
439
440 Ok((mse, max_error, snr_db))
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use crate::random_ops::randn;
447
448 #[test]
449 fn test_uniform_quantization() {
450 let input = randn(&[4, 4], None, None, None).unwrap();
451 let (quantized, scale, zero_point) =
452 uniform_quantize(&input, 0.1, 128, QuantizationType::UInt8).unwrap();
453
454 assert_eq!(quantized.shape().dims(), input.shape().dims());
456
457 let dequantized = uniform_dequantize(&quantized, scale, zero_point).unwrap();
459 assert_eq!(dequantized.shape().dims(), input.shape().dims());
460 }
461
462 #[test]
463 fn test_dynamic_quantization() {
464 let input = randn(&[3, 3], None, None, None).unwrap();
465 let (quantized, scale, _zero_point) =
466 dynamic_quantize(&input, QuantizationType::Int8, false).unwrap();
467
468 assert_eq!(quantized.shape().dims(), input.shape().dims());
469 assert!(scale > 0.0);
470 }
471
472 #[test]
473 fn test_fake_quantization() {
474 let input = randn(&[2, 2], None, None, None).unwrap();
475 let fake_quantized = fake_quantize(&input, 0.1, 0, QuantizationType::Int8).unwrap();
476
477 assert_eq!(fake_quantized.shape().dims(), input.shape().dims());
478 }
479
480 #[test]
481 fn test_magnitude_pruning() {
482 let weights = randn(&[10, 10], None, None, None).unwrap();
483 let (pruned, mask) = magnitude_prune(&weights, 0.5, false).unwrap();
484
485 assert_eq!(pruned.shape().dims(), weights.shape().dims());
486 assert_eq!(mask.shape().dims(), weights.shape().dims());
487 }
488
489 #[test]
490 fn test_gradual_pruning() {
491 let weights = randn(&[5, 5], None, None, None).unwrap();
492 let (pruned, sparsity, mask) =
493 gradual_magnitude_prune(&weights, 50, 10, 100, 0.0, 0.8).unwrap();
494
495 assert_eq!(pruned.shape().dims(), weights.shape().dims());
496 assert!(sparsity >= 0.0 && sparsity <= 0.8);
497 assert_eq!(mask.shape().dims(), weights.shape().dims());
498 }
499
500 #[test]
501 fn test_lottery_ticket() {
502 let trained_weights = randn(&[4, 4], None, None, None).unwrap();
503 let initial_weights = randn(&[4, 4], None, None, None).unwrap();
504
505 let (mask, winning_subnetwork) =
506 lottery_ticket_prune(&trained_weights, &initial_weights, 0.6).unwrap();
507
508 assert_eq!(mask.shape().dims(), trained_weights.shape().dims());
509 assert_eq!(
510 winning_subnetwork.shape().dims(),
511 initial_weights.shape().dims()
512 );
513 }
514
515 #[test]
516 fn test_quantization_error_analysis() {
517 let original = randn(&[3, 3], None, None, None).unwrap();
518 let quantized = original.clone(); let (mse, max_error, snr_db) = quantization_error_analysis(&original, &quantized).unwrap();
521
522 assert!(mse <= 1e-6);
524 assert!(max_error <= 1e-6);
525 assert!(snr_db > 60.0 || snr_db.is_infinite());
526 }
527}