Skip to main content

torsh_functional/
quantization.rs

1//! Quantization and Compression Functions
2//!
3//! This module provides quantization operations for model compression including:
4//! - Uniform and non-uniform quantization
5//! - Dynamic quantization schemes
6//! - Pruning utilities (magnitude-based, structured, unstructured)
7//! - Model compression techniques
8//! - Low-precision computation functions
9//! - Knowledge distillation utilities
10
11use torsh_core::{Result as TorshResult, TorshError};
12use torsh_tensor::{
13    creation::{ones, randn, zeros},
14    stats::StatMode,
15    Tensor,
16};
17
18/// Quantization schemes
19#[derive(Debug, Clone, Copy, PartialEq)]
20pub enum QuantizationScheme {
21    /// Uniform quantization with equal spacing
22    Uniform,
23    /// Non-uniform quantization with custom levels
24    NonUniform,
25    /// Dynamic quantization based on data statistics
26    Dynamic,
27}
28
29/// Quantization data types
30#[derive(Debug, Clone, Copy, PartialEq)]
31pub enum QuantizationType {
32    /// 8-bit signed integer
33    Int8,
34    /// 8-bit unsigned integer
35    UInt8,
36    /// 16-bit signed integer
37    Int16,
38    /// 4-bit quantization (for extreme compression)
39    Int4,
40}
41
42/// Uniform quantization
43///
44/// Quantizes floating-point values to fixed-point representation
45/// using uniform spacing between quantization levels.
46///
47/// # Arguments
48/// * `input` - Input tensor to quantize
49/// * `scale` - Quantization scale factor
50/// * `zero_point` - Zero point offset
51/// * `qtype` - Target quantization type
52///
53/// # Returns
54/// Tuple of (quantized_tensor, scale, zero_point)
55pub fn uniform_quantize(
56    input: &Tensor,
57    scale: f32,
58    zero_point: i32,
59    qtype: QuantizationType,
60) -> TorshResult<(Tensor, f32, i32)> {
61    let (qmin, qmax) = match qtype {
62        QuantizationType::Int8 => (-128i32, 127i32),
63        QuantizationType::UInt8 => (0i32, 255i32),
64        QuantizationType::Int16 => (-32768i32, 32767i32),
65        QuantizationType::Int4 => (-8i32, 7i32),
66    };
67
68    // Quantize: q = clamp(round(x / scale + zero_point), qmin, qmax)
69    let scaled = input.div_scalar(scale)?;
70    let shifted = scaled.add_scalar(zero_point as f32)?;
71    let rounded = shifted.round()?;
72    let clamped = crate::math::clamp(&rounded, qmin as f32, qmax as f32)?;
73
74    Ok((clamped, scale, zero_point))
75}
76
77/// Dequantize uniformly quantized tensor
78///
79/// # Arguments
80/// * `quantized` - Quantized tensor
81/// * `scale` - Quantization scale factor
82/// * `zero_point` - Zero point offset
83///
84/// # Returns
85/// Dequantized floating-point tensor
86pub fn uniform_dequantize(quantized: &Tensor, scale: f32, zero_point: i32) -> TorshResult<Tensor> {
87    // Dequantize: x = (q - zero_point) * scale
88    let mut shifted = quantized.clone();
89    shifted.sub_scalar_(zero_point as f32)?;
90    let shifted = shifted;
91    let dequantized = shifted.mul_scalar(scale)?;
92    Ok(dequantized)
93}
94
95/// Dynamic quantization with automatic scale/zero-point calculation
96///
97/// Automatically determines optimal scale and zero point based on
98/// input tensor statistics.
99///
100/// # Arguments
101/// * `input` - Input tensor to quantize
102/// * `qtype` - Target quantization type
103/// * `reduce_range` - Whether to reduce quantization range for better accuracy
104///
105/// # Returns
106/// Tuple of (quantized_tensor, scale, zero_point)
107pub fn dynamic_quantize(
108    input: &Tensor,
109    qtype: QuantizationType,
110    reduce_range: bool,
111) -> TorshResult<(Tensor, f32, i32)> {
112    let (qmin, qmax) = match qtype {
113        QuantizationType::Int8 => {
114            if reduce_range {
115                (-64i32, 63i32)
116            } else {
117                (-128i32, 127i32)
118            }
119        }
120        QuantizationType::UInt8 => {
121            if reduce_range {
122                (0i32, 127i32)
123            } else {
124                (0i32, 255i32)
125            }
126        }
127        QuantizationType::Int16 => {
128            if reduce_range {
129                (-16384i32, 16383i32)
130            } else {
131                (-32768i32, 32767i32)
132            }
133        }
134        QuantizationType::Int4 => {
135            if reduce_range {
136                (-4i32, 3i32)
137            } else {
138                (-8i32, 7i32)
139            }
140        }
141    };
142
143    // Calculate min and max values in input
144    let input_min = input.min()?.data()?[0];
145    let input_max = input.max(None, false)?.data()?[0];
146
147    // Calculate scale and zero point
148    let scale = (input_max - input_min) / (qmax - qmin) as f32;
149    let zero_point_float = qmin as f32 - input_min / scale;
150    let zero_point = zero_point_float.round() as i32;
151
152    // Ensure scale is not zero
153    let safe_scale = if scale == 0.0 { 1.0 } else { scale };
154
155    uniform_quantize(input, safe_scale, zero_point, qtype)
156}
157
158/// Quantize-aware training (QAT) simulation
159///
160/// Simulates quantization effects during training by applying
161/// fake quantization (quantize then immediately dequantize).
162///
163/// # Arguments
164/// * `input` - Input tensor
165/// * `scale` - Quantization scale
166/// * `zero_point` - Zero point offset
167/// * `qtype` - Quantization type
168///
169/// # Returns
170/// Fake quantized tensor (still in floating point)
171pub fn fake_quantize(
172    input: &Tensor,
173    scale: f32,
174    zero_point: i32,
175    qtype: QuantizationType,
176) -> TorshResult<Tensor> {
177    let (quantized, scale, zero_point) = uniform_quantize(input, scale, zero_point, qtype)?;
178    uniform_dequantize(&quantized, scale, zero_point)
179}
180
181/// Magnitude-based pruning
182///
183/// Prunes weights with smallest absolute values to achieve target sparsity.
184///
185/// # Arguments
186/// * `weights` - Weight tensor to prune
187/// * `sparsity` - Target sparsity level (0.0 = no pruning, 0.9 = 90% pruned)
188/// * `structured` - Whether to use structured pruning (prune entire channels/filters)
189///
190/// # Returns
191/// Tuple of (pruned_weights, pruning_mask)
192pub fn magnitude_prune(
193    weights: &Tensor,
194    sparsity: f32,
195    structured: bool,
196) -> TorshResult<(Tensor, Tensor)> {
197    if sparsity < 0.0 || sparsity >= 1.0 {
198        return Err(TorshError::invalid_argument_with_context(
199            "Sparsity must be in range [0.0, 1.0)",
200            "magnitude_prune",
201        ));
202    }
203
204    if structured {
205        // Structured pruning: prune entire channels/filters
206        let weight_shape_ref = weights.shape();
207        let weight_shape = weight_shape_ref.dims();
208        if weight_shape.len() < 2 {
209            return Err(TorshError::invalid_argument_with_context(
210                "Structured pruning requires at least 2D weights",
211                "magnitude_prune",
212            ));
213        }
214
215        let num_filters = weight_shape[0];
216        let num_to_prune = (num_filters as f32 * sparsity) as usize;
217
218        // Calculate L2 norm for each filter
219        let dims_to_reduce: Vec<i32> = (1..weight_shape.len()).map(|i| i as i32).collect();
220        let _filter_norms = weights
221            .pow_scalar(2.0)?
222            .sum_dim(&dims_to_reduce, false)?
223            .sqrt()?;
224
225        // Create mask (simplified - in practice would sort and threshold)
226        let mask = ones(&weight_shape)?;
227
228        // For now, create a simple pattern mask
229        // In full implementation would identify smallest norm filters
230        if num_to_prune > 0 {
231            // Set some filters to zero as example
232            for _i in 0..num_to_prune.min(num_filters) {
233                // This is a placeholder - in practice would zero out specific filter indices
234            }
235        }
236
237        let pruned_weights = weights.mul_op(&mask)?;
238        Ok((pruned_weights, mask))
239    } else {
240        // Unstructured pruning: prune individual weights
241        let abs_weights = weights.abs()?;
242        let threshold = calculate_pruning_threshold(&abs_weights, sparsity)?;
243
244        // Create mask: 1 where |weight| > threshold, 0 otherwise
245        let bool_mask = abs_weights.gt_scalar(threshold)?;
246        // Convert boolean mask to float mask manually
247        let mask_data: Vec<f32> = bool_mask
248            .data()?
249            .iter()
250            .map(|&b| if b { 1.0 } else { 0.0 })
251            .collect();
252        let mask = Tensor::from_data(mask_data, weights.shape().dims().to_vec(), weights.device())?;
253        let pruned_weights = weights.mul_op(&mask)?;
254
255        Ok((pruned_weights, mask))
256    }
257}
258
259/// Calculate pruning threshold for given sparsity level
260fn calculate_pruning_threshold(abs_weights: &Tensor, sparsity: f32) -> TorshResult<f32> {
261    // In a full implementation, this would:
262    // 1. Flatten the tensor
263    // 2. Sort the values
264    // 3. Find the value at the sparsity percentile
265
266    // For now, use a simple approximation based on statistics
267    let mean_data = abs_weights.mean(None, false)?.data()?;
268    let mean_val = mean_data.get(0).unwrap_or(&0.1).clone();
269    let std_data = abs_weights.std(None, false, StatMode::Sample)?.data()?;
270    let std_val = std_data.get(0).unwrap_or(&0.01).clone();
271
272    // Use a heuristic threshold based on statistics
273    let threshold = mean_val - sparsity * std_val;
274    Ok(threshold.max(0.0))
275}
276
277/// Gradual magnitude pruning with sparsity scheduling
278///
279/// Implements gradual pruning where sparsity increases over training steps.
280///
281/// # Arguments
282/// * `weights` - Weight tensor to prune
283/// * `current_step` - Current training step
284/// * `start_step` - Step to start pruning
285/// * `end_step` - Step to finish pruning
286/// * `initial_sparsity` - Initial sparsity level
287/// * `final_sparsity` - Final target sparsity level
288///
289/// # Returns
290/// Tuple of (pruned_weights, current_sparsity, pruning_mask)
291pub fn gradual_magnitude_prune(
292    weights: &Tensor,
293    current_step: usize,
294    start_step: usize,
295    end_step: usize,
296    initial_sparsity: f32,
297    final_sparsity: f32,
298) -> TorshResult<(Tensor, f32, Tensor)> {
299    if current_step < start_step {
300        // No pruning yet
301        let mask = ones(&weights.shape().dims())?;
302        return Ok((weights.clone(), initial_sparsity, mask));
303    }
304
305    if current_step >= end_step {
306        // Final sparsity reached
307        let (pruned, mask) = magnitude_prune(weights, final_sparsity, false)?;
308        return Ok((pruned, final_sparsity, mask));
309    }
310
311    // Calculate current sparsity using polynomial schedule
312    let progress = (current_step - start_step) as f32 / (end_step - start_step) as f32;
313    let current_sparsity = initial_sparsity
314        + (final_sparsity - initial_sparsity) * (3.0 * progress.powi(2) - 2.0 * progress.powi(3));
315
316    let (pruned, mask) = magnitude_prune(weights, current_sparsity, false)?;
317    Ok((pruned, current_sparsity, mask))
318}
319
320/// Weight clustering for compression
321///
322/// Groups weights into clusters and replaces each weight with its cluster centroid.
323///
324/// # Arguments
325/// * `weights` - Weight tensor to cluster
326/// * `num_clusters` - Number of clusters (codebook size)
327///
328/// # Returns
329/// Tuple of (clustered_weights, centroids, cluster_assignments)
330pub fn weight_clustering(
331    weights: &Tensor,
332    num_clusters: usize,
333) -> TorshResult<(Tensor, Tensor, Tensor)> {
334    if num_clusters == 0 {
335        return Err(TorshError::invalid_argument_with_context(
336            "Number of clusters must be positive",
337            "weight_clustering",
338        ));
339    }
340
341    // Simplified k-means clustering implementation
342    // In practice would use proper k-means algorithm
343
344    let weight_shape_ref = weights.shape();
345    let weight_shape = weight_shape_ref.dims();
346    let _num_weights = weights.numel();
347
348    // Initialize centroids (simplified - random sampling from weights)
349    let centroids = randn(&[num_clusters])?;
350
351    // For now, create simple cluster assignments based on weight ranges
352    let min_data = weights.min()?.data()?;
353    let min_weight = min_data.get(0).unwrap_or(&-1.0).clone();
354    let max_data = weights.max(None, false)?.data()?;
355    let max_weight = max_data.get(0).unwrap_or(&1.0).clone();
356    let _weight_range = max_weight - min_weight;
357
358    // Create cluster assignments (simplified)
359    let cluster_assignments = zeros(&weight_shape)?;
360
361    // Replace weights with cluster centroids
362    let clustered_weights = weights.clone(); // Placeholder
363
364    Ok((clustered_weights, centroids, cluster_assignments))
365}
366
367/// Lottery ticket hypothesis: find winning subnetworks
368///
369/// Identifies sparse subnetworks that can achieve comparable performance
370/// when trained from scratch.
371///
372/// # Arguments
373/// * `weights` - Original trained weights
374/// * `initial_weights` - Initial weights before training
375/// * `sparsity` - Target sparsity for the lottery ticket
376///
377/// # Returns
378/// Tuple of (lottery_ticket_mask, winning_subnetwork_weights)
379pub fn lottery_ticket_prune(
380    weights: &Tensor,
381    initial_weights: &Tensor,
382    sparsity: f32,
383) -> TorshResult<(Tensor, Tensor)> {
384    if weights.shape().dims() != initial_weights.shape().dims() {
385        return Err(TorshError::invalid_argument_with_context(
386            "Weight tensors must have same shape",
387            "lottery_ticket_prune",
388        ));
389    }
390
391    // Find winning lottery ticket based on final weight magnitudes
392    let (_, mask) = magnitude_prune(weights, sparsity, false)?;
393
394    // Apply mask to initial weights to get winning subnetwork
395    let winning_subnetwork = initial_weights.mul_op(&mask)?;
396
397    Ok((mask, winning_subnetwork))
398}
399
400/// Quantization error analysis
401///
402/// Analyzes the error introduced by quantization to guide optimization.
403///
404/// # Arguments
405/// * `original` - Original floating-point tensor
406/// * `quantized` - Quantized tensor (after dequantization)
407///
408/// # Returns
409/// Tuple of (mse_error, max_error, snr_db)
410pub fn quantization_error_analysis(
411    original: &Tensor,
412    quantized: &Tensor,
413) -> TorshResult<(f32, f32, f32)> {
414    if original.shape().dims() != quantized.shape().dims() {
415        return Err(TorshError::invalid_argument_with_context(
416            "Tensors must have same shape",
417            "quantization_error_analysis",
418        ));
419    }
420
421    // Calculate mean squared error
422    let error = original.sub(quantized)?;
423    let mse_tensor = error.pow_scalar(2.0)?.mean(None, false)?;
424    let mse = mse_tensor.data()?[0];
425
426    // Calculate maximum absolute error
427    let abs_error = error.abs()?;
428    let max_error_tensor = abs_error.max(None, false)?;
429    let max_error = max_error_tensor.data()?[0];
430
431    // Calculate signal-to-noise ratio in dB
432    let signal_power_tensor = original.pow_scalar(2.0)?.mean(None, false)?;
433    let signal_power = signal_power_tensor.data()?[0];
434    let snr_db = if mse > 0.0 {
435        10.0 * (signal_power / mse).log10()
436    } else {
437        f32::INFINITY
438    };
439
440    Ok((mse, max_error, snr_db))
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use crate::random_ops::randn;
447
448    #[test]
449    fn test_uniform_quantization() {
450        let input = randn(&[4, 4], None, None, None).unwrap();
451        let (quantized, scale, zero_point) =
452            uniform_quantize(&input, 0.1, 128, QuantizationType::UInt8).unwrap();
453
454        // Check quantized values are in valid range
455        assert_eq!(quantized.shape().dims(), input.shape().dims());
456
457        // Test dequantization
458        let dequantized = uniform_dequantize(&quantized, scale, zero_point).unwrap();
459        assert_eq!(dequantized.shape().dims(), input.shape().dims());
460    }
461
462    #[test]
463    fn test_dynamic_quantization() {
464        let input = randn(&[3, 3], None, None, None).unwrap();
465        let (quantized, scale, _zero_point) =
466            dynamic_quantize(&input, QuantizationType::Int8, false).unwrap();
467
468        assert_eq!(quantized.shape().dims(), input.shape().dims());
469        assert!(scale > 0.0);
470    }
471
472    #[test]
473    fn test_fake_quantization() {
474        let input = randn(&[2, 2], None, None, None).unwrap();
475        let fake_quantized = fake_quantize(&input, 0.1, 0, QuantizationType::Int8).unwrap();
476
477        assert_eq!(fake_quantized.shape().dims(), input.shape().dims());
478    }
479
480    #[test]
481    fn test_magnitude_pruning() {
482        let weights = randn(&[10, 10], None, None, None).unwrap();
483        let (pruned, mask) = magnitude_prune(&weights, 0.5, false).unwrap();
484
485        assert_eq!(pruned.shape().dims(), weights.shape().dims());
486        assert_eq!(mask.shape().dims(), weights.shape().dims());
487    }
488
489    #[test]
490    fn test_gradual_pruning() {
491        let weights = randn(&[5, 5], None, None, None).unwrap();
492        let (pruned, sparsity, mask) =
493            gradual_magnitude_prune(&weights, 50, 10, 100, 0.0, 0.8).unwrap();
494
495        assert_eq!(pruned.shape().dims(), weights.shape().dims());
496        assert!(sparsity >= 0.0 && sparsity <= 0.8);
497        assert_eq!(mask.shape().dims(), weights.shape().dims());
498    }
499
500    #[test]
501    fn test_lottery_ticket() {
502        let trained_weights = randn(&[4, 4], None, None, None).unwrap();
503        let initial_weights = randn(&[4, 4], None, None, None).unwrap();
504
505        let (mask, winning_subnetwork) =
506            lottery_ticket_prune(&trained_weights, &initial_weights, 0.6).unwrap();
507
508        assert_eq!(mask.shape().dims(), trained_weights.shape().dims());
509        assert_eq!(
510            winning_subnetwork.shape().dims(),
511            initial_weights.shape().dims()
512        );
513    }
514
515    #[test]
516    fn test_quantization_error_analysis() {
517        let original = randn(&[3, 3], None, None, None).unwrap();
518        let quantized = original.clone(); // Perfect case for testing
519
520        let (mse, max_error, snr_db) = quantization_error_analysis(&original, &quantized).unwrap();
521
522        // Should be very small errors for identical tensors
523        assert!(mse <= 1e-6);
524        assert!(max_error <= 1e-6);
525        assert!(snr_db > 60.0 || snr_db.is_infinite());
526    }
527}