1#![allow(unused_variables)] use crate::{
13 errors::{invalid_input, Result},
14 tensor::{DType, Tensor},
15};
16use std::collections::HashMap;
17
18#[derive(Debug, Clone)]
20pub struct BitsAndBytesConfig {
21 pub bits: u8,
23 pub dynamic_tree: bool,
25 pub block_size: usize,
27 pub stochastic: bool,
29 pub outlier_threshold: f32,
31 pub nested_quantization: bool,
33}
34
35impl Default for BitsAndBytesConfig {
36 fn default() -> Self {
37 Self {
38 bits: 8,
39 dynamic_tree: false,
40 block_size: 256,
41 stochastic: false,
42 outlier_threshold: 0.99,
43 nested_quantization: false,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
50pub struct QuantState {
51 pub data: Tensor,
53 pub scale: Tensor,
55 pub zero_point: Option<Tensor>,
57 pub outliers: Option<Vec<usize>>,
59 pub original_dtype: DType,
61 pub block_sizes: Vec<usize>,
63 pub original_shape: Vec<usize>,
65}
66
67pub fn quantize_int8(tensor: &Tensor, config: &BitsAndBytesConfig) -> Result<QuantState> {
69 let original_dtype = tensor.dtype();
70 let shape = tensor.shape();
71
72 let total_elements = tensor.shape().iter().product::<usize>();
74 let flattened = tensor.reshape(&[total_elements])?;
75 let num_elements = flattened.shape()[0];
76
77 let num_blocks = num_elements.div_ceil(config.block_size);
79 let mut scales = Vec::with_capacity(num_blocks);
80 let mut zero_points = Vec::with_capacity(num_blocks);
81 let mut quantized_blocks = Vec::new();
82 let mut outlier_indices = Vec::new();
83
84 for block_idx in 0..num_blocks {
85 let start = block_idx * config.block_size;
86 let end = std::cmp::min(start + config.block_size, num_elements);
87 let block = flattened.slice_ranges(&[(start, end)])?;
88
89 let (min_val, max_val) = block.min_max()?;
91
92 if config.outlier_threshold < 1.0 {
94 let sorted = block.sort()?;
95 let lower_idx = ((1.0 - config.outlier_threshold) * (end - start) as f32) as usize;
96 let upper_idx = (config.outlier_threshold * (end - start) as f32) as usize;
97
98 let lower_bound = sorted.get_float(lower_idx)?;
99 let upper_bound = sorted.get_float(upper_idx)?;
100
101 for i in start..end {
103 let val = flattened.get_float(i)?;
104 if val < lower_bound || val > upper_bound {
105 outlier_indices.push(i);
106 }
107 }
108 }
109
110 let scale = (max_val - min_val) / 255.0;
112 let zero_point = -min_val / scale;
113
114 scales.push(scale);
115 zero_points.push(zero_point);
116
117 let quantized = block.sub_scalar(min_val)?.div_scalar(scale)?.round()?.clamp(0.0, 255.0)?;
119
120 quantized_blocks.push(quantized);
121 }
122
123 let quantized_data =
125 Tensor::concat(&quantized_blocks, 0)?.to_dtype(DType::I64)?.reshape(&shape)?;
126
127 let scale_tensor = Tensor::from_vec(scales, &[num_blocks])?;
129 let zero_point_tensor = Tensor::from_vec(zero_points, &[num_blocks])?;
130
131 let final_scale = if config.nested_quantization {
133 quantize_scales(&scale_tensor, 8)?
134 } else {
135 scale_tensor
136 };
137
138 Ok(QuantState {
139 data: quantized_data,
140 scale: final_scale,
141 zero_point: Some(zero_point_tensor),
142 outliers: if outlier_indices.is_empty() { None } else { Some(outlier_indices) },
143 original_dtype,
144 block_sizes: vec![config.block_size],
145 original_shape: shape.to_vec(),
146 })
147}
148
149pub fn quantize_4bit(tensor: &Tensor, config: &BitsAndBytesConfig) -> Result<QuantState> {
151 let original_dtype = tensor.dtype();
152 let shape = tensor.shape();
153
154 let block_size = config.block_size / 2;
156 let total_elements = tensor.shape().iter().product::<usize>();
157 let flattened = tensor.reshape(&[total_elements])?;
158 let num_elements = flattened.shape()[0];
159 let num_blocks = num_elements.div_ceil(block_size);
160
161 let mut scales = Vec::with_capacity(num_blocks);
162 let mut quantized_blocks = Vec::new();
163
164 let nf4_levels = get_nf4_quantization_levels();
166
167 for block_idx in 0..num_blocks {
168 let start = block_idx * block_size;
169 let end = std::cmp::min(start + block_size, num_elements);
170 let block = flattened.slice_ranges(&[(start, end)])?;
171
172 let mean = block.mean()?;
174 let std = block.std()?;
175 let mean_scalar = mean.get_float(0)?;
176 let std_scalar = std.get_float(0)?;
177 let normalized = block.sub_scalar(mean_scalar)?.div_scalar(std_scalar + 1e-8)?;
178
179 let abs_max = normalized.abs()?.max_value()?;
181 let scale = abs_max.get_float(0)?;
182 scales.push(scale);
183
184 let mut quantized_values = Vec::with_capacity(end - start);
186 for i in 0..(end - start) {
187 let val = normalized.get_float(i)? / scale;
188 let quantized_idx = find_nearest_nf4_level(val, &nf4_levels);
189 quantized_values.push(quantized_idx as f32);
190 }
191
192 let quantized = Tensor::from_vec(quantized_values, &[end - start])?;
193 quantized_blocks.push(quantized);
194 }
195
196 let quantized_concat = Tensor::concat(&quantized_blocks, 0)?;
198 let packed_data = pack_4bit_tensor(&quantized_concat)?;
199
200 let scale_tensor = Tensor::from_vec(scales, &[num_blocks])?;
201
202 Ok(QuantState {
203 data: packed_data,
204 scale: scale_tensor,
205 zero_point: None,
206 outliers: None,
207 original_dtype,
208 block_sizes: vec![block_size],
209 original_shape: shape.to_vec(),
210 })
211}
212
213pub fn quantize_dynamic_tree(tensor: &Tensor, config: &BitsAndBytesConfig) -> Result<QuantState> {
215 let total_elements = tensor.shape().iter().product::<usize>();
217 let flattened = tensor.reshape(&[total_elements])?;
218 let histogram = build_histogram(&flattened, 256)?;
219 let tree = build_quantization_tree(&histogram, config.bits)?;
220
221 let quantized = apply_tree_quantization(&flattened, &tree)?;
223
224 let scale_data = serialize_tree(&tree)?;
226
227 Ok(QuantState {
228 data: quantized.reshape(&tensor.shape())?,
229 scale: scale_data,
230 zero_point: None,
231 outliers: None,
232 original_dtype: tensor.dtype(),
233 block_sizes: vec![],
234 original_shape: tensor.shape().to_vec(),
235 })
236}
237
238pub fn dequantize_bitsandbytes(state: &QuantState, config: &BitsAndBytesConfig) -> Result<Tensor> {
240 match config.bits {
241 8 => dequantize_int8(state),
242 4 => dequantize_4bit(state),
243 _ => Err(invalid_input(format!(
244 "Unsupported bit width: {}",
245 config.bits
246 ))),
247 }
248}
249
250fn dequantize_int8(state: &QuantState) -> Result<Tensor> {
252 let shape = state.data.shape();
253 let total_elements = state.data.shape().iter().product::<usize>();
254 let flattened = state.data.reshape(&[total_elements])?;
255 let num_elements = flattened.shape()[0];
256
257 let block_size = state.block_sizes.first().copied().unwrap_or(256);
259 let num_blocks = num_elements.div_ceil(block_size);
260
261 let mut dequantized_blocks = Vec::new();
262
263 for block_idx in 0..num_blocks {
264 let start = block_idx * block_size;
265 let end = std::cmp::min(start + block_size, num_elements);
266 let block = flattened.slice_ranges(&[(start, end)])?;
267
268 let scale = state.scale.get_float(block_idx)?;
270 let zero_point = state
271 .zero_point
272 .as_ref()
273 .map(|zp| zp.get_float(block_idx))
274 .transpose()?
275 .unwrap_or(0.0);
276
277 let dequantized = block.to_dtype(DType::F32)?.sub_scalar(zero_point)?.scalar_mul(scale)?;
279
280 dequantized_blocks.push(dequantized);
281 }
282
283 Tensor::concat(&dequantized_blocks, 0)?
285 .reshape(&shape)?
286 .to_dtype(state.original_dtype)
287}
288
289fn dequantize_4bit(state: &QuantState) -> Result<Tensor> {
291 let unpacked = unpack_4bit_tensor(&state.data)?;
293 let nf4_levels = get_nf4_quantization_levels();
294
295 let original_shape = &state.original_shape;
296 let block_size = state.block_sizes.first().copied().unwrap_or(128);
297 let num_elements = unpacked.shape()[0];
298 let num_blocks = num_elements.div_ceil(block_size);
299
300 let mut dequantized_blocks = Vec::new();
301
302 for block_idx in 0..num_blocks {
303 let start = block_idx * block_size;
304 let end = std::cmp::min(start + block_size, num_elements);
305 let block = unpacked.slice(0, start, end)?;
306
307 let scale = state.scale.get_float(block_idx)?;
308
309 let mut values = Vec::with_capacity(end - start);
311 for i in 0..(end - start) {
312 let idx = block.get_float(i)? as usize;
313 let nf4_value = nf4_levels[idx];
314 values.push(nf4_value * scale);
315 }
316
317 let dequantized = Tensor::from_vec(values, &[end - start])?;
318 dequantized_blocks.push(dequantized);
319 }
320
321 Tensor::concat(&dequantized_blocks, 0)?
322 .reshape(original_shape)?
323 .to_dtype(state.original_dtype)
324}
325
326fn get_nf4_quantization_levels() -> Vec<f32> {
328 vec![
329 -1.0,
330 -0.6961928009986877,
331 -0.5250730514526367,
332 -0.39491748809814453,
333 -0.28444138169288635,
334 -0.18477343022823334,
335 -0.09105003625154495,
336 0.0,
337 0.07958029955625534,
338 0.16093020141124725,
339 0.24611230194568634,
340 0.33791524171829224,
341 0.44070982933044434,
342 0.5626170039176941,
343 0.7229568362236023,
344 1.0,
345 ]
346}
347
348fn find_nearest_nf4_level(value: f32, levels: &[f32]) -> usize {
350 let mut min_dist = f32::INFINITY;
351 let mut best_idx = 0;
352
353 for (idx, &level) in levels.iter().enumerate() {
354 let dist = (value - level).abs();
355 if dist < min_dist {
356 min_dist = dist;
357 best_idx = idx;
358 }
359 }
360
361 best_idx
362}
363
364fn pack_4bit_tensor(tensor: &Tensor) -> Result<Tensor> {
366 let data = tensor.to_vec_f32()?;
367 let mut packed = Vec::with_capacity(data.len().div_ceil(2));
368
369 for i in (0..data.len()).step_by(2) {
370 let low = data[i] as u8 & 0x0F;
371 let high = if i + 1 < data.len() { (data[i + 1] as u8 & 0x0F) << 4 } else { 0 };
372 packed.push(low | high);
373 }
374
375 let packed_f32: Vec<f32> = packed.into_iter().map(|x| x as f32).collect();
376 let len = packed_f32.len();
377 Tensor::from_vec(packed_f32, &[len])
378}
379
380fn unpack_4bit_tensor(tensor: &Tensor) -> Result<Tensor> {
382 let packed = tensor.to_vec_u8()?;
383 let mut unpacked = Vec::with_capacity(packed.len() * 2);
384
385 for byte in packed {
386 unpacked.push((byte & 0x0F) as f32);
387 unpacked.push(((byte >> 4) & 0x0F) as f32);
388 }
389
390 let len = unpacked.len();
391 Tensor::from_vec(unpacked, &[len])
392}
393
394fn build_histogram(tensor: &Tensor, bins: usize) -> Result<Vec<f32>> {
396 let data = tensor.to_vec_f32()?;
397 let (min_val, max_val) = tensor.min_max()?;
398 let range = max_val - min_val;
399 let bin_width = range / bins as f32;
400
401 let mut histogram = vec![0.0; bins];
402
403 for &value in &data {
404 let bin_idx = ((value - min_val) / bin_width).floor() as usize;
405 let bin_idx = bin_idx.min(bins - 1);
406 histogram[bin_idx] += 1.0;
407 }
408
409 let total: f32 = histogram.iter().sum();
411 for count in &mut histogram {
412 *count /= total;
413 }
414
415 Ok(histogram)
416}
417
418#[derive(Debug, Clone)]
420struct TreeNode {
421 threshold: f32,
422 left: Option<Box<TreeNode>>,
423 right: Option<Box<TreeNode>>,
424 value: Option<u8>,
425}
426
427fn build_quantization_tree(histogram: &[f32], bits: u8) -> Result<TreeNode> {
429 let levels = 1 << bits;
431 let mut thresholds = Vec::with_capacity(levels - 1);
432
433 for i in 1..levels {
435 thresholds.push(i as f32 / levels as f32);
436 }
437
438 fn build_node(thresholds: &[f32], start: usize, end: usize) -> TreeNode {
440 if start >= end {
441 TreeNode {
442 threshold: 0.0,
443 left: None,
444 right: None,
445 value: Some(start as u8),
446 }
447 } else {
448 let mid = (start + end) / 2;
449 TreeNode {
450 threshold: thresholds[mid],
451 left: Some(Box::new(build_node(thresholds, start, mid))),
452 right: Some(Box::new(build_node(thresholds, mid + 1, end))),
453 value: None,
454 }
455 }
456 }
457
458 Ok(build_node(&thresholds, 0, levels - 1))
459}
460
461fn apply_tree_quantization(tensor: &Tensor, tree: &TreeNode) -> Result<Tensor> {
463 let data = tensor.to_vec_f32()?;
464 let mut quantized = Vec::with_capacity(data.len());
465
466 for &value in &data {
467 let quantized_value = traverse_tree(value, tree);
468 quantized.push(quantized_value as f32);
469 }
470
471 Tensor::from_vec(quantized, &tensor.shape())
472}
473
474fn traverse_tree(value: f32, node: &TreeNode) -> u8 {
476 if let Some(leaf_value) = node.value {
477 leaf_value
478 } else if value < node.threshold {
479 traverse_tree(
480 value,
481 node.left.as_ref().expect("non-leaf node must have left child"),
482 )
483 } else {
484 traverse_tree(
485 value,
486 node.right.as_ref().expect("non-leaf node must have right child"),
487 )
488 }
489}
490
491fn serialize_tree(tree: &TreeNode) -> Result<Tensor> {
493 let mut thresholds = Vec::new();
495 collect_thresholds(tree, &mut thresholds);
496 let len = thresholds.len();
497 Tensor::from_vec(thresholds, &[len])
498}
499
500fn collect_thresholds(node: &TreeNode, thresholds: &mut Vec<f32>) {
502 if node.value.is_none() {
503 thresholds.push(node.threshold);
504 if let Some(left) = &node.left {
505 collect_thresholds(left, thresholds);
506 }
507 if let Some(right) = &node.right {
508 collect_thresholds(right, thresholds);
509 }
510 }
511}
512
513fn quantize_scales(scales: &Tensor, bits: u8) -> Result<Tensor> {
515 let (min_val, max_val) = scales.min_max()?;
517 let levels = (1 << bits) as f32;
518 let scale = (max_val - min_val) / (levels - 1.0);
519
520 scales.sub_scalar(min_val)?.div_scalar(scale)?.round()?.clamp(0.0, levels - 1.0)
521}
522
523pub fn to_bitsandbytes_format(
525 tensor: &Tensor,
526 config: &BitsAndBytesConfig,
527) -> Result<HashMap<String, Tensor>> {
528 let state = match config.bits {
529 8 => quantize_int8(tensor, config)?,
530 4 => quantize_4bit(tensor, config)?,
531 _ => {
532 return Err(invalid_input(format!(
533 "Unsupported bit width: {}",
534 config.bits
535 )))
536 },
537 };
538
539 let mut result = HashMap::new();
540 result.insert("data".to_string(), state.data);
541 result.insert("scale".to_string(), state.scale);
542
543 if let Some(zero_point) = state.zero_point {
544 result.insert("zero_point".to_string(), zero_point);
545 }
546
547 if let Some(outliers) = state.outliers {
548 let outlier_tensor = Tensor::from_vec(
549 outliers.iter().map(|&idx| idx as f32).collect(),
550 &[outliers.len()],
551 )?;
552 result.insert("outliers".to_string(), outlier_tensor);
553 }
554
555 Ok(result)
556}
557
558pub fn from_bitsandbytes_format(
560 data: HashMap<String, Tensor>,
561 config: &BitsAndBytesConfig,
562) -> Result<Tensor> {
563 let quantized_data = data
564 .get("data")
565 .ok_or_else(|| invalid_input("Missing 'data' tensor".to_string()))?;
566 let scale = data
567 .get("scale")
568 .ok_or_else(|| invalid_input("Missing 'scale' tensor".to_string()))?;
569 let zero_point = data.get("zero_point");
570 let outliers = data
571 .get("outliers")
572 .map(|t| t.to_vec_f32().map(|v| v.iter().map(|&x| x as usize).collect()))
573 .transpose()?;
574
575 let state = QuantState {
576 data: quantized_data.clone(),
577 scale: scale.clone(),
578 zero_point: zero_point.cloned(),
579 outliers,
580 original_dtype: DType::F32,
581 block_sizes: vec![config.block_size],
582 original_shape: quantized_data.shape().to_vec(),
583 };
584
585 dequantize_bitsandbytes(&state, config)
586}
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591
592 #[test]
593 fn test_int8_quantization() -> Result<()> {
594 let tensor = Tensor::randn(&[64, 64])?;
595 let config = BitsAndBytesConfig::default();
596
597 let state = quantize_int8(&tensor, &config)?;
598 let dequantized = dequantize_int8(&state)?;
599
600 assert_eq!(tensor.shape(), dequantized.shape());
602
603 let error = tensor.sub(&dequantized)?.abs()?.mean()?;
605 let error_val = error.get_float(0)?;
606 assert!(
607 error_val < 0.1,
608 "Reconstruction error too high: {}",
609 error_val
610 );
611
612 Ok(())
613 }
614
615 #[test]
616 fn test_4bit_quantization() -> Result<()> {
617 let tensor = Tensor::randn(&[32, 32])?;
618 let config = BitsAndBytesConfig {
619 bits: 4,
620 ..Default::default()
621 };
622
623 let state = quantize_4bit(&tensor, &config)?;
624 let dequantized = dequantize_4bit(&state)?;
625
626 assert_eq!(tensor.shape(), dequantized.shape());
627 Ok(())
628 }
629
630 #[test]
631 fn test_bitsandbytes_format_conversion() -> Result<()> {
632 let tensor = Tensor::randn(&[128, 128])?;
633 let config = BitsAndBytesConfig::default();
634
635 let bnb_format = to_bitsandbytes_format(&tensor, &config)?;
636 let reconstructed = from_bitsandbytes_format(bnb_format, &config)?;
637
638 assert_eq!(tensor.shape(), reconstructed.shape());
639 Ok(())
640 }
641}