Skip to main content

tract_linalg/frame/block_quant/
mod.rs

1use downcast_rs::{Downcast, impl_downcast};
2use dyn_clone::{DynClone, clone_box};
3use dyn_hash::DynHash;
4use num_traits::Zero;
5use tract_data::internal::*;
6use tract_data::itertools::Itertools;
7
8use std::alloc::Layout;
9use std::borrow::Cow;
10use std::fmt::{Debug, Display};
11use std::hash::Hash;
12use std::sync::Arc;
13
14mod helpers;
15mod q4_0;
16mod q8_1;
17mod value;
18
19pub use helpers::{NibbleReader, NibbleWriter};
20pub use q4_0::Q4_0;
21pub use q8_1::Q8_1;
22pub use value::{BlockQuantFact, PackedBlockQuantFact};
23
24use crate::mmm::{EagerPackedInput, MMMInputFormat};
25use crate::pack::PackedFormat;
26
27use crate::WeightType;
28
29use super::mmm::MMMInputValue;
30
31pub trait BlockQuant: Debug + Display + Send + Sync + DynClone + DynHash + Downcast {
32    fn same_as(&self, other: &dyn BlockQuant) -> bool;
33
34    fn block_len(&self) -> usize;
35
36    fn block_bytes(&self) -> usize;
37
38    fn dequant_block_f32(&self, quant: &[u8], block: &mut [f32]);
39    fn dequant_block_f16(&self, quant: &[u8], block: &mut [f16]);
40    fn quant_block_f16(&self, block: &[f16], quant: &mut [u8]);
41    fn quant_block_f32(&self, block: &[f32], quant: &mut [u8]);
42
43    fn quant_f16(&self, input: &[f16]) -> TractResult<Blob> {
44        unsafe {
45            let blocks = input.len() / self.block_len();
46            let mut quant = Blob::for_layout(
47                Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(),
48            );
49            for b in 0..blocks {
50                let block = &input[b * self.block_len()..][..self.block_len()];
51                let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()];
52                self.quant_block_f16(block, qblock);
53            }
54            Ok(quant)
55        }
56    }
57
58    fn quant_f32(&self, input: &[f32]) -> TractResult<Blob> {
59        unsafe {
60            let blocks = input.len() / self.block_len();
61            let mut quant = Blob::for_layout(
62                Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(),
63            );
64            for b in 0..blocks {
65                let block = &input[b * self.block_len()..][..self.block_len()];
66                let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()];
67                self.quant_block_f32(block, qblock);
68            }
69            Ok(quant)
70        }
71    }
72
73    fn dequant_f32(&self, input: &[u8]) -> TractResult<Tensor> {
74        unsafe {
75            let blocks = input.len() / self.block_bytes();
76            let mut tensor = Tensor::uninitialized::<f32>(&[blocks * self.block_len()])?;
77            let mut tensor_dense = tensor.try_as_dense_mut()?;
78            let slice = tensor_dense.as_slice_mut::<f32>()?;
79            for b in 0..blocks {
80                let block = &mut slice[b * self.block_len()..][..self.block_len()];
81                let qblock = &input[b * self.block_bytes()..][..self.block_bytes()];
82                self.dequant_block_f32(qblock, block);
83            }
84            Ok(tensor)
85        }
86    }
87
88    fn dequant_f16(&self, input: &[u8]) -> TractResult<Tensor> {
89        unsafe {
90            let blocks = input.len() / self.block_bytes();
91            let mut tensor = Tensor::uninitialized::<f16>(&[blocks * self.block_len()])?;
92            let mut tensor_dense = tensor.try_as_dense_mut()?;
93            let slice = tensor_dense.as_slice_mut::<f16>()?;
94            for b in 0..blocks {
95                let block = &mut slice[b * self.block_len()..][..self.block_len()];
96                let qblock = &input[b * self.block_bytes()..][..self.block_bytes()];
97                self.dequant_block_f16(qblock, block);
98            }
99            Ok(tensor)
100        }
101    }
102
103    fn extract_at_offset_f16(&self, input: &[u8], offset: usize) -> f16 {
104        let len = self.block_len();
105        let block_id = offset / len;
106        let mut block = vec![f16::zero(); self.block_len()];
107        self.dequant_block_f16(
108            &input[block_id * self.block_bytes()..][..self.block_bytes()],
109            &mut block,
110        );
111        block[offset % len]
112    }
113
114    fn extract_at_offset_f32(&self, input: &[u8], offset: usize) -> f32 {
115        let len = self.block_len();
116        let block_id = offset / len;
117        let mut block = vec![f32::zero(); self.block_len()];
118        self.dequant_block_f32(
119            &input[block_id * self.block_bytes()..][..self.block_bytes()],
120            &mut block,
121        );
122        block[offset % len]
123    }
124
125    fn simulate_precision_loss(
126        &self,
127        mut tensor: Tensor,
128        block_axis: usize,
129    ) -> TractResult<Tensor> {
130        ensure!(block_axis == tensor.rank() - 1);
131        ensure!(tensor.shape()[block_axis] % self.block_len() == 0);
132        let mut scratch = vec![0u8; self.block_bytes()];
133        if tensor.datum_type() == f32::datum_type() {
134            let mut tensor_dense = tensor.try_as_dense_mut()?;
135            for block in tensor_dense.as_slice_mut::<f32>()?.chunks_mut(self.block_len()) {
136                self.quant_block_f32(block, &mut scratch);
137                self.dequant_block_f32(&scratch, block);
138            }
139            drop(tensor_dense);
140            Ok(tensor)
141        } else if tensor.datum_type() == f16::datum_type() {
142            let mut tensor_dense = tensor.try_as_dense_mut()?;
143            for block in tensor_dense.as_slice_mut::<f16>()?.chunks_mut(self.block_len()) {
144                self.quant_block_f16(block, &mut scratch);
145                self.dequant_block_f16(&scratch, block);
146            }
147            drop(tensor_dense);
148            Ok(tensor)
149        } else {
150            todo!()
151        }
152    }
153
154    fn pack(
155        &self,
156        input: &[u8],
157        k: usize,
158        r: usize,
159        zip: usize,
160        scales_at_end: bool,
161    ) -> TractResult<EagerPackedInput>;
162
163    unsafe fn extract_packed_panel(
164        &self,
165        value: &EagerPackedInput,
166        target: &PackedFormat,
167        panel: usize,
168        scratch: *mut u8,
169    ) -> TractResult<()>;
170
171    fn extract_at_mn_f16(
172        &self,
173        value: &EagerPackedInput,
174        mn: usize,
175        target: &mut [f16],
176    ) -> TractResult<()>;
177
178    fn extract_at_mn_f32(
179        &self,
180        value: &EagerPackedInput,
181        mn: usize,
182        target: &mut [f32],
183    ) -> TractResult<()>;
184}
185
186dyn_clone::clone_trait_object!(BlockQuant);
187dyn_hash::hash_trait_object!(BlockQuant);
188impl_downcast!(BlockQuant);
189
190#[allow(clippy::derived_hash_with_manual_eq)]
191#[derive(Clone, Hash)]
192pub struct PackedBlockQuantFormat {
193    pub bq: Box<dyn BlockQuant>,
194    pub r: usize,
195    pub zip: usize,
196    pub scales_at_end: bool,
197}
198
199impl PartialEq for PackedBlockQuantFormat {
200    fn eq(&self, other: &Self) -> bool {
201        self.bq.same_as(&*other.bq)
202            && self.r == other.r
203            && self.zip == other.zip
204            && self.scales_at_end == other.scales_at_end
205    }
206}
207
208impl Eq for PackedBlockQuantFormat {}
209
210impl Display for PackedBlockQuantFormat {
211    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212        write!(f, "Packed{}[{}]", &*self.bq, self.r)?;
213        if self.zip != 0 {
214            write!(f, "Z{}", self.zip)?;
215        }
216        if self.scales_at_end {
217            write!(f, "Se")?;
218        }
219        Ok(())
220    }
221}
222
223impl Debug for PackedBlockQuantFormat {
224    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
225        <Self as Display>::fmt(self, f)
226    }
227}
228
229impl PackedBlockQuantFormat {
230    pub fn new(bq: &dyn BlockQuant, r: usize, zip: usize, scales_at_end: bool) -> Self {
231        PackedBlockQuantFormat { bq: clone_box(bq), r, zip, scales_at_end }
232    }
233
234    pub fn simulate_precision_loss(
235        &self,
236        tensor: Tensor,
237        block_axis: usize,
238    ) -> TractResult<Tensor> {
239        self.bq.simulate_precision_loss(tensor, block_axis)
240    }
241
242    pub fn pack(&self, input: &[u8], k: usize) -> TractResult<EagerPackedInput> {
243        self.bq.pack(input, k, self.r, self.zip, self.scales_at_end)
244    }
245}
246
247impl MMMInputFormat for PackedBlockQuantFormat {
248    fn prepare_tensor(&self, t: &Tensor, _k_axis: usize, _mn_axis: usize) -> TractResult<Tensor> {
249        let packed = t
250            .try_as_dense()?
251            .as_slice::<Opaque>()?
252            .iter()
253            .map(|o| {
254                let bwf = o.downcast_ref::<BlobWithFact>().unwrap();
255                let bqf = bwf.fact.downcast_ref::<BlockQuantFact>().unwrap();
256                let packed = self.pack(&bwf.value, bqf.k())?;
257                Ok(Opaque(Arc::new(Box::new(packed) as Box<dyn MMMInputValue>)))
258            })
259            .collect::<TractResult<Vec<Opaque>>>()?;
260        tensor1(&packed).into_shape(t.shape())
261    }
262
263    fn prepare_one(
264        &self,
265        t: &Tensor,
266        k_axis: usize,
267        mn_axis: usize,
268    ) -> TractResult<Box<dyn MMMInputValue>> {
269        // this code path is essentially there for test scenarios
270        let t = if t.datum_type().is_number() {
271            let k = t.shape()[k_axis];
272            let m = t.shape()[mn_axis];
273            assert!(k % self.bq.block_len() == 0);
274            let t: Cow<Tensor> = if k_axis == 1 && mn_axis == 0 {
275                Cow::Borrowed(t)
276            } else {
277                Cow::Owned(t.clone().move_axis(1, 0)?)
278            };
279            let quant = if t.datum_type() == f32::datum_type() {
280                self.bq.quant_f32(t.try_as_dense()?.as_slice()?)?
281            } else if t.datum_type() == f16::datum_type() {
282                self.bq.quant_f16(t.try_as_dense()?.as_slice()?)?
283            } else {
284                todo!()
285            };
286            Cow::Owned(tensor0(Opaque(Arc::new(BlobWithFact {
287                value: Arc::new(quant),
288                fact: Box::new(BlockQuantFact::new(self.bq.clone(), tvec!(m, k))),
289            }))))
290        } else {
291            Cow::Borrowed(t)
292        };
293        ensure!(mn_axis == 0);
294        ensure!(k_axis == 1);
295        let bwf = t.try_as_dense()?.to_scalar::<Opaque>()?.downcast_ref::<BlobWithFact>().unwrap();
296        let bqf = bwf.fact.downcast_ref::<BlockQuantFact>().unwrap();
297        let packed = self.pack(&bwf.value, bqf.k())?;
298        Ok(Box::new(packed))
299    }
300
301    fn precursor(&self) -> WeightType {
302        WeightType::BlockQuant(self.bq.clone())
303    }
304
305    fn k_alignment(&self) -> usize {
306        self.bq.block_len()
307    }
308
309    fn r(&self) -> usize {
310        self.r
311    }
312
313    fn mem_size(&self, k: TDim, mn: TDim) -> TDim {
314        k * mn * self.bq.block_bytes() / self.bq.block_len()
315    }
316
317    fn same_as(&self, other: &dyn MMMInputFormat) -> bool {
318        other.downcast_ref::<Self>().is_some_and(|other| self == other)
319    }
320
321    fn extract_at_mn_f16(
322        &self,
323        data: &EagerPackedInput,
324        mn: usize,
325        slice: &mut [f16],
326    ) -> TractResult<()> {
327        self.bq.extract_at_mn_f16(data, mn, slice)
328    }
329
330    fn extract_at_mn_f32(
331        &self,
332        data: &EagerPackedInput,
333        mn: usize,
334        slice: &mut [f32],
335    ) -> TractResult<()> {
336        self.bq.extract_at_mn_f32(data, mn, slice)
337    }
338}