Skip to main content

ruvllm/gguf/
quantization.rs

1//! GGUF Quantization Types and Dequantization Kernels
2//!
3//! This module implements all GGUF quantization formats used by llama.cpp,
4//! providing both type definitions and optimized dequantization routines.
5//!
6//! ## Quantization Format Overview
7//!
8//! GGUF supports multiple quantization formats with different tradeoffs:
9//!
10//! | Format | Bits/Weight | Block Size | Description |
11//! |--------|-------------|------------|-------------|
12//! | F32 | 32 | 1 | Full precision |
13//! | F16 | 16 | 1 | Half precision |
14//! | Q8_0 | 8.5 | 32 | 8-bit symmetric |
15//! | Q8_1 | 9 | 32 | 8-bit with offset |
16//! | Q4_0 | 4.5 | 32 | 4-bit symmetric |
17//! | Q4_1 | 5 | 32 | 4-bit with offset |
18//! | Q5_0 | 5.5 | 32 | 5-bit symmetric |
19//! | Q5_1 | 6 | 32 | 5-bit with offset |
20//! | Q2_K | 2.56 | 256 | 2-bit k-quant |
21//! | Q3_K | 3.44 | 256 | 3-bit k-quant |
22//! | Q4_K | 4.5 | 256 | 4-bit k-quant |
23//! | Q5_K | 5.5 | 256 | 5-bit k-quant |
24//! | Q6_K | 6.56 | 256 | 6-bit k-quant |
25//! | IQ2_XXS | 2.06 | 256 | i-quant extreme |
26//! | IQ2_XS | 2.31 | 256 | i-quant |
27//! | IQ3_XXS | 3.06 | 256 | i-quant 3-bit |
28//! | IQ1_S | 1.56 | 256 | i-quant 1-bit |
29//! | IQ4_NL | 4.5 | 32 | i-quant 4-bit non-linear |
30
31use crate::bitnet::dequantize_bitnet_t158;
32use crate::error::{Result, RuvLLMError};
33
34// ============================================================================
35// Quantization Types
36// ============================================================================
37
38/// GGUF quantization type identifiers.
39///
40/// These correspond to the GGML quantization types used in llama.cpp.
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42#[repr(u32)]
43pub enum GgufQuantType {
44    /// 32-bit floating point (no quantization)
45    F32 = 0,
46    /// 16-bit floating point
47    F16 = 1,
48    /// 4-bit quantization (32-element blocks, symmetric)
49    Q4_0 = 2,
50    /// 4-bit quantization with offset
51    Q4_1 = 3,
52    /// Legacy 4-bit format (deprecated)
53    Q4_2 = 4,
54    /// Legacy 4-bit format (deprecated)
55    Q4_3 = 5,
56    /// 5-bit quantization (symmetric)
57    Q5_0 = 6,
58    /// 5-bit quantization with offset
59    Q5_1 = 7,
60    /// 8-bit quantization (symmetric)
61    Q8_0 = 8,
62    /// 8-bit quantization with offset
63    Q8_1 = 9,
64    /// 2-bit k-quant
65    Q2_K = 10,
66    /// 3-bit k-quant
67    Q3_K = 11,
68    /// 4-bit k-quant
69    Q4_K = 12,
70    /// 5-bit k-quant
71    Q5_K = 13,
72    /// 6-bit k-quant
73    Q6_K = 14,
74    /// 8-bit k-quant
75    Q8_K = 15,
76    /// I-quant 2-bit extreme extra small
77    IQ2_XXS = 16,
78    /// I-quant 2-bit extra small
79    IQ2_XS = 17,
80    /// I-quant 3-bit extra extra small
81    IQ3_XXS = 18,
82    /// I-quant 1-bit small
83    IQ1_S = 19,
84    /// I-quant 4-bit non-linear
85    IQ4_NL = 20,
86    /// I-quant 3-bit small
87    IQ3_S = 21,
88    /// I-quant 2-bit small
89    IQ2_S = 22,
90    /// I-quant 4-bit extra small
91    IQ4_XS = 23,
92    /// 16-bit integer
93    I8 = 24,
94    /// 16-bit integer
95    I16 = 25,
96    /// 32-bit integer
97    I32 = 26,
98    /// 64-bit integer
99    I64 = 27,
100    /// 64-bit floating point
101    F64 = 28,
102    /// BF16 brain float
103    Bf16 = 29,
104    /// BitNet b1.58 ternary quantization (2-bit packed)
105    BitnetT158 = 30,
106}
107
108impl TryFrom<u32> for GgufQuantType {
109    type Error = RuvLLMError;
110
111    fn try_from(value: u32) -> Result<Self> {
112        match value {
113            0 => Ok(Self::F32),
114            1 => Ok(Self::F16),
115            2 => Ok(Self::Q4_0),
116            3 => Ok(Self::Q4_1),
117            4 => Ok(Self::Q4_2),
118            5 => Ok(Self::Q4_3),
119            6 => Ok(Self::Q5_0),
120            7 => Ok(Self::Q5_1),
121            8 => Ok(Self::Q8_0),
122            9 => Ok(Self::Q8_1),
123            10 => Ok(Self::Q2_K),
124            11 => Ok(Self::Q3_K),
125            12 => Ok(Self::Q4_K),
126            13 => Ok(Self::Q5_K),
127            14 => Ok(Self::Q6_K),
128            15 => Ok(Self::Q8_K),
129            16 => Ok(Self::IQ2_XXS),
130            17 => Ok(Self::IQ2_XS),
131            18 => Ok(Self::IQ3_XXS),
132            19 => Ok(Self::IQ1_S),
133            20 => Ok(Self::IQ4_NL),
134            21 => Ok(Self::IQ3_S),
135            22 => Ok(Self::IQ2_S),
136            23 => Ok(Self::IQ4_XS),
137            24 => Ok(Self::I8),
138            25 => Ok(Self::I16),
139            26 => Ok(Self::I32),
140            27 => Ok(Self::I64),
141            28 => Ok(Self::F64),
142            29 => Ok(Self::Bf16),
143            30 => Ok(Self::BitnetT158),
144            _ => Err(RuvLLMError::Model(format!(
145                "Unknown GGUF quantization type: {}",
146                value
147            ))),
148        }
149    }
150}
151
152impl GgufQuantType {
153    /// Get the block size for this quantization type.
154    ///
155    /// Quantization operates on blocks of elements. Non-quantized types
156    /// have a block size of 1.
157    pub fn block_size(&self) -> usize {
158        match self {
159            Self::F32 | Self::F16 | Self::Bf16 | Self::F64 => 1,
160            Self::I8 | Self::I16 | Self::I32 | Self::I64 => 1,
161            Self::Q4_0 | Self::Q4_1 | Self::Q4_2 | Self::Q4_3 => 32,
162            Self::Q5_0 | Self::Q5_1 => 32,
163            Self::Q8_0 | Self::Q8_1 => 32,
164            Self::Q2_K | Self::Q3_K | Self::Q4_K | Self::Q5_K | Self::Q6_K | Self::Q8_K => 256,
165            Self::IQ2_XXS | Self::IQ2_XS | Self::IQ2_S => 256,
166            Self::IQ3_XXS | Self::IQ3_S => 256,
167            Self::IQ1_S => 256,
168            Self::IQ4_NL => 32,
169            Self::IQ4_XS => 256,
170            Self::BitnetT158 => 256,
171        }
172    }
173
174    /// Get the size in bytes for one block of this type.
175    ///
176    /// This is the storage size for `block_size()` elements.
177    pub fn type_size(&self) -> usize {
178        match self {
179            Self::F32 => 4,
180            Self::F16 => 2,
181            Self::Bf16 => 2,
182            Self::F64 => 8,
183            Self::I8 => 1,
184            Self::I16 => 2,
185            Self::I32 => 4,
186            Self::I64 => 8,
187            // Q4_0: 32 elements -> half (16 bytes) + scale (2 bytes f16) = 18 bytes
188            Self::Q4_0 => 18,
189            // Q4_1: 32 elements -> half (16 bytes) + scale (2 bytes) + min (2 bytes) = 20 bytes
190            Self::Q4_1 => 20,
191            Self::Q4_2 => 18, // Deprecated
192            Self::Q4_3 => 20, // Deprecated
193            // Q5_0: 32 elements -> scale (2) + quants (20) = 22 bytes
194            Self::Q5_0 => 22,
195            // Q5_1: 32 elements -> scale (2) + min (2) + quants (20) = 24 bytes
196            Self::Q5_1 => 24,
197            // Q8_0: 32 elements -> scale (2) + quants (32) = 34 bytes
198            Self::Q8_0 => 34,
199            // Q8_1: 32 elements -> scale (2) + offset (2) + quants (32) = 36 bytes
200            Self::Q8_1 => 36,
201            // Q2_K: 256 elements -> superblock structure
202            Self::Q2_K => 84,
203            // Q3_K: 256 elements
204            Self::Q3_K => 110,
205            // Q4_K: 256 elements -> d (2) + dmin (2) + scales (12) + qs (128) = 144 bytes
206            Self::Q4_K => 144,
207            // Q5_K: 256 elements
208            Self::Q5_K => 176,
209            // Q6_K: 256 elements
210            Self::Q6_K => 210,
211            // Q8_K: 256 elements
212            Self::Q8_K => 292,
213            // I-quants (approximate sizes)
214            Self::IQ2_XXS => 66,
215            Self::IQ2_XS => 74,
216            Self::IQ2_S => 82,
217            Self::IQ3_XXS => 98,
218            Self::IQ3_S => 110,
219            Self::IQ1_S => 50,
220            Self::IQ4_NL => 18,
221            Self::IQ4_XS => 136,
222            // BitNet b1.58: 256 elements -> 64 bytes (2-bit packed) + 2 bytes (FP16 scale) = 66 bytes
223            Self::BitnetT158 => 66,
224        }
225    }
226
227    /// Calculate the total byte size for a tensor with this dtype.
228    pub fn tensor_size(&self, num_elements: usize) -> usize {
229        let block_size = self.block_size();
230        let type_size = self.type_size();
231        let num_blocks = (num_elements + block_size - 1) / block_size;
232        num_blocks * type_size
233    }
234
235    /// Check if this is a quantized type.
236    pub fn is_quantized(&self) -> bool {
237        !matches!(
238            self,
239            Self::F32
240                | Self::F16
241                | Self::Bf16
242                | Self::F64
243                | Self::I8
244                | Self::I16
245                | Self::I32
246                | Self::I64
247        )
248    }
249
250    /// Get approximate bits per weight.
251    pub fn bits_per_weight(&self) -> f32 {
252        let type_size = self.type_size() as f32;
253        let block_size = self.block_size() as f32;
254        (type_size * 8.0) / block_size
255    }
256
257    /// Get the name as used in GGUF files.
258    pub fn name(&self) -> &'static str {
259        match self {
260            Self::F32 => "F32",
261            Self::F16 => "F16",
262            Self::Bf16 => "BF16",
263            Self::F64 => "F64",
264            Self::I8 => "I8",
265            Self::I16 => "I16",
266            Self::I32 => "I32",
267            Self::I64 => "I64",
268            Self::Q4_0 => "Q4_0",
269            Self::Q4_1 => "Q4_1",
270            Self::Q4_2 => "Q4_2",
271            Self::Q4_3 => "Q4_3",
272            Self::Q5_0 => "Q5_0",
273            Self::Q5_1 => "Q5_1",
274            Self::Q8_0 => "Q8_0",
275            Self::Q8_1 => "Q8_1",
276            Self::Q2_K => "Q2_K",
277            Self::Q3_K => "Q3_K",
278            Self::Q4_K => "Q4_K",
279            Self::Q5_K => "Q5_K",
280            Self::Q6_K => "Q6_K",
281            Self::Q8_K => "Q8_K",
282            Self::IQ2_XXS => "IQ2_XXS",
283            Self::IQ2_XS => "IQ2_XS",
284            Self::IQ2_S => "IQ2_S",
285            Self::IQ3_XXS => "IQ3_XXS",
286            Self::IQ3_S => "IQ3_S",
287            Self::IQ1_S => "IQ1_S",
288            Self::IQ4_NL => "IQ4_NL",
289            Self::IQ4_XS => "IQ4_XS",
290            Self::BitnetT158 => "BITNET_T158",
291        }
292    }
293}
294
295// ============================================================================
296// Quantized Tensor Container
297// ============================================================================
298
299/// Container for quantized tensor data.
300///
301/// This struct holds the raw quantized bytes along with metadata
302/// needed for dequantization.
303#[derive(Debug, Clone)]
304pub struct QuantizedTensor {
305    /// Raw quantized data bytes
306    pub data: Vec<u8>,
307    /// Quantization type
308    pub dtype: GgufQuantType,
309    /// Tensor shape
310    pub shape: Vec<usize>,
311    /// Total number of elements
312    pub num_elements: usize,
313}
314
315impl QuantizedTensor {
316    /// Dequantize to FP32.
317    pub fn dequantize(&self) -> Result<Vec<f32>> {
318        dequantize_tensor(&self.data, self.dtype, self.num_elements)
319    }
320
321    /// Get the block count.
322    pub fn block_count(&self) -> usize {
323        let block_size = self.dtype.block_size();
324        (self.num_elements + block_size - 1) / block_size
325    }
326}
327
328// ============================================================================
329// Dequantization Functions
330// ============================================================================
331
332/// Dequantize a tensor from raw bytes to FP32.
333///
334/// # Arguments
335///
336/// * `data` - Raw quantized bytes
337/// * `dtype` - Quantization type
338/// * `num_elements` - Total number of output elements
339///
340/// # Returns
341///
342/// Vector of FP32 values
343pub fn dequantize_tensor(
344    data: &[u8],
345    dtype: GgufQuantType,
346    num_elements: usize,
347) -> Result<Vec<f32>> {
348    let mut output = vec![0.0f32; num_elements];
349
350    match dtype {
351        GgufQuantType::F32 => dequantize_f32(data, &mut output),
352        GgufQuantType::F16 => dequantize_f16(data, &mut output),
353        GgufQuantType::Bf16 => dequantize_bf16(data, &mut output),
354        GgufQuantType::Q4_0 => dequantize_q4_0(data, &mut output),
355        GgufQuantType::Q4_1 => dequantize_q4_1(data, &mut output),
356        GgufQuantType::Q5_0 => dequantize_q5_0(data, &mut output),
357        GgufQuantType::Q5_1 => dequantize_q5_1(data, &mut output),
358        GgufQuantType::Q8_0 => dequantize_q8_0(data, &mut output),
359        GgufQuantType::Q8_1 => dequantize_q8_1(data, &mut output),
360        GgufQuantType::Q2_K => dequantize_q2_k(data, &mut output),
361        GgufQuantType::Q3_K => dequantize_q3_k(data, &mut output),
362        GgufQuantType::Q4_K => dequantize_q4_k(data, &mut output),
363        GgufQuantType::Q5_K => dequantize_q5_k(data, &mut output),
364        GgufQuantType::Q6_K => dequantize_q6_k(data, &mut output),
365        GgufQuantType::IQ4_NL => dequantize_iq4_nl(data, &mut output),
366        GgufQuantType::BitnetT158 => dequantize_bitnet_t158_wrapper(data, &mut output),
367        GgufQuantType::IQ1_S => {
368            return Err(RuvLLMError::Model(
369                "IQ1_S dequantization requires codebook lookup tables (not yet implemented). \
370                 For BitNet ternary quantization, use BITNET_T158 type instead."
371                    .to_string(),
372            ));
373        }
374        _ => {
375            return Err(RuvLLMError::Model(format!(
376                "Dequantization not implemented for {:?}",
377                dtype
378            )));
379        }
380    }
381
382    Ok(output)
383}
384
385/// Dequantize a single block.
386///
387/// # Arguments
388///
389/// * `data` - Raw block bytes
390/// * `dtype` - Quantization type
391/// * `output` - Output buffer (must have capacity for block_size elements)
392pub fn dequantize_block(data: &[u8], dtype: GgufQuantType, output: &mut [f32]) {
393    match dtype {
394        GgufQuantType::Q4_0 => dequantize_q4_0_block(data, output),
395        GgufQuantType::Q4_1 => dequantize_q4_1_block(data, output),
396        GgufQuantType::Q8_0 => dequantize_q8_0_block(data, output),
397        GgufQuantType::Q4_K => dequantize_q4_k_block(data, output),
398        GgufQuantType::BitnetT158 => dequantize_bitnet_t158_block_wrapper(data, output),
399        _ => {
400            // Fallback: fill with zeros
401            output.fill(0.0);
402        }
403    }
404}
405
406/// Dequantize a single BITNET_T158 block from GGUF format.
407///
408/// Block format (66 bytes):
409/// - 64 bytes: packed 2-bit ternary data
410/// - 2 bytes: FP16 scale
411fn dequantize_bitnet_t158_block_wrapper(data: &[u8], output: &mut [f32]) {
412    if data.len() < BITNET_T158_TYPE_SIZE {
413        output.fill(0.0);
414        return;
415    }
416
417    // Extract packed data (first 64 bytes)
418    let packed = &data[..64];
419
420    // Extract scale (last 2 bytes)
421    let scale = f16_to_f32(u16::from_le_bytes([data[64], data[65]]));
422
423    // Dequantize using bitnet module (expects 256 elements)
424    let min_output_len = output.len().min(BITNET_T158_BLOCK_SIZE);
425    let dequantized = dequantize_bitnet_t158(packed, &[scale], min_output_len);
426
427    // Copy to output
428    output[..dequantized.len()].copy_from_slice(&dequantized);
429}
430
431// ============================================================================
432// F32/F16/BF16 (No Quantization)
433// ============================================================================
434
435fn dequantize_f32(data: &[u8], output: &mut [f32]) {
436    for (i, chunk) in data.chunks_exact(4).enumerate() {
437        if i >= output.len() {
438            break;
439        }
440        output[i] = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
441    }
442}
443
444fn dequantize_f16(data: &[u8], output: &mut [f32]) {
445    for (i, chunk) in data.chunks_exact(2).enumerate() {
446        if i >= output.len() {
447            break;
448        }
449        let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
450        output[i] = f16_to_f32(bits);
451    }
452}
453
454fn dequantize_bf16(data: &[u8], output: &mut [f32]) {
455    for (i, chunk) in data.chunks_exact(2).enumerate() {
456        if i >= output.len() {
457            break;
458        }
459        let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
460        // BF16 is upper 16 bits of F32
461        output[i] = f32::from_bits((bits as u32) << 16);
462    }
463}
464
465// ============================================================================
466// Q4_0: 4-bit Symmetric Quantization
467// ============================================================================
468
469/// Q4_0 block structure: scale (f16) + 16 bytes (32 4-bit values)
470const Q4_0_BLOCK_SIZE: usize = 32;
471const Q4_0_TYPE_SIZE: usize = 18; // 2 + 16
472
473fn dequantize_q4_0(data: &[u8], output: &mut [f32]) {
474    let num_blocks = output.len() / Q4_0_BLOCK_SIZE;
475
476    for block_idx in 0..num_blocks {
477        let block_start = block_idx * Q4_0_TYPE_SIZE;
478        let out_start = block_idx * Q4_0_BLOCK_SIZE;
479
480        if block_start + Q4_0_TYPE_SIZE > data.len() {
481            break;
482        }
483
484        let block = &data[block_start..block_start + Q4_0_TYPE_SIZE];
485        let out = &mut output[out_start..out_start + Q4_0_BLOCK_SIZE];
486
487        dequantize_q4_0_block(block, out);
488    }
489}
490
491fn dequantize_q4_0_block(block: &[u8], output: &mut [f32]) {
492    // Scale is stored as f16 in first 2 bytes
493    let scale = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
494
495    // 16 bytes of packed 4-bit values (2 values per byte)
496    for i in 0..16 {
497        let byte = block[2 + i];
498        let q0 = (byte & 0x0F) as i8 - 8; // Q4_0 uses offset of 8
499        let q1 = ((byte >> 4) & 0x0F) as i8 - 8;
500
501        output[i * 2] = (q0 as f32) * scale;
502        output[i * 2 + 1] = (q1 as f32) * scale;
503    }
504}
505
506// ============================================================================
507// Q4_1: 4-bit Asymmetric Quantization
508// ============================================================================
509
510const Q4_1_BLOCK_SIZE: usize = 32;
511const Q4_1_TYPE_SIZE: usize = 20; // 2 + 2 + 16
512
513fn dequantize_q4_1(data: &[u8], output: &mut [f32]) {
514    let num_blocks = output.len() / Q4_1_BLOCK_SIZE;
515
516    for block_idx in 0..num_blocks {
517        let block_start = block_idx * Q4_1_TYPE_SIZE;
518        let out_start = block_idx * Q4_1_BLOCK_SIZE;
519
520        if block_start + Q4_1_TYPE_SIZE > data.len() {
521            break;
522        }
523
524        let block = &data[block_start..block_start + Q4_1_TYPE_SIZE];
525        let out = &mut output[out_start..out_start + Q4_1_BLOCK_SIZE];
526
527        dequantize_q4_1_block(block, out);
528    }
529}
530
531fn dequantize_q4_1_block(block: &[u8], output: &mut [f32]) {
532    let scale = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
533    let min = f16_to_f32(u16::from_le_bytes([block[2], block[3]]));
534
535    for i in 0..16 {
536        let byte = block[4 + i];
537        let q0 = (byte & 0x0F) as f32;
538        let q1 = ((byte >> 4) & 0x0F) as f32;
539
540        output[i * 2] = q0 * scale + min;
541        output[i * 2 + 1] = q1 * scale + min;
542    }
543}
544
545// ============================================================================
546// Q5_0: 5-bit Symmetric Quantization
547// ============================================================================
548
549const Q5_0_BLOCK_SIZE: usize = 32;
550const Q5_0_TYPE_SIZE: usize = 22; // 2 + 4 (high bits) + 16 (low bits)
551
552fn dequantize_q5_0(data: &[u8], output: &mut [f32]) {
553    let num_blocks = output.len() / Q5_0_BLOCK_SIZE;
554
555    for block_idx in 0..num_blocks {
556        let block_start = block_idx * Q5_0_TYPE_SIZE;
557        let out_start = block_idx * Q5_0_BLOCK_SIZE;
558
559        if block_start + Q5_0_TYPE_SIZE > data.len() {
560            break;
561        }
562
563        let scale = f16_to_f32(u16::from_le_bytes([
564            data[block_start],
565            data[block_start + 1],
566        ]));
567
568        // 4 bytes for high bits (32 values, 1 bit each)
569        let qh = u32::from_le_bytes([
570            data[block_start + 2],
571            data[block_start + 3],
572            data[block_start + 4],
573            data[block_start + 5],
574        ]);
575
576        // 16 bytes for low 4 bits
577        for i in 0..16 {
578            let byte = data[block_start + 6 + i];
579            let h0 = ((qh >> (i * 2)) & 1) as i8;
580            let h1 = ((qh >> (i * 2 + 1)) & 1) as i8;
581
582            let q0 = ((byte & 0x0F) as i8 | (h0 << 4)) - 16;
583            let q1 = (((byte >> 4) & 0x0F) as i8 | (h1 << 4)) - 16;
584
585            output[out_start + i * 2] = (q0 as f32) * scale;
586            output[out_start + i * 2 + 1] = (q1 as f32) * scale;
587        }
588    }
589}
590
591// ============================================================================
592// Q5_1: 5-bit Asymmetric Quantization
593// ============================================================================
594
595const Q5_1_BLOCK_SIZE: usize = 32;
596const Q5_1_TYPE_SIZE: usize = 24; // 2 + 2 + 4 + 16
597
598fn dequantize_q5_1(data: &[u8], output: &mut [f32]) {
599    let num_blocks = output.len() / Q5_1_BLOCK_SIZE;
600
601    for block_idx in 0..num_blocks {
602        let block_start = block_idx * Q5_1_TYPE_SIZE;
603        let out_start = block_idx * Q5_1_BLOCK_SIZE;
604
605        if block_start + Q5_1_TYPE_SIZE > data.len() {
606            break;
607        }
608
609        let scale = f16_to_f32(u16::from_le_bytes([
610            data[block_start],
611            data[block_start + 1],
612        ]));
613        let min = f16_to_f32(u16::from_le_bytes([
614            data[block_start + 2],
615            data[block_start + 3],
616        ]));
617
618        let qh = u32::from_le_bytes([
619            data[block_start + 4],
620            data[block_start + 5],
621            data[block_start + 6],
622            data[block_start + 7],
623        ]);
624
625        for i in 0..16 {
626            let byte = data[block_start + 8 + i];
627            let h0 = ((qh >> (i * 2)) & 1) as u8;
628            let h1 = ((qh >> (i * 2 + 1)) & 1) as u8;
629
630            let q0 = ((byte & 0x0F) | (h0 << 4)) as f32;
631            let q1 = (((byte >> 4) & 0x0F) | (h1 << 4)) as f32;
632
633            output[out_start + i * 2] = q0 * scale + min;
634            output[out_start + i * 2 + 1] = q1 * scale + min;
635        }
636    }
637}
638
639// ============================================================================
640// Q8_0: 8-bit Symmetric Quantization
641// ============================================================================
642
643const Q8_0_BLOCK_SIZE: usize = 32;
644const Q8_0_TYPE_SIZE: usize = 34; // 2 + 32
645
646fn dequantize_q8_0(data: &[u8], output: &mut [f32]) {
647    let num_blocks = output.len() / Q8_0_BLOCK_SIZE;
648
649    for block_idx in 0..num_blocks {
650        let block_start = block_idx * Q8_0_TYPE_SIZE;
651        let out_start = block_idx * Q8_0_BLOCK_SIZE;
652
653        if block_start + Q8_0_TYPE_SIZE > data.len() {
654            break;
655        }
656
657        let block = &data[block_start..block_start + Q8_0_TYPE_SIZE];
658        let out = &mut output[out_start..out_start + Q8_0_BLOCK_SIZE];
659
660        dequantize_q8_0_block(block, out);
661    }
662}
663
664fn dequantize_q8_0_block(block: &[u8], output: &mut [f32]) {
665    let scale = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
666
667    for i in 0..32 {
668        let q = block[2 + i] as i8;
669        output[i] = (q as f32) * scale;
670    }
671}
672
673// ============================================================================
674// Q8_1: 8-bit Asymmetric Quantization
675// ============================================================================
676
677const Q8_1_BLOCK_SIZE: usize = 32;
678const Q8_1_TYPE_SIZE: usize = 36; // 2 + 2 + 32
679
680fn dequantize_q8_1(data: &[u8], output: &mut [f32]) {
681    let num_blocks = output.len() / Q8_1_BLOCK_SIZE;
682
683    for block_idx in 0..num_blocks {
684        let block_start = block_idx * Q8_1_TYPE_SIZE;
685        let out_start = block_idx * Q8_1_BLOCK_SIZE;
686
687        if block_start + Q8_1_TYPE_SIZE > data.len() {
688            break;
689        }
690
691        let scale = f16_to_f32(u16::from_le_bytes([
692            data[block_start],
693            data[block_start + 1],
694        ]));
695        let offset = f16_to_f32(u16::from_le_bytes([
696            data[block_start + 2],
697            data[block_start + 3],
698        ]));
699
700        for i in 0..32 {
701            let q = data[block_start + 4 + i] as i8;
702            output[out_start + i] = (q as f32) * scale + offset;
703        }
704    }
705}
706
707// ============================================================================
708// Q2_K: 2-bit K-Quant
709// ============================================================================
710
711const Q2_K_BLOCK_SIZE: usize = 256;
712const Q2_K_TYPE_SIZE: usize = 84;
713
714fn dequantize_q2_k(data: &[u8], output: &mut [f32]) {
715    let num_blocks = output.len() / Q2_K_BLOCK_SIZE;
716
717    for block_idx in 0..num_blocks {
718        let block_start = block_idx * Q2_K_TYPE_SIZE;
719        let out_start = block_idx * Q2_K_BLOCK_SIZE;
720
721        if block_start + Q2_K_TYPE_SIZE > data.len() {
722            break;
723        }
724
725        // Q2_K structure:
726        // scales: [16] 4-bit scales
727        // d: f16 super scale
728        // dmin: f16 super min
729        // qs: [64] 2-bit values (4 per byte)
730
731        let block = &data[block_start..];
732
733        let d = f16_to_f32(u16::from_le_bytes([block[16], block[17]]));
734        let dmin = f16_to_f32(u16::from_le_bytes([block[18], block[19]]));
735
736        for j in 0..16 {
737            // Each sub-block of 16 elements
738            let sc = (block[j / 2] >> ((j % 2) * 4)) & 0x0F;
739            let scale = d * (sc as f32);
740            let min = dmin * (sc as f32);
741
742            for k in 0..16 {
743                let idx = j * 16 + k;
744                let byte_idx = 20 + idx / 4;
745                let bit_idx = (idx % 4) * 2;
746                let q = (block[byte_idx] >> bit_idx) & 0x03;
747                output[out_start + idx] = (q as f32) * scale - min;
748            }
749        }
750    }
751}
752
753// ============================================================================
754// Q3_K: 3-bit K-Quant
755// ============================================================================
756
757const Q3_K_BLOCK_SIZE: usize = 256;
758const Q3_K_TYPE_SIZE: usize = 110;
759
760fn dequantize_q3_k(data: &[u8], output: &mut [f32]) {
761    let num_blocks = output.len() / Q3_K_BLOCK_SIZE;
762
763    for block_idx in 0..num_blocks {
764        let block_start = block_idx * Q3_K_TYPE_SIZE;
765        let out_start = block_idx * Q3_K_BLOCK_SIZE;
766
767        if block_start + Q3_K_TYPE_SIZE > data.len() {
768            break;
769        }
770
771        // Simplified Q3_K dequantization
772        let block = &data[block_start..];
773        let d = f16_to_f32(u16::from_le_bytes([block[104], block[105]]));
774
775        // High bits, scales, and low bits are interleaved in complex way
776        // This is a simplified implementation
777        for i in 0..256 {
778            let byte_idx = i * 3 / 8;
779            let bit_offset = (i * 3) % 8;
780
781            if byte_idx < 96 {
782                let q = ((block[byte_idx] >> bit_offset) & 0x07) as i8 - 4;
783                output[out_start + i] = (q as f32) * d;
784            }
785        }
786    }
787}
788
789// ============================================================================
790// Q4_K: 4-bit K-Quant (Most Common)
791// ============================================================================
792
793const Q4_K_BLOCK_SIZE: usize = 256;
794const Q4_K_TYPE_SIZE: usize = 144; // d(2) + dmin(2) + scales(12) + qs(128)
795
796fn dequantize_q4_k(data: &[u8], output: &mut [f32]) {
797    let num_blocks = output.len() / Q4_K_BLOCK_SIZE;
798
799    for block_idx in 0..num_blocks {
800        let block_start = block_idx * Q4_K_TYPE_SIZE;
801        let out_start = block_idx * Q4_K_BLOCK_SIZE;
802
803        if block_start + Q4_K_TYPE_SIZE > data.len() {
804            break;
805        }
806
807        let block = &data[block_start..block_start + Q4_K_TYPE_SIZE];
808        let out = &mut output[out_start..out_start + Q4_K_BLOCK_SIZE];
809
810        dequantize_q4_k_block(block, out);
811    }
812}
813
814fn dequantize_q4_k_block(block: &[u8], output: &mut [f32]) {
815    // Block layout: d (2) + dmin (2) + scales (12) + qs (128)
816    let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
817    let dmin = f16_to_f32(u16::from_le_bytes([block[2], block[3]]));
818
819    // Process each of 8 sub-blocks of 32 elements
820    for sb in 0..8 {
821        // Extract 6-bit scale for this sub-block
822        let scale_idx = sb * 6 / 8;
823        let scale_shift = (sb * 6) % 8;
824
825        let mut sc = (block[4 + scale_idx] >> scale_shift) & 0x3F;
826        if scale_shift > 2 && scale_idx + 1 < 12 {
827            sc |= (block[4 + scale_idx + 1] << (8 - scale_shift)) & 0x3F;
828        }
829
830        let scale = d * (sc as f32);
831
832        // Dequantize 32 elements in this sub-block
833        let qs_start = 16 + sb * 16; // 16 bytes header + 16 bytes per sub-block
834        for i in 0..16 {
835            let byte = block[qs_start + i];
836            let q0 = (byte & 0x0F) as f32;
837            let q1 = ((byte >> 4) & 0x0F) as f32;
838
839            output[sb * 32 + i * 2] = q0 * scale + dmin;
840            output[sb * 32 + i * 2 + 1] = q1 * scale + dmin;
841        }
842    }
843}
844
845// ============================================================================
846// Q5_K: 5-bit K-Quant
847// ============================================================================
848
849const Q5_K_BLOCK_SIZE: usize = 256;
850const Q5_K_TYPE_SIZE: usize = 176;
851
852fn dequantize_q5_k(data: &[u8], output: &mut [f32]) {
853    let num_blocks = output.len() / Q5_K_BLOCK_SIZE;
854
855    for block_idx in 0..num_blocks {
856        let block_start = block_idx * Q5_K_TYPE_SIZE;
857        let out_start = block_idx * Q5_K_BLOCK_SIZE;
858
859        if block_start + Q5_K_TYPE_SIZE > data.len() {
860            break;
861        }
862
863        let block = &data[block_start..];
864        let d = f16_to_f32(u16::from_le_bytes([block[0], block[1]]));
865        let dmin = f16_to_f32(u16::from_le_bytes([block[2], block[3]]));
866
867        // Simplified Q5_K - similar structure to Q4_K but with 5 bits
868        for i in 0..256 {
869            let byte_idx = 16 + (i * 5) / 8;
870            let bit_offset = (i * 5) % 8;
871
872            if byte_idx < Q5_K_TYPE_SIZE {
873                let mut q = (block[byte_idx] >> bit_offset) & 0x1F;
874                if bit_offset > 3 && byte_idx + 1 < Q5_K_TYPE_SIZE {
875                    q |= (block[byte_idx + 1] << (8 - bit_offset)) & 0x1F;
876                }
877                output[out_start + i] = (q as f32) * d + dmin;
878            }
879        }
880    }
881}
882
883// ============================================================================
884// Q6_K: 6-bit K-Quant
885// ============================================================================
886
887const Q6_K_BLOCK_SIZE: usize = 256;
888const Q6_K_TYPE_SIZE: usize = 210;
889
890fn dequantize_q6_k(data: &[u8], output: &mut [f32]) {
891    let num_blocks = output.len() / Q6_K_BLOCK_SIZE;
892
893    for block_idx in 0..num_blocks {
894        let block_start = block_idx * Q6_K_TYPE_SIZE;
895        let out_start = block_idx * Q6_K_BLOCK_SIZE;
896
897        if block_start + Q6_K_TYPE_SIZE > data.len() {
898            break;
899        }
900
901        let block = &data[block_start..];
902        let d = f16_to_f32(u16::from_le_bytes([block[208], block[209]]));
903
904        // Q6_K has complex bit packing
905        // Low 4 bits: ql[128]
906        // High 2 bits: qh[64]
907        // Scales: scales[16]
908        for i in 0..256 {
909            let ql_idx = i / 2;
910            let is_high = i % 2 == 1;
911
912            if ql_idx < 128 {
913                let ql = if is_high {
914                    (block[ql_idx] >> 4) & 0x0F
915                } else {
916                    block[ql_idx] & 0x0F
917                };
918
919                let qh_idx = 128 + i / 4;
920                let qh_shift = (i % 4) * 2;
921                let qh = if qh_idx < 192 {
922                    (block[qh_idx] >> qh_shift) & 0x03
923                } else {
924                    0
925                };
926
927                let q = ((qh << 4) | ql) as i8 - 32;
928                let scale_idx = i / 16;
929                let sc = if scale_idx < 16 {
930                    (block[192 + scale_idx / 2] >> ((scale_idx % 2) * 4)) & 0x0F
931                } else {
932                    1
933                };
934
935                output[out_start + i] = (q as f32) * d * (sc as f32);
936            }
937        }
938    }
939}
940
941// ============================================================================
942// IQ4_NL: I-Quant 4-bit Non-Linear
943// ============================================================================
944
945const IQ4_NL_BLOCK_SIZE: usize = 32;
946const IQ4_NL_TYPE_SIZE: usize = 18;
947
948// Non-linear quantization lookup table (simplified version)
949const IQ4_NL_LUT: [f32; 16] = [
950    -1.0, -0.75, -0.5, -0.375, -0.25, -0.125, 0.0, 0.125, 0.25, 0.375, 0.5, 0.75, 1.0, 1.5, 2.0,
951    3.0,
952];
953
954fn dequantize_iq4_nl(data: &[u8], output: &mut [f32]) {
955    let num_blocks = output.len() / IQ4_NL_BLOCK_SIZE;
956
957    for block_idx in 0..num_blocks {
958        let block_start = block_idx * IQ4_NL_TYPE_SIZE;
959        let out_start = block_idx * IQ4_NL_BLOCK_SIZE;
960
961        if block_start + IQ4_NL_TYPE_SIZE > data.len() {
962            break;
963        }
964
965        let scale = f16_to_f32(u16::from_le_bytes([
966            data[block_start],
967            data[block_start + 1],
968        ]));
969
970        for i in 0..16 {
971            let byte = data[block_start + 2 + i];
972            let q0 = (byte & 0x0F) as usize;
973            let q1 = ((byte >> 4) & 0x0F) as usize;
974
975            output[out_start + i * 2] = IQ4_NL_LUT[q0] * scale;
976            output[out_start + i * 2 + 1] = IQ4_NL_LUT[q1] * scale;
977        }
978    }
979}
980
981// ============================================================================
982// BITNET_T158: BitNet b1.58 Ternary Quantization
983// ============================================================================
984
985const BITNET_T158_BLOCK_SIZE: usize = 256;
986const BITNET_T158_TYPE_SIZE: usize = 66; // 64 bytes packed + 2 bytes FP16 scale
987
988/// Wrapper for BitNet T158 dequantization from GGUF format.
989///
990/// GGUF BITNET_T158 block layout (66 bytes per 256 elements):
991/// - 64 bytes: packed 2-bit ternary data (256 values × 2 bits = 512 bits = 64 bytes)
992/// - 2 bytes: FP16 scale factor
993///
994/// This wrapper extracts scales from the interleaved GGUF format and passes
995/// them to the bitnet module's dequantization function.
996fn dequantize_bitnet_t158_wrapper(data: &[u8], output: &mut [f32]) {
997    let num_blocks = output.len() / BITNET_T158_BLOCK_SIZE;
998
999    // Extract scales from GGUF format (interleaved with packed data)
1000    let mut scales = Vec::with_capacity(num_blocks);
1001    let mut packed_data = Vec::with_capacity(num_blocks * 64);
1002
1003    for block_idx in 0..num_blocks {
1004        let block_start = block_idx * BITNET_T158_TYPE_SIZE;
1005
1006        if block_start + BITNET_T158_TYPE_SIZE > data.len() {
1007            break;
1008        }
1009
1010        // Extract 64 bytes of packed ternary data
1011        packed_data.extend_from_slice(&data[block_start..block_start + 64]);
1012
1013        // Extract FP16 scale (last 2 bytes of block)
1014        let scale_f16 = f16_to_f32(u16::from_le_bytes([
1015            data[block_start + 64],
1016            data[block_start + 65],
1017        ]));
1018        scales.push(scale_f16);
1019    }
1020
1021    // Call bitnet module's dequantization function
1022    let dequantized = dequantize_bitnet_t158(&packed_data, &scales, output.len());
1023
1024    // Copy to output buffer
1025    output[..dequantized.len()].copy_from_slice(&dequantized);
1026}
1027
1028// ============================================================================
1029// F16 Conversion Helper
1030// ============================================================================
1031
1032/// Convert f16 bits to f32.
1033#[inline(always)]
1034fn f16_to_f32(bits: u16) -> f32 {
1035    let sign = ((bits & 0x8000) as u32) << 16;
1036    let exp = ((bits >> 10) & 0x1F) as u32;
1037    let frac = (bits & 0x03FF) as u32;
1038
1039    if exp == 0 {
1040        if frac == 0 {
1041            return f32::from_bits(sign);
1042        }
1043        // Denormalized
1044        let mut e = 1u32;
1045        let mut f = frac;
1046        while (f & 0x0400) == 0 {
1047            f <<= 1;
1048            e += 1;
1049        }
1050        f &= 0x03FF;
1051        return f32::from_bits(sign | ((127 - 15 + 1 - e) << 23) | (f << 13));
1052    }
1053
1054    if exp == 31 {
1055        // Inf or NaN
1056        return f32::from_bits(sign | 0x7F80_0000 | (frac << 13));
1057    }
1058
1059    f32::from_bits(sign | ((exp + 127 - 15) << 23) | (frac << 13))
1060}
1061
1062// ============================================================================
1063// Tests
1064// ============================================================================
1065
1066#[cfg(test)]
1067mod tests {
1068    use super::*;
1069
1070    #[test]
1071    fn test_quant_type_sizes() {
1072        assert_eq!(GgufQuantType::F32.block_size(), 1);
1073        assert_eq!(GgufQuantType::F32.type_size(), 4);
1074
1075        assert_eq!(GgufQuantType::Q4_0.block_size(), 32);
1076        assert_eq!(GgufQuantType::Q4_0.type_size(), 18);
1077
1078        assert_eq!(GgufQuantType::Q4_K.block_size(), 256);
1079        assert_eq!(GgufQuantType::Q4_K.type_size(), 144);
1080    }
1081
1082    #[test]
1083    fn test_quant_type_bits() {
1084        // F32 = 32 bits
1085        assert!((GgufQuantType::F32.bits_per_weight() - 32.0).abs() < 0.1);
1086
1087        // Q4_0 = 18 bytes * 8 / 32 elements = 4.5 bits
1088        assert!((GgufQuantType::Q4_0.bits_per_weight() - 4.5).abs() < 0.1);
1089
1090        // Q8_0 = 34 bytes * 8 / 32 elements = 8.5 bits
1091        assert!((GgufQuantType::Q8_0.bits_per_weight() - 8.5).abs() < 0.1);
1092    }
1093
1094    #[test]
1095    fn test_f16_conversion() {
1096        // Test basic values
1097        assert_eq!(f16_to_f32(0x0000), 0.0);
1098        assert_eq!(f16_to_f32(0x3C00), 1.0);
1099        assert_eq!(f16_to_f32(0xBC00), -1.0);
1100
1101        // Test small values
1102        let half = f16_to_f32(0x3800); // 0.5 in f16
1103        assert!((half - 0.5).abs() < 0.001);
1104    }
1105
1106    #[test]
1107    fn test_q4_0_dequantize() {
1108        // Create a simple Q4_0 block: scale=1.0, all zeros
1109        let mut block = vec![0u8; 18];
1110        // f16 1.0 = 0x3C00
1111        block[0] = 0x00;
1112        block[1] = 0x3C;
1113        // All quants = 8 (which becomes 0 after offset subtraction)
1114        for i in 0..16 {
1115            block[2 + i] = 0x88; // Both nibbles = 8
1116        }
1117
1118        let mut output = vec![0.0f32; 32];
1119        dequantize_q4_0_block(&block, &mut output);
1120
1121        // All values should be 0
1122        for val in &output {
1123            assert!(val.abs() < 0.001);
1124        }
1125    }
1126
1127    #[test]
1128    fn test_q8_0_dequantize() {
1129        // Create a Q8_0 block
1130        let mut block = vec![0u8; 34];
1131        // scale = 1.0 (f16 0x3C00)
1132        block[0] = 0x00;
1133        block[1] = 0x3C;
1134        // quants = [1, 2, 3, ...]
1135        for i in 0..32 {
1136            block[2 + i] = (i + 1) as u8;
1137        }
1138
1139        let mut output = vec![0.0f32; 32];
1140        dequantize_q8_0_block(&block, &mut output);
1141
1142        // Values should be 1.0, 2.0, 3.0, ...
1143        for i in 0..32 {
1144            assert!((output[i] - (i + 1) as f32).abs() < 0.001);
1145        }
1146    }
1147
1148    #[test]
1149    fn test_quant_type_try_from() {
1150        assert_eq!(GgufQuantType::try_from(0).unwrap(), GgufQuantType::F32);
1151        assert_eq!(GgufQuantType::try_from(12).unwrap(), GgufQuantType::Q4_K);
1152        assert!(GgufQuantType::try_from(100).is_err());
1153    }
1154
1155    #[test]
1156    fn test_quantized_tensor() {
1157        let tensor = QuantizedTensor {
1158            data: vec![0u8; 144],
1159            dtype: GgufQuantType::Q4_K,
1160            shape: vec![256],
1161            num_elements: 256,
1162        };
1163
1164        assert_eq!(tensor.block_count(), 1);
1165        assert!(tensor.dtype.is_quantized());
1166    }
1167}