Skip to main content

torsh_quantization/
lib.rs

1//! # ToRSh Quantization Library
2//!
3//! A comprehensive quantization library for deep learning tensor operations, providing
4//! state-of-the-art quantization algorithms, configuration management, performance
5//! metrics, and utility functions.
6//!
7//! ## Key Features
8//!
9//! - **Multiple Quantization Schemes**: INT8, INT4, binary, ternary, group-wise quantization
10//! - **Advanced Observers**: MinMax, Histogram, Percentile, MovingAverage calibration
11//! - **Backend Support**: Native, FBGEMM, QNNPACK for optimized execution
12//! - **Comprehensive Metrics**: PSNR, SNR, compression ratio analysis
13//! - **Configuration Tools**: Builder patterns, validation, JSON serialization
14//! - **Utility Functions**: Batch processing, error diagnostics, auto-calibration
15//!
16//! ## Architecture
17//!
18//! The library is organized into specialized modules:
19//!
20//! - **config**: Configuration types and builder patterns
21//! - **algorithms**: Core quantization and dequantization algorithms
22//! - **observers**: Calibration system for parameter estimation
23//! - **specialized**: Advanced algorithms (INT4, binary, ternary, group-wise)
24//! - **metrics**: Performance analysis and benchmarking tools
25//! - **utils**: Utility functions for validation, batch processing, and reporting
26//!
27//! ## Quick Start
28//!
29//! ```rust
30//! use torsh_quantization::{QuantConfig, quantize_with_config};
31//! use torsh_tensor::creation::tensor_1d;
32//!
33//! // Create a simple quantization configuration
34//! let config = QuantConfig::int8();
35//!
36//! // Create a tensor to quantize
37//! let data = vec![0.0, 1.0, 2.0, 3.0];
38//! let tensor = tensor_1d(&data).unwrap();
39//!
40//! // Quantize the tensor
41//! let (quantized, scale, zero_point) = quantize_with_config(&tensor, &config).unwrap();
42//! ```
43//!
44//! ## Advanced Usage
45//!
46//! ### Custom Configuration
47//!
48//! ```rust
49//! use torsh_quantization::{QuantConfig, ObserverType, QuantBackend};
50//!
51//! let config = QuantConfig::int8()
52//!     .with_observer(ObserverType::Histogram)
53//!     .with_backend(QuantBackend::Fbgemm);
54//! ```
55//!
56//! ### Batch Processing
57//!
58//! ```rust
59//! use torsh_quantization::{quantize_batch_consistent, QuantConfig};
60//! use torsh_tensor::creation::tensor_1d;
61//!
62//! let tensor1 = tensor_1d(&[0.0, 1.0, 2.0]).unwrap();
63//! let tensor2 = tensor_1d(&[1.0, 2.0, 3.0]).unwrap();
64//! let tensor3 = tensor_1d(&[2.0, 3.0, 4.0]).unwrap();
65//! let tensors = vec![&tensor1, &tensor2, &tensor3];
66//! let config = QuantConfig::int8();
67//! let results = quantize_batch_consistent(&tensors, &config).unwrap();
68//! ```
69//!
70//! ### Performance Analysis
71//!
72//! ```rust
73//! use torsh_quantization::{compare_quantization_configs, QuantConfig};
74//! use torsh_tensor::creation::tensor_1d;
75//!
76//! let tensor = tensor_1d(&[0.0, 1.0, 2.0, 3.0]).unwrap();
77//! let configs = vec![
78//!     QuantConfig::int8(),
79//!     QuantConfig::int4(),
80//!     QuantConfig::per_channel(0),
81//! ];
82//! let comparison = compare_quantization_configs(&tensor, &configs).unwrap();
83//! ```
84//!
85//! ## Export Support
86//!
87//! The library supports exporting quantized models to various formats:
88//! - **ONNX**: Industry-standard format for cross-platform deployment
89//! - **TensorRT**: NVIDIA's high-performance inference engine
90//! - **TensorFlow Lite**: Mobile and edge deployment
91//! - **Core ML**: Apple's machine learning framework
92//! - **Custom formats**: Extensible architecture for new backends
93
94// ============================================================================
95// Core Quantization Infrastructure
96// ============================================================================
97
98/// Core configuration types and builders
99pub mod config;
100pub use config::*;
101
102/// Core quantization algorithms
103pub mod algorithms;
104pub use algorithms::*;
105
106/// Observer system for calibration
107pub mod observers;
108pub use observers::*;
109
110// ============================================================================
111// Quantization Schemes and Techniques
112// ============================================================================
113
114/// Specialized quantization schemes (INT4, binary, ternary, group-wise)
115pub mod specialized;
116pub use specialized::*;
117
118// ============================================================================
119// Analysis and Performance
120// ============================================================================
121
122/// Performance metrics and analysis
123pub mod metrics;
124pub use metrics::*;
125
126/// Advanced analysis tools
127pub mod analysis;
128pub use analysis::*;
129
130/// Memory pool management
131pub mod memory_pool;
132pub use memory_pool::*;
133
134/// SIMD-accelerated operations
135pub mod simd_ops;
136// Selective re-export to avoid ambiguity with auto_config::TensorStats
137pub use simd_ops::{
138    calculate_tensor_stats_simd, dequantize_per_tensor_affine_simd, find_min_max_simd,
139    get_mobile_optimization_hints, get_simd_width, is_simd_available,
140    quantize_batch_consistent_simd, quantize_mobile_optimized, quantize_per_channel_simd,
141    quantize_per_tensor_affine_simd, quantize_to_int8_simd, MobileOptimizationHints,
142    TensorStats as SimdTensorStats,
143};
144
145// ARM NEON-specific operations (only available on aarch64)
146#[cfg(target_arch = "aarch64")]
147pub use simd_ops::{find_min_max_neon, quantize_neon_optimized};
148
149// ============================================================================
150// Advanced and Research Features
151// ============================================================================
152
153/// Quantum-inspired quantization
154pub mod quantum;
155pub use quantum::*;
156
157/// Enhanced quantum-inspired quantization
158pub mod quantum_enhanced;
159pub use quantum_enhanced::*;
160
161/// Comprehensive benchmark suite
162pub mod benchmarks;
163pub use benchmarks::{
164    BaselineMetrics, BenchmarkConfig as SuiteBenchmarkConfig,
165    BenchmarkResult as SuiteBenchmarkResult, HardwareInfo, QuantizationBenchmarkSuite,
166};
167
168// ============================================================================
169// Utility Functions
170// ============================================================================
171
172/// Utility functions and helpers
173pub mod utils;
174pub use utils::*;
175
176/// ML-powered auto-configuration system
177pub mod auto_config;
178pub use auto_config::*;
179
180// ============================================================================
181// Additional Modules (Advanced - May require fixes)
182// ============================================================================
183// The following modules are available but may have internal compilation issues
184// or require additional dependencies. They are exposed for advanced users.
185
186/// Quantization operations (high-level API)
187#[cfg(feature = "experimental")]
188pub mod quantize;
189
190/// Dequantization operations
191#[cfg(feature = "experimental")]
192pub mod dequantize;
193
194/// Advanced quantization techniques
195#[cfg(feature = "experimental")]
196pub mod advanced;
197
198/// Compression techniques (sub-byte, vector, sparse)
199#[cfg(feature = "experimental")]
200pub mod compression;
201
202/// Fake quantization for QAT
203#[cfg(feature = "experimental")]
204pub mod fake_quantize;
205
206/// Quantization-aware training (QAT)
207#[cfg(feature = "experimental")]
208pub mod qat;
209
210/// Post-training quantization (PTQ)
211#[cfg(feature = "experimental")]
212pub mod post_training;
213
214/// Quantization optimizer
215#[cfg(feature = "experimental")]
216pub mod optimizer;
217
218/// Real-time adaptive quantization
219#[cfg(feature = "experimental")]
220pub mod realtime_adaptive;
221
222/// Hardware-optimized backends
223#[cfg(feature = "experimental")]
224pub mod hardware;
225
226/// Operation fusion for performance
227#[cfg(feature = "experimental")]
228pub mod fusion;
229
230/// Performance profiling
231#[cfg(feature = "experimental")]
232pub mod profiler;
233
234/// Debugging utilities
235#[cfg(feature = "experimental")]
236pub mod debugging;
237
238/// Neural codecs for learned quantization
239#[cfg(feature = "experimental")]
240pub mod neural_codecs;
241
242/// Research and experimental features
243#[cfg(feature = "experimental")]
244pub mod research;
245
246/// Model export functionality (ONNX, TensorRT, TFLite, CoreML)
247#[cfg(feature = "experimental")]
248pub mod export;
249
250// Re-export commonly used types from other crates
251pub use torsh_core::{error::Result as TorshResult, DType, TorshError};
252pub use torsh_tensor::Tensor;
253
254// Version information
255pub const VERSION: &str = env!("CARGO_PKG_VERSION");
256pub const VERSION_MAJOR: u32 = 0;
257pub const VERSION_MINOR: u32 = 1;
258pub const VERSION_PATCH: u32 = 0;
259
260/// Prelude module for convenient imports
261pub mod prelude {
262    pub use crate::algorithms::*;
263    pub use crate::analysis::*;
264    pub use crate::auto_config::*;
265    pub use crate::config::*;
266    pub use crate::memory_pool::*;
267    pub use crate::metrics::*;
268    pub use crate::observers::*;
269    pub use crate::specialized::*;
270    pub use crate::utils::*;
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use torsh_tensor::creation::tensor_1d;
277
278    #[test]
279    fn test_basic_quantization_workflow() {
280        let data = vec![0.0, 1.0, 2.0, 3.0];
281        let tensor = tensor_1d(&data).unwrap();
282
283        // Test with basic INT8 configuration
284        let config = QuantConfig::int8();
285        let result = quantize_with_config(&tensor, &config);
286        assert!(result.is_ok());
287
288        let (quantized, scale, zero_point) = result.unwrap();
289        // Verify quantization worked correctly - values should be in quantized range
290        let quantized_data = quantized.data().unwrap();
291        let all_in_range = quantized_data.iter().all(|&x| x >= -128.0 && x <= 127.0);
292        assert!(
293            all_in_range,
294            "Quantized values should be in I8 range [-128, 127]"
295        );
296        assert!(scale > 0.0);
297
298        // Test dequantization
299        let dequantized = dequantize(&quantized, scale, zero_point).unwrap();
300        assert_eq!(dequantized.dtype(), DType::F32);
301    }
302
303    #[test]
304    fn test_configuration_validation() {
305        let valid_config = QuantConfig::int8();
306        assert!(valid_config.validate().is_ok());
307
308        let per_channel_config = QuantConfig::per_channel(0);
309        assert!(per_channel_config.validate().is_ok());
310    }
311
312    #[test]
313    fn test_specialized_quantization() {
314        let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
315        let _tensor = tensor_1d(&data).unwrap();
316
317        // Test INT4 quantization
318        let int4_config = QuantConfig::int4();
319        assert!(int4_config.validate().is_ok());
320
321        // Test binary quantization
322        let binary_config = QuantConfig::binary();
323        assert!(binary_config.validate().is_ok());
324
325        // Test ternary quantization
326        let ternary_config = QuantConfig::ternary();
327        assert!(ternary_config.validate().is_ok());
328    }
329
330    #[test]
331    fn test_utils_functionality() {
332        let data = vec![0.0, 1.0, 2.0, 3.0];
333        let tensor = tensor_1d(&data).unwrap();
334        let config = QuantConfig::int8();
335
336        // Test configuration validation with suggestions
337        let suggestions = validate_config_with_suggestions(&config).unwrap();
338        assert!(suggestions.len() > 0);
339
340        // Test optimization hints
341        let hints = get_optimization_hints(&tensor, &config);
342        // Hints can be empty for simple tensors - both empty and non-empty are valid
343        assert!(hints.is_empty() || !hints.is_empty());
344
345        // Test JSON serialization
346        let json = export_config_to_json(&config).unwrap();
347        let imported_config = import_config_from_json(&json).unwrap();
348        assert_eq!(config.dtype, imported_config.dtype);
349        assert_eq!(config.scheme, imported_config.scheme);
350    }
351
352    #[test]
353    fn test_batch_processing() {
354        let data1 = vec![0.0, 1.0, 2.0, 3.0];
355        let data2 = vec![4.0, 5.0, 6.0, 7.0];
356        let tensor1 = tensor_1d(&data1).unwrap();
357        let tensor2 = tensor_1d(&data2).unwrap();
358
359        let tensors = vec![&tensor1, &tensor2];
360        let config = QuantConfig::int8();
361
362        let results = quantize_batch_consistent(&tensors, &config).unwrap();
363        assert_eq!(results.len(), 2);
364
365        // Verify consistent scale and zero point
366        let (_, scale1, zp1) = &results[0];
367        let (_, scale2, zp2) = &results[1];
368        assert_eq!(scale1, scale2);
369        assert_eq!(zp1, zp2);
370    }
371
372    #[test]
373    fn test_metrics_calculation() {
374        let data = vec![0.0, 1.0, 2.0, 3.0];
375        let tensor = tensor_1d(&data).unwrap();
376        let config = QuantConfig::int8();
377
378        let (quantized, scale, zero_point) = quantize_with_config(&tensor, &config).unwrap();
379        let dequantized = dequantize(&quantized, scale, zero_point).unwrap();
380
381        let metrics = calculate_quantization_metrics(&tensor, &dequantized, 32, 8).unwrap();
382
383        assert!(metrics.psnr > 0.0);
384        assert!(metrics.snr > 0.0);
385        assert!(metrics.compression_ratio > 1.0);
386        assert!(metrics.cosine_similarity >= 0.0 && metrics.cosine_similarity <= 1.0);
387    }
388
389    #[test]
390    fn test_configuration_comparison() {
391        let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
392        let tensor = tensor_1d(&data).unwrap();
393
394        let configs = vec![
395            QuantConfig::int8(),
396            QuantConfig::per_channel(0),
397            QuantConfig::int4(),
398        ];
399
400        let comparison = compare_quantization_configs(&tensor, &configs).unwrap();
401        assert_eq!(comparison.len(), 3);
402
403        // Results should be sorted by PSNR (higher is better)
404        for i in 1..comparison.len() {
405            assert!(comparison[i - 1].1.psnr >= comparison[i].1.psnr);
406        }
407    }
408
409    #[test]
410    fn test_auto_calibration() {
411        let data1 = vec![0.0, 1.0, 2.0, 3.0];
412        let data2 = vec![4.0, 5.0, 6.0, 7.0];
413        let tensor1 = tensor_1d(&data1).unwrap();
414        let tensor2 = tensor_1d(&data2).unwrap();
415
416        let calibration_tensors = vec![&tensor1, &tensor2];
417        let target_psnr = 30.0;
418        let max_compression = 8.0;
419
420        let optimal_config =
421            auto_calibrate_quantization(&calibration_tensors, target_psnr, max_compression)
422                .unwrap();
423
424        assert!(optimal_config.validate().is_ok());
425    }
426
427    #[test]
428    fn test_report_generation() {
429        let data = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
430        let tensor = tensor_1d(&data).unwrap();
431
432        let configs = vec![QuantConfig::int8(), QuantConfig::int4()];
433
434        let report = generate_quantization_report(&tensor, &configs).unwrap();
435
436        // Verify report contains expected sections
437        assert!(report.contains("# Quantization Analysis Report"));
438        assert!(report.contains("## Quantization Configuration Comparison"));
439        assert!(report.contains("## Detailed Metrics"));
440        assert!(report.contains("## Recommendations"));
441    }
442
443    #[test]
444    fn test_error_diagnostics() {
445        let data = vec![0.0, 1.0, 2.0, 3.0];
446        let tensor = tensor_1d(&data).unwrap();
447        let config = QuantConfig::int8();
448
449        // Simulate an error (this is a mock example)
450        let error = TorshError::InvalidArgument("Test error".to_string());
451        let diagnosis = diagnose_quantization_failure(&tensor, &config, &error);
452
453        assert!(diagnosis.contains("Quantization failed with error"));
454        assert!(diagnosis.contains("Tensor Analysis"));
455        assert!(diagnosis.contains("Configuration Analysis"));
456        assert!(diagnosis.contains("Recovery Suggestions"));
457    }
458
459    #[test]
460    fn test_optimized_config_creation() {
461        // Test different use cases
462        let inference_config = create_optimized_config("inference_cpu", "x86").unwrap();
463        assert!(inference_config.validate().is_ok());
464
465        let mobile_config = create_optimized_config("inference_mobile", "arm").unwrap();
466        assert!(mobile_config.validate().is_ok());
467
468        let training_config = create_optimized_config("training", "gpu").unwrap();
469        assert!(training_config.validate().is_ok());
470
471        // Test invalid use case
472        let invalid_result = create_optimized_config("invalid_use_case", "x86");
473        assert!(invalid_result.is_err());
474    }
475}