Skip to main content

tract_linalg/frame/block_quant/
mod.rs

1use downcast_rs::{Downcast, impl_downcast};
2use dyn_clone::{DynClone, clone_box};
3use dyn_eq::DynEq;
4use dyn_hash::DynHash;
5use num_traits::Zero;
6use tract_data::internal::*;
7use tract_data::itertools::Itertools;
8
9use std::alloc::Layout;
10use std::borrow::Cow;
11use std::fmt::{Debug, Display};
12use std::hash::Hash;
13use std::sync::Arc;
14
15mod helpers;
16mod q4_0;
17mod q8_1;
18mod storage;
19mod value;
20
21pub use helpers::{NibbleReader, NibbleWriter};
22pub use q4_0::Q4_0;
23pub use q8_1::Q8_1;
24pub use storage::{BlockQuantStorage, block_quant_slice};
25pub use value::{BlockQuantFact, PackedBlockQuantFact};
26
27use crate::mmm::{EagerPackedInput, MMMInputFormat};
28use crate::pack::PackedFormat;
29
30use crate::WeightType;
31
32use super::mmm::MMMInputValue;
33
34pub trait BlockQuant:
35    Debug + Display + Send + Sync + DynClone + DynHash + dyn_eq::DynEq + Downcast
36{
37    fn block_len(&self) -> usize;
38
39    fn block_bytes(&self) -> usize;
40
41    fn dequant_block_f32(&self, quant: &[u8], block: &mut [f32]);
42    fn dequant_block_f16(&self, quant: &[u8], block: &mut [f16]);
43    fn quant_block_f16(&self, block: &[f16], quant: &mut [u8]);
44    fn quant_block_f32(&self, block: &[f32], quant: &mut [u8]);
45
46    fn quant_f16(&self, input: &[f16]) -> TractResult<Blob> {
47        unsafe {
48            let blocks = input.len() / self.block_len();
49            let mut quant = Blob::for_layout(
50                Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(),
51            );
52            for b in 0..blocks {
53                let block = &input[b * self.block_len()..][..self.block_len()];
54                let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()];
55                self.quant_block_f16(block, qblock);
56            }
57            Ok(quant)
58        }
59    }
60
61    fn quant_f32(&self, input: &[f32]) -> TractResult<Blob> {
62        unsafe {
63            let blocks = input.len() / self.block_len();
64            let mut quant = Blob::for_layout(
65                Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(),
66            );
67            for b in 0..blocks {
68                let block = &input[b * self.block_len()..][..self.block_len()];
69                let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()];
70                self.quant_block_f32(block, qblock);
71            }
72            Ok(quant)
73        }
74    }
75
76    fn dequant_f32(&self, input: &[u8]) -> TractResult<Tensor> {
77        unsafe {
78            let blocks = input.len() / self.block_bytes();
79            let mut tensor = Tensor::uninitialized::<f32>(&[blocks * self.block_len()])?;
80            let mut tensor_plain = tensor.try_as_plain_mut()?;
81            let slice = tensor_plain.as_slice_mut::<f32>()?;
82            for b in 0..blocks {
83                let block = &mut slice[b * self.block_len()..][..self.block_len()];
84                let qblock = &input[b * self.block_bytes()..][..self.block_bytes()];
85                self.dequant_block_f32(qblock, block);
86            }
87            Ok(tensor)
88        }
89    }
90
91    fn dequant_f16(&self, input: &[u8]) -> TractResult<Tensor> {
92        unsafe {
93            let blocks = input.len() / self.block_bytes();
94            let mut tensor = Tensor::uninitialized::<f16>(&[blocks * self.block_len()])?;
95            let mut tensor_plain = tensor.try_as_plain_mut()?;
96            let slice = tensor_plain.as_slice_mut::<f16>()?;
97            for b in 0..blocks {
98                let block = &mut slice[b * self.block_len()..][..self.block_len()];
99                let qblock = &input[b * self.block_bytes()..][..self.block_bytes()];
100                self.dequant_block_f16(qblock, block);
101            }
102            Ok(tensor)
103        }
104    }
105
106    fn extract_at_offset_f16(&self, input: &[u8], offset: usize) -> f16 {
107        let len = self.block_len();
108        let block_id = offset / len;
109        let mut block = vec![f16::zero(); self.block_len()];
110        self.dequant_block_f16(
111            &input[block_id * self.block_bytes()..][..self.block_bytes()],
112            &mut block,
113        );
114        block[offset % len]
115    }
116
117    fn extract_at_offset_f32(&self, input: &[u8], offset: usize) -> f32 {
118        let len = self.block_len();
119        let block_id = offset / len;
120        let mut block = vec![f32::zero(); self.block_len()];
121        self.dequant_block_f32(
122            &input[block_id * self.block_bytes()..][..self.block_bytes()],
123            &mut block,
124        );
125        block[offset % len]
126    }
127
128    fn simulate_precision_loss(
129        &self,
130        mut tensor: Tensor,
131        block_axis: usize,
132    ) -> TractResult<Tensor> {
133        ensure!(block_axis == tensor.rank() - 1);
134        ensure!(tensor.shape()[block_axis] % self.block_len() == 0);
135        let mut scratch = vec![0u8; self.block_bytes()];
136        if tensor.datum_type() == f32::datum_type() {
137            let mut tensor_plain = tensor.try_as_plain_mut()?;
138            for block in tensor_plain.as_slice_mut::<f32>()?.chunks_mut(self.block_len()) {
139                self.quant_block_f32(block, &mut scratch);
140                self.dequant_block_f32(&scratch, block);
141            }
142            drop(tensor_plain);
143            Ok(tensor)
144        } else if tensor.datum_type() == f16::datum_type() {
145            let mut tensor_plain = tensor.try_as_plain_mut()?;
146            for block in tensor_plain.as_slice_mut::<f16>()?.chunks_mut(self.block_len()) {
147                self.quant_block_f16(block, &mut scratch);
148                self.dequant_block_f16(&scratch, block);
149            }
150            drop(tensor_plain);
151            Ok(tensor)
152        } else {
153            todo!()
154        }
155    }
156
157    fn pack(
158        &self,
159        input: &[u8],
160        k: usize,
161        r: usize,
162        zip: usize,
163        scales_at_end: bool,
164    ) -> TractResult<EagerPackedInput>;
165
166    unsafe fn extract_packed_panel(
167        &self,
168        value: &EagerPackedInput,
169        target: &PackedFormat,
170        panel: usize,
171        scratch: *mut u8,
172    ) -> TractResult<()>;
173
174    fn extract_at_mn_f16(
175        &self,
176        value: &EagerPackedInput,
177        mn: usize,
178        target: &mut [f16],
179    ) -> TractResult<()>;
180
181    fn extract_at_mn_f32(
182        &self,
183        value: &EagerPackedInput,
184        mn: usize,
185        target: &mut [f32],
186    ) -> TractResult<()>;
187}
188
189dyn_clone::clone_trait_object!(BlockQuant);
190dyn_hash::hash_trait_object!(BlockQuant);
191dyn_eq::eq_trait_object!(BlockQuant);
192impl_downcast!(BlockQuant);
193
194#[allow(clippy::derived_hash_with_manual_eq)]
195#[derive(Clone, Hash)]
196pub struct PackedBlockQuantFormat {
197    pub bq: Box<dyn BlockQuant>,
198    pub r: usize,
199    pub zip: usize,
200    pub scales_at_end: bool,
201}
202
203impl PartialEq for PackedBlockQuantFormat {
204    fn eq(&self, other: &Self) -> bool {
205        *self.bq == *other.bq
206            && self.r == other.r
207            && self.zip == other.zip
208            && self.scales_at_end == other.scales_at_end
209    }
210}
211
212impl Eq for PackedBlockQuantFormat {}
213
214impl Display for PackedBlockQuantFormat {
215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        write!(f, "Packed{}[{}]", &*self.bq, self.r)?;
217        if self.zip != 0 {
218            write!(f, "Z{}", self.zip)?;
219        }
220        if self.scales_at_end {
221            write!(f, "Se")?;
222        }
223        Ok(())
224    }
225}
226
227impl Debug for PackedBlockQuantFormat {
228    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229        <Self as Display>::fmt(self, f)
230    }
231}
232
233impl PackedBlockQuantFormat {
234    pub fn new(bq: &dyn BlockQuant, r: usize, zip: usize, scales_at_end: bool) -> Self {
235        PackedBlockQuantFormat { bq: clone_box(bq), r, zip, scales_at_end }
236    }
237
238    pub fn simulate_precision_loss(
239        &self,
240        tensor: Tensor,
241        block_axis: usize,
242    ) -> TractResult<Tensor> {
243        self.bq.simulate_precision_loss(tensor, block_axis)
244    }
245
246    pub fn pack(&self, input: &[u8], k: usize) -> TractResult<EagerPackedInput> {
247        self.bq.pack(input, k, self.r, self.zip, self.scales_at_end)
248    }
249}
250
251impl MMMInputFormat for PackedBlockQuantFormat {
252    fn prepare_tensor(&self, t: &Tensor, _k_axis: usize, _mn_axis: usize) -> TractResult<Tensor> {
253        let bqs = t.try_storage_as::<BlockQuantStorage>()?;
254        let num_groups: usize =
255            if t.rank() > 2 { t.shape()[..t.rank() - 2].iter().product() } else { 1 };
256        let m_per_group = t.shape()[t.rank().saturating_sub(2)];
257        let k = *t.shape().last().unwrap();
258        let values = (0..num_groups)
259            .map(|g| {
260                let slice = block_quant_slice(bqs.value(), &*self.bq, m_per_group, k, g);
261                let packed = self.pack(slice, k)?;
262                Ok(Box::new(packed) as Box<dyn MMMInputValue>)
263            })
264            .collect::<TractResult<Vec<_>>>()?;
265        let leading_shape = &t.shape()[..t.rank().saturating_sub(2)];
266        Ok(crate::mmm::PackedMatrixStorage::new_batched(leading_shape, values)
267            .into_tensor(t.datum_type()))
268    }
269
270    fn prepare_one(
271        &self,
272        t: &Tensor,
273        k_axis: usize,
274        mn_axis: usize,
275    ) -> TractResult<Box<dyn MMMInputValue>> {
276        // this code path is essentially there for test scenarios
277        let t = if t.is_plain() && t.datum_type().is_number() {
278            let k = t.shape()[k_axis];
279            let m = t.shape()[mn_axis];
280            assert!(k % self.bq.block_len() == 0);
281            let t: Cow<Tensor> = if k_axis == 1 && mn_axis == 0 {
282                Cow::Borrowed(t)
283            } else {
284                Cow::Owned(t.clone().move_axis(1, 0)?)
285            };
286            let quant = if t.datum_type() == f32::datum_type() {
287                self.bq.quant_f32(t.try_as_plain()?.as_slice()?)?
288            } else if t.datum_type() == f16::datum_type() {
289                self.bq.quant_f16(t.try_as_plain()?.as_slice()?)?
290            } else {
291                todo!()
292            };
293            Cow::Owned(
294                BlockQuantStorage::new(self.bq.clone(), m, k, Arc::new(quant))?
295                    .into_tensor_with_shape(t.datum_type(), &[1, m, k]),
296            )
297        } else {
298            Cow::Borrowed(t)
299        };
300        ensure!(mn_axis == 0);
301        ensure!(k_axis == 1);
302        let bqs = t.try_storage_as::<BlockQuantStorage>()?;
303        let k = *t.shape().last().unwrap();
304        let packed = self.pack(bqs.value(), k)?;
305        Ok(Box::new(packed))
306    }
307
308    fn precursor(&self) -> WeightType {
309        WeightType::BlockQuant(self.bq.clone())
310    }
311
312    fn k_alignment(&self) -> usize {
313        self.bq.block_len()
314    }
315
316    fn r(&self) -> usize {
317        self.r
318    }
319
320    fn mem_size(&self, k: TDim, mn: TDim) -> TDim {
321        k * mn * self.bq.block_bytes() / self.bq.block_len()
322    }
323
324    fn extract_at_mn_f16(
325        &self,
326        data: &EagerPackedInput,
327        mn: usize,
328        slice: &mut [f16],
329    ) -> TractResult<()> {
330        self.bq.extract_at_mn_f16(data, mn, slice)
331    }
332
333    fn extract_at_mn_f32(
334        &self,
335        data: &EagerPackedInput,
336        mn: usize,
337        slice: &mut [f32],
338    ) -> TractResult<()> {
339        self.bq.extract_at_mn_f32(data, mn, slice)
340    }
341}