1#[cfg(feature = "std")]
16use std::collections::HashMap;
17
18#[cfg(not(feature = "std"))]
19extern crate alloc;
20
21#[cfg(not(feature = "std"))]
22use alloc::{collections::BTreeMap as HashMap, string::String};
23
24use torsh_core::{
25 dtype::DType,
26 error::{Result as TorshResult, TorshError},
27};
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
31pub enum QScheme {
32 PerTensorAffine,
34 PerChannelAffine,
36 PerTensorSymmetric,
38 PerChannelSymmetric,
40 Int4PerTensor,
42 Int4PerChannel,
44 MixedPrecision,
46 Binary,
48 Ternary,
50 GroupWise,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
56pub enum QuantBackend {
57 Fbgemm,
59 Qnnpack,
61 Native,
63 Xnnpack,
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
69pub enum ReduceRange {
70 None,
72 Reduce,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
78pub enum ObserverType {
79 MinMax,
81 MovingAverage,
83 Histogram,
85 Percentile,
87 KLDivergence,
89 Entropy,
91}
92
93#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
95pub struct QuantConfig {
96 pub dtype: DType,
97 pub scheme: QScheme,
98 pub enable_fake_quant: bool,
99 pub observer_type: ObserverType,
100 pub backend: QuantBackend,
101 pub reduce_range: ReduceRange,
102 pub qint_min: Option<i32>,
103 pub qint_max: Option<i32>,
104 pub eps: f32,
105 pub averaging_constant: f32,
106 pub ch_axis: Option<usize>,
107 pub group_size: Option<usize>,
109}
110
111impl Default for QuantConfig {
112 fn default() -> Self {
113 Self {
114 dtype: DType::I8,
115 scheme: QScheme::PerTensorAffine,
116 enable_fake_quant: false,
117 observer_type: ObserverType::MinMax,
118 backend: QuantBackend::Native,
119 reduce_range: ReduceRange::None,
120 qint_min: None,
121 qint_max: None,
122 eps: 1e-8,
123 averaging_constant: 0.01,
124 ch_axis: None,
125 group_size: None,
126 }
127 }
128}
129
130#[derive(Debug, Clone)]
132pub struct MixedPrecisionConfig {
133 pub layer_precision: HashMap<String, DType>,
135 pub default_precision: DType,
137 pub sensitivity_threshold: f32,
139}
140
141impl Default for MixedPrecisionConfig {
142 fn default() -> Self {
143 let mut layer_precision = HashMap::new();
144 layer_precision.insert("embedding".to_string(), DType::I8);
145 layer_precision.insert("attention".to_string(), DType::F16);
146 layer_precision.insert("output".to_string(), DType::F32);
147
148 Self {
149 layer_precision,
150 default_precision: DType::I8,
151 sensitivity_threshold: 0.1,
152 }
153 }
154}
155
156impl QuantConfig {
157 pub fn new() -> Self {
159 Self::default()
160 }
161
162 pub fn int8() -> Self {
164 Self {
165 dtype: DType::I8,
166 qint_min: Some(-128),
167 qint_max: Some(127),
168 ..Self::default()
169 }
170 }
171
172 pub fn int4() -> Self {
174 Self {
175 dtype: DType::I8, scheme: QScheme::Int4PerTensor,
177 qint_min: Some(-8),
178 qint_max: Some(7),
179 observer_type: ObserverType::Histogram,
180 ..Self::default()
181 }
182 }
183
184 pub fn binary() -> Self {
186 Self {
187 dtype: DType::I8,
188 scheme: QScheme::Binary,
189 qint_min: Some(-1),
190 qint_max: Some(1),
191 observer_type: ObserverType::MinMax,
192 ..Self::default()
193 }
194 }
195
196 pub fn ternary() -> Self {
198 Self {
199 dtype: DType::I8,
200 scheme: QScheme::Ternary,
201 qint_min: Some(-1),
202 qint_max: Some(1),
203 observer_type: ObserverType::MinMax,
204 ..Self::default()
205 }
206 }
207
208 pub fn mixed_precision() -> Self {
210 Self {
211 dtype: DType::I8, scheme: QScheme::MixedPrecision,
213 observer_type: ObserverType::KLDivergence,
214 ..Self::default()
215 }
216 }
217
218 pub fn uint8() -> Self {
220 Self {
221 dtype: DType::U8,
222 qint_min: Some(0),
223 qint_max: Some(255),
224 ..Self::default()
225 }
226 }
227
228 pub fn per_channel(ch_axis: usize) -> Self {
230 Self {
231 scheme: QScheme::PerChannelAffine,
232 ch_axis: Some(ch_axis),
233 ..Self::default()
234 }
235 }
236
237 pub fn group_wise(ch_axis: usize, group_size: usize) -> Self {
239 Self {
240 scheme: QScheme::GroupWise,
241 ch_axis: Some(ch_axis),
242 group_size: Some(group_size),
243 observer_type: ObserverType::Histogram,
244 ..Self::default()
245 }
246 }
247
248 pub fn qat() -> Self {
250 Self {
251 enable_fake_quant: true,
252 observer_type: ObserverType::MovingAverage,
253 ..Self::default()
254 }
255 }
256
257 pub fn with_backend(mut self, backend: QuantBackend) -> Self {
259 self.backend = backend;
260 self
261 }
262
263 pub fn with_observer(mut self, observer_type: ObserverType) -> Self {
265 self.observer_type = observer_type;
266 self
267 }
268
269 pub fn with_scheme(mut self, scheme: QScheme) -> Self {
271 self.scheme = scheme;
272 if matches!(
273 scheme,
274 QScheme::PerChannelAffine | QScheme::PerChannelSymmetric | QScheme::GroupWise
275 ) && self.ch_axis.is_none()
276 {
277 self.ch_axis = Some(0); }
279 if matches!(scheme, QScheme::GroupWise) && self.group_size.is_none() {
280 self.group_size = Some(32); }
282 self
283 }
284
285 pub fn with_fake_quant(mut self, enable: bool) -> Self {
287 self.enable_fake_quant = enable;
288 self
289 }
290
291 pub fn with_reduce_range(mut self, reduce_range: ReduceRange) -> Self {
293 self.reduce_range = reduce_range;
294 self
295 }
296
297 pub fn with_group_size(mut self, group_size: usize) -> Self {
299 self.group_size = Some(group_size);
300 self
301 }
302
303 pub fn get_qint_range(&self) -> (i32, i32) {
305 let (base_min, base_max) = match self.scheme {
306 QScheme::Int4PerTensor | QScheme::Int4PerChannel => (-8, 7),
307 QScheme::Binary => (-1, 1),
308 QScheme::Ternary => (-1, 1),
309 _ => match self.dtype {
310 DType::I8 => (-128, 127),
311 DType::U8 => (0, 255),
312 _ => (self.qint_min.unwrap_or(-128), self.qint_max.unwrap_or(127)),
313 },
314 };
315
316 let (qmin, qmax) = match self.reduce_range {
317 ReduceRange::None => (base_min, base_max),
318 ReduceRange::Reduce => {
319 let range = base_max - base_min;
321 let reduced_range = range / 2;
322 let mid = (base_min + base_max) / 2;
323 (mid - reduced_range / 2, mid + reduced_range / 2)
324 }
325 };
326
327 (qmin, qmax)
328 }
329
330 pub fn validate(&self) -> TorshResult<()> {
332 if matches!(
334 self.scheme,
335 QScheme::PerChannelAffine | QScheme::PerChannelSymmetric | QScheme::GroupWise
336 ) && self.ch_axis.is_none()
337 {
338 return Err(TorshError::InvalidArgument(
339 "Per-channel/Group-wise quantization requires channel axis".to_string(),
340 ));
341 }
342
343 if matches!(self.scheme, QScheme::GroupWise) {
345 if self.group_size.is_none() {
346 return Err(TorshError::InvalidArgument(
347 "Group-wise quantization requires group size".to_string(),
348 ));
349 }
350 if let Some(group_size) = self.group_size {
351 if group_size == 0 {
352 return Err(TorshError::InvalidArgument(
353 "Group size must be greater than 0".to_string(),
354 ));
355 }
356 }
357 }
358
359 if matches!(
361 self.scheme,
362 QScheme::PerTensorSymmetric | QScheme::PerChannelSymmetric
363 ) {
364 }
366
367 if matches!(self.scheme, QScheme::Binary | QScheme::Ternary)
369 && !matches!(
370 self.observer_type,
371 ObserverType::MinMax | ObserverType::MovingAverage
372 )
373 {
374 return Err(TorshError::InvalidArgument(
375 "Binary/ternary quantization requires MinMax or MovingAverage observer".to_string(),
376 ));
377 }
378
379 if matches!(self.scheme, QScheme::MixedPrecision)
381 && !matches!(
382 self.observer_type,
383 ObserverType::KLDivergence | ObserverType::Entropy
384 )
385 {
386 return Err(TorshError::InvalidArgument(
387 "Mixed precision quantization requires KLDivergence or Entropy observer"
388 .to_string(),
389 ));
390 }
391
392 if self.eps <= 0.0 {
394 return Err(TorshError::InvalidArgument(
395 "eps must be positive".to_string(),
396 ));
397 }
398
399 if self.averaging_constant <= 0.0 || self.averaging_constant >= 1.0 {
401 return Err(TorshError::InvalidArgument(
402 "averaging_constant must be in (0, 1)".to_string(),
403 ));
404 }
405
406 Ok(())
407 }
408}
409
410pub struct QuantConfigBuilder {
412 config: QuantConfig,
413}
414
415impl QuantConfigBuilder {
416 pub fn new() -> Self {
418 Self {
419 config: QuantConfig::default(),
420 }
421 }
422
423 pub fn dtype(mut self, dtype: DType) -> Self {
425 self.config.dtype = dtype;
426 self
427 }
428
429 pub fn scheme(mut self, scheme: QScheme) -> Self {
431 self.config = self.config.with_scheme(scheme);
432 self
433 }
434
435 pub fn observer(mut self, observer_type: ObserverType) -> Self {
437 self.config.observer_type = observer_type;
438 self
439 }
440
441 pub fn backend(mut self, backend: QuantBackend) -> Self {
443 self.config.backend = backend;
444 self
445 }
446
447 pub fn fake_quant(mut self, enable: bool) -> Self {
449 self.config.enable_fake_quant = enable;
450 self
451 }
452
453 pub fn channel_axis(mut self, axis: usize) -> Self {
455 self.config.ch_axis = Some(axis);
456 self
457 }
458
459 pub fn group_size(mut self, size: usize) -> Self {
461 self.config.group_size = Some(size);
462 self
463 }
464
465 pub fn build(self) -> TorshResult<QuantConfig> {
467 self.config.validate()?;
468 Ok(self.config)
469 }
470}
471
472impl Default for QuantConfigBuilder {
473 fn default() -> Self {
474 Self::new()
475 }
476}
477
478#[cfg(test)]
479mod tests {
480 use super::*;
481
482 #[test]
483 fn test_quant_config_defaults() {
484 let config = QuantConfig::default();
485 assert_eq!(config.dtype, DType::I8);
486 assert_eq!(config.scheme, QScheme::PerTensorAffine);
487 assert!(!config.enable_fake_quant);
488 assert_eq!(config.observer_type, ObserverType::MinMax);
489 assert_eq!(config.backend, QuantBackend::Native);
490 assert_eq!(config.reduce_range, ReduceRange::None);
491 }
492
493 #[test]
494 fn test_quant_config_presets() {
495 let int8_config = QuantConfig::int8();
496 assert_eq!(int8_config.dtype, DType::I8);
497 assert_eq!(int8_config.qint_min, Some(-128));
498 assert_eq!(int8_config.qint_max, Some(127));
499
500 let binary_config = QuantConfig::binary();
501 assert_eq!(binary_config.scheme, QScheme::Binary);
502 assert_eq!(binary_config.qint_min, Some(-1));
503 assert_eq!(binary_config.qint_max, Some(1));
504
505 let int4_config = QuantConfig::int4();
506 assert_eq!(int4_config.scheme, QScheme::Int4PerTensor);
507 assert_eq!(int4_config.observer_type, ObserverType::Histogram);
508 }
509
510 #[test]
511 fn test_quant_config_builder() {
512 let config = QuantConfigBuilder::new()
513 .dtype(DType::I8)
514 .scheme(QScheme::PerChannelAffine)
515 .observer(ObserverType::Histogram)
516 .backend(QuantBackend::Fbgemm)
517 .channel_axis(1)
518 .build()
519 .unwrap();
520
521 assert_eq!(config.dtype, DType::I8);
522 assert_eq!(config.scheme, QScheme::PerChannelAffine);
523 assert_eq!(config.observer_type, ObserverType::Histogram);
524 assert_eq!(config.backend, QuantBackend::Fbgemm);
525 assert_eq!(config.ch_axis, Some(1));
526 }
527
528 #[test]
529 fn test_config_validation() {
530 let valid_config = QuantConfig::per_channel(0);
532 assert!(valid_config.validate().is_ok());
533
534 let mut invalid_config = QuantConfig::default();
536 invalid_config.scheme = QScheme::PerChannelAffine;
537 invalid_config.ch_axis = None;
538 assert!(invalid_config.validate().is_err());
539
540 let mut invalid_group = QuantConfig::default();
542 invalid_group.scheme = QScheme::GroupWise;
543 invalid_group.ch_axis = Some(0);
544 invalid_group.group_size = None;
545 assert!(invalid_group.validate().is_err());
546
547 let mut invalid_eps = QuantConfig::default();
549 invalid_eps.eps = -1.0;
550 assert!(invalid_eps.validate().is_err());
551
552 let mut invalid_avg = QuantConfig::default();
554 invalid_avg.averaging_constant = 1.5;
555 assert!(invalid_avg.validate().is_err());
556 }
557
558 #[test]
559 fn test_get_qint_range() {
560 let int8_config = QuantConfig::int8();
561 assert_eq!(int8_config.get_qint_range(), (-128, 127));
562
563 let uint8_config = QuantConfig::uint8();
564 assert_eq!(uint8_config.get_qint_range(), (0, 255));
565
566 let int4_config = QuantConfig::int4();
567 assert_eq!(int4_config.get_qint_range(), (-8, 7));
568
569 let binary_config = QuantConfig::binary();
570 assert_eq!(binary_config.get_qint_range(), (-1, 1));
571
572 let reduced_config = QuantConfig::int8().with_reduce_range(ReduceRange::Reduce);
574 let (min, max) = reduced_config.get_qint_range();
575 assert!(min > -128 && max < 127);
576 }
577
578 #[test]
579 fn test_mixed_precision_config() {
580 let mixed_config = MixedPrecisionConfig::default();
581 assert_eq!(mixed_config.default_precision, DType::I8);
582 assert_eq!(mixed_config.sensitivity_threshold, 0.1);
583 assert!(mixed_config.layer_precision.contains_key("embedding"));
584 }
585
586 #[test]
587 fn test_config_serialization() {
588 let config = QuantConfig::int8().with_observer(ObserverType::Histogram);
589
590 let json = serde_json::to_string(&config).unwrap();
592 let deserialized: QuantConfig = serde_json::from_str(&json).unwrap();
593
594 assert_eq!(config.dtype, deserialized.dtype);
595 assert_eq!(config.scheme, deserialized.scheme);
596 assert_eq!(config.observer_type, deserialized.observer_type);
597 }
598}