tract_linalg/frame/block_quant/
mod.rs

1use downcast_rs::{impl_downcast, Downcast};
2use dyn_clone::{clone_box, DynClone};
3use dyn_hash::DynHash;
4use num_traits::Zero;
5use tract_data::internal::*;
6use tract_data::itertools::Itertools;
7
8use std::alloc::Layout;
9use std::borrow::Cow;
10use std::fmt::{Debug, Display};
11use std::hash::Hash;
12use std::sync::Arc;
13
14mod helpers;
15mod q4_0;
16mod value;
17
18pub use helpers::{NibbleReader, NibbleWriter};
19pub use q4_0::Q4_0;
20pub use value::{BlockQuantFact, BlockQuantValue, PackedBlockQuantFact};
21
22use crate::mmm::{EagerPackedInput, MMMInputFormat};
23use crate::pack::PackedFormat;
24
25use crate::WeightType;
26
27use super::mmm::MMMInputValue;
28
29pub trait BlockQuant: Debug + Display + Send + Sync + DynClone + DynHash + Downcast {
30    fn same_as(&self, other: &dyn BlockQuant) -> bool;
31
32    fn block_len(&self) -> usize;
33
34    fn block_bytes(&self) -> usize;
35
36    fn dequant_block_f32(&self, quant: &[u8], block: &mut [f32]);
37    fn dequant_block_f16(&self, quant: &[u8], block: &mut [f16]);
38    fn quant_block_f16(&self, block: &[f16], quant: &mut [u8]);
39    fn quant_block_f32(&self, block: &[f32], quant: &mut [u8]);
40
41    fn quant_f16(&self, input: &[f16]) -> TractResult<Blob> {
42        unsafe {
43            let blocks = input.len() / self.block_len();
44            let mut quant = Blob::for_layout(
45                Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(),
46            );
47            for b in 0..blocks {
48                let block = &input[b * self.block_len()..][..self.block_len()];
49                let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()];
50                self.quant_block_f16(block, qblock);
51            }
52            Ok(quant)
53        }
54    }
55
56    fn quant_f32(&self, input: &[f32]) -> TractResult<Blob> {
57        unsafe {
58            let blocks = input.len() / self.block_len();
59            let mut quant = Blob::for_layout(
60                Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(),
61            );
62            for b in 0..blocks {
63                let block = &input[b * self.block_len()..][..self.block_len()];
64                let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()];
65                self.quant_block_f32(block, qblock);
66            }
67            Ok(quant)
68        }
69    }
70
71    fn dequant_f32(&self, input: &[u8]) -> TractResult<Tensor> {
72        unsafe {
73            let blocks = input.len() / self.block_bytes();
74            let mut tensor = Tensor::uninitialized::<f32>(&[blocks * self.block_len()])?;
75            let slice = tensor.as_slice_mut::<f32>()?;
76            for b in 0..blocks {
77                let block = &mut slice[b * self.block_len()..][..self.block_len()];
78                let qblock = &input[b * self.block_bytes()..][..self.block_bytes()];
79                self.dequant_block_f32(qblock, block);
80            }
81            Ok(tensor)
82        }
83    }
84
85    fn dequant_f16(&self, input: &[u8]) -> TractResult<Tensor> {
86        unsafe {
87            let blocks = input.len() / self.block_bytes();
88            let mut tensor = Tensor::uninitialized::<f16>(&[blocks * self.block_len()])?;
89            let slice = tensor.as_slice_mut::<f16>()?;
90            for b in 0..blocks {
91                let block = &mut slice[b * self.block_len()..][..self.block_len()];
92                let qblock = &input[b * self.block_bytes()..][..self.block_bytes()];
93                self.dequant_block_f16(qblock, block);
94            }
95            Ok(tensor)
96        }
97    }
98
99    fn extract_at_offset_f16(&self, input: &[u8], offset: usize) -> f16 {
100        let len = self.block_len();
101        let block_id = offset / len;
102        let mut block = vec![f16::zero(); self.block_len()];
103        self.dequant_block_f16(
104            &input[block_id * self.block_bytes()..][..self.block_bytes()],
105            &mut block,
106        );
107        block[offset % len]
108    }
109
110    fn extract_at_offset_f32(&self, input: &[u8], offset: usize) -> f32 {
111        let len = self.block_len();
112        let block_id = offset / len;
113        let mut block = vec![f32::zero(); self.block_len()];
114        self.dequant_block_f32(
115            &input[block_id * self.block_bytes()..][..self.block_bytes()],
116            &mut block,
117        );
118        block[offset % len]
119    }
120
121    fn simulate_precision_loss(
122        &self,
123        mut tensor: Tensor,
124        block_axis: usize,
125    ) -> TractResult<Tensor> {
126        ensure!(block_axis == tensor.rank() - 1);
127        ensure!(tensor.shape()[block_axis] % self.block_len() == 0);
128        let mut scratch = vec![0u8; self.block_bytes()];
129        if tensor.datum_type() == f32::datum_type() {
130            for block in tensor.as_slice_mut::<f32>()?.chunks_mut(self.block_len()) {
131                self.quant_block_f32(block, &mut scratch);
132                self.dequant_block_f32(&scratch, block);
133            }
134            Ok(tensor)
135        } else if tensor.datum_type() == f16::datum_type() {
136            for block in tensor.as_slice_mut::<f16>()?.chunks_mut(self.block_len()) {
137                self.quant_block_f16(block, &mut scratch);
138                self.dequant_block_f16(&scratch, block);
139            }
140            Ok(tensor)
141        } else {
142            todo!()
143        }
144    }
145
146    fn pack(
147        &self,
148        input: &[u8],
149        k: usize,
150        r: usize,
151        zip: usize,
152        scales_at_end: bool,
153    ) -> TractResult<EagerPackedInput>;
154
155    unsafe fn extract_packed_panel(
156        &self,
157        value: &EagerPackedInput,
158        target: &PackedFormat,
159        panel: usize,
160        scratch: *mut u8,
161    ) -> TractResult<()>;
162
163    fn extract_at_mn_f16(
164        &self,
165        value: &EagerPackedInput,
166        mn: usize,
167        target: &mut [f16],
168    ) -> TractResult<()>;
169
170    fn extract_at_mn_f32(
171        &self,
172        value: &EagerPackedInput,
173        mn: usize,
174        target: &mut [f32],
175    ) -> TractResult<()>;
176}
177
178dyn_clone::clone_trait_object!(BlockQuant);
179dyn_hash::hash_trait_object!(BlockQuant);
180impl_downcast!(BlockQuant);
181
182#[allow(clippy::derived_hash_with_manual_eq)]
183#[derive(Clone, Hash)]
184pub struct PackedBlockQuantFormat {
185    pub bq: Box<dyn BlockQuant>,
186    pub r: usize,
187    pub zip: usize,
188    pub scales_at_end: bool,
189}
190
191impl PartialEq for PackedBlockQuantFormat {
192    fn eq(&self, other: &Self) -> bool {
193        self.bq.same_as(&*other.bq)
194            && self.r == other.r
195            && self.zip == other.zip
196            && self.scales_at_end == other.scales_at_end
197    }
198}
199
200impl Eq for PackedBlockQuantFormat {}
201
202impl Display for PackedBlockQuantFormat {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        write!(f, "Packed{}[{}]", &*self.bq, self.r)?;
205        if self.zip != 0 {
206            write!(f, "Z{}", self.zip)?;
207        }
208        if self.scales_at_end {
209            write!(f, "Se")?;
210        }
211        Ok(())
212    }
213}
214
215impl Debug for PackedBlockQuantFormat {
216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        <Self as Display>::fmt(self, f)
218    }
219}
220
221impl PackedBlockQuantFormat {
222    pub fn new(bq: &dyn BlockQuant, r: usize, zip: usize, scales_at_end: bool) -> Self {
223        PackedBlockQuantFormat { bq: clone_box(bq), r, zip, scales_at_end }
224    }
225
226    pub fn simulate_precision_loss(
227        &self,
228        tensor: Tensor,
229        block_axis: usize,
230    ) -> TractResult<Tensor> {
231        self.bq.simulate_precision_loss(tensor, block_axis)
232    }
233
234    pub fn pack(&self, input: &[u8], k: usize) -> TractResult<EagerPackedInput> {
235        self.bq.pack(input, k, self.r, self.zip, self.scales_at_end)
236    }
237}
238
239impl MMMInputFormat for PackedBlockQuantFormat {
240    fn prepare_tensor(&self, t: &Tensor, _k_axis: usize, _mn_axis: usize) -> TractResult<Tensor> {
241        let packed = t
242            .as_slice::<Opaque>()?
243            .iter()
244            .map(|o| {
245                let bqv = o.downcast_ref::<BlockQuantValue>().unwrap();
246                let packed = self.pack(&bqv.value, bqv.fact.k())?;
247                Ok(Opaque(Arc::new(Box::new(packed) as Box<dyn MMMInputValue>)))
248            })
249            .collect::<TractResult<Vec<Opaque>>>()?;
250        tensor1(&packed).into_shape(t.shape())
251    }
252
253    fn prepare_one(
254        &self,
255        t: &Tensor,
256        k_axis: usize,
257        mn_axis: usize,
258    ) -> TractResult<Box<dyn MMMInputValue>> {
259        // this code path is essentially there for test scenarios
260        let t = if t.datum_type().is_number() {
261            let k = t.shape()[k_axis];
262            let m = t.shape()[mn_axis];
263            assert!(k % self.bq.block_len() == 0);
264            let t: Cow<Tensor> = if k_axis == 1 && mn_axis == 0 {
265                Cow::Borrowed(t)
266            } else {
267                Cow::Owned(t.clone().move_axis(1, 0)?)
268            };
269            let quant = if t.datum_type() == f32::datum_type() {
270                self.bq.quant_f32(t.as_slice()?)?
271            } else if t.datum_type() == f16::datum_type() {
272                self.bq.quant_f16(t.as_slice()?)?
273            } else {
274                todo!()
275            };
276            Cow::Owned(tensor0(Opaque(Arc::new(BlockQuantValue {
277                value: Arc::new(quant),
278                fact: BlockQuantFact::new(self.bq.clone(), tvec!(m, k)),
279            }))))
280        } else {
281            Cow::Borrowed(t)
282        };
283        ensure!(mn_axis == 0);
284        ensure!(k_axis == 1);
285        let bqv = t.to_scalar::<Opaque>()?.downcast_ref::<BlockQuantValue>().unwrap();
286        let packed = self.pack(&bqv.value, bqv.fact.k())?;
287        Ok(Box::new(packed))
288    }
289
290    fn precursor(&self) -> WeightType {
291        WeightType::BlockQuant(self.bq.clone())
292    }
293
294    fn k_alignment(&self) -> usize {
295        self.bq.block_len()
296    }
297
298    fn r(&self) -> usize {
299        self.r
300    }
301
302    fn mem_size(&self, k: TDim, mn: TDim) -> TDim {
303        k * mn * self.bq.block_bytes() / self.bq.block_len()
304    }
305
306    fn same_as(&self, other: &dyn MMMInputFormat) -> bool {
307        other.downcast_ref::<Self>().is_some_and(|other| self == other)
308    }
309
310    fn extract_at_mn_f16(
311        &self,
312        data: &EagerPackedInput,
313        mn: usize,
314        slice: &mut [f16],
315    ) -> TractResult<()> {
316        self.bq.extract_at_mn_f16(data, mn, slice)
317    }
318
319    fn extract_at_mn_f32(
320        &self,
321        data: &EagerPackedInput,
322        mn: usize,
323        slice: &mut [f32],
324    ) -> TractResult<()> {
325        self.bq.extract_at_mn_f32(data, mn, slice)
326    }
327}