tract_linalg/frame/block_quant/
mod.rs1use downcast_rs::{Downcast, impl_downcast};
2use dyn_clone::{DynClone, clone_box};
3use dyn_eq::DynEq;
4use dyn_hash::DynHash;
5use num_traits::Zero;
6use tract_data::internal::*;
7use tract_data::itertools::Itertools;
8
9use std::alloc::Layout;
10use std::borrow::Cow;
11use std::fmt::{Debug, Display};
12use std::hash::Hash;
13use std::sync::Arc;
14
15mod helpers;
16mod q4_0;
17mod q8_1;
18mod storage;
19mod value;
20
21pub use helpers::{NibbleReader, NibbleWriter};
22pub use q4_0::Q4_0;
23pub use q8_1::Q8_1;
24pub use storage::{BlockQuantStorage, block_quant_slice};
25pub use value::{BlockQuantFact, PackedBlockQuantFact};
26
27use crate::mmm::{EagerPackedInput, MMMInputFormat};
28use crate::pack::PackedFormat;
29
30use crate::WeightType;
31
32use super::mmm::MMMInputValue;
33
34pub trait BlockQuant:
35 Debug + Display + Send + Sync + DynClone + DynHash + dyn_eq::DynEq + Downcast
36{
37 fn block_len(&self) -> usize;
38
39 fn block_bytes(&self) -> usize;
40
41 fn dequant_block_f32(&self, quant: &[u8], block: &mut [f32]);
42 fn dequant_block_f16(&self, quant: &[u8], block: &mut [f16]);
43 fn quant_block_f16(&self, block: &[f16], quant: &mut [u8]);
44 fn quant_block_f32(&self, block: &[f32], quant: &mut [u8]);
45
46 fn quant_f16(&self, input: &[f16]) -> TractResult<Blob> {
47 unsafe {
48 let blocks = input.len() / self.block_len();
49 let mut quant = Blob::for_layout(
50 Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(),
51 );
52 for b in 0..blocks {
53 let block = &input[b * self.block_len()..][..self.block_len()];
54 let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()];
55 self.quant_block_f16(block, qblock);
56 }
57 Ok(quant)
58 }
59 }
60
61 fn quant_f32(&self, input: &[f32]) -> TractResult<Blob> {
62 unsafe {
63 let blocks = input.len() / self.block_len();
64 let mut quant = Blob::for_layout(
65 Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(),
66 );
67 for b in 0..blocks {
68 let block = &input[b * self.block_len()..][..self.block_len()];
69 let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()];
70 self.quant_block_f32(block, qblock);
71 }
72 Ok(quant)
73 }
74 }
75
76 fn dequant_f32(&self, input: &[u8]) -> TractResult<Tensor> {
77 unsafe {
78 let blocks = input.len() / self.block_bytes();
79 let mut tensor = Tensor::uninitialized::<f32>(&[blocks * self.block_len()])?;
80 let mut tensor_plain = tensor.try_as_plain_mut()?;
81 let slice = tensor_plain.as_slice_mut::<f32>()?;
82 for b in 0..blocks {
83 let block = &mut slice[b * self.block_len()..][..self.block_len()];
84 let qblock = &input[b * self.block_bytes()..][..self.block_bytes()];
85 self.dequant_block_f32(qblock, block);
86 }
87 Ok(tensor)
88 }
89 }
90
91 fn dequant_f16(&self, input: &[u8]) -> TractResult<Tensor> {
92 unsafe {
93 let blocks = input.len() / self.block_bytes();
94 let mut tensor = Tensor::uninitialized::<f16>(&[blocks * self.block_len()])?;
95 let mut tensor_plain = tensor.try_as_plain_mut()?;
96 let slice = tensor_plain.as_slice_mut::<f16>()?;
97 for b in 0..blocks {
98 let block = &mut slice[b * self.block_len()..][..self.block_len()];
99 let qblock = &input[b * self.block_bytes()..][..self.block_bytes()];
100 self.dequant_block_f16(qblock, block);
101 }
102 Ok(tensor)
103 }
104 }
105
106 fn extract_at_offset_f16(&self, input: &[u8], offset: usize) -> f16 {
107 let len = self.block_len();
108 let block_id = offset / len;
109 let mut block = vec![f16::zero(); self.block_len()];
110 self.dequant_block_f16(
111 &input[block_id * self.block_bytes()..][..self.block_bytes()],
112 &mut block,
113 );
114 block[offset % len]
115 }
116
117 fn extract_at_offset_f32(&self, input: &[u8], offset: usize) -> f32 {
118 let len = self.block_len();
119 let block_id = offset / len;
120 let mut block = vec![f32::zero(); self.block_len()];
121 self.dequant_block_f32(
122 &input[block_id * self.block_bytes()..][..self.block_bytes()],
123 &mut block,
124 );
125 block[offset % len]
126 }
127
128 fn simulate_precision_loss(
129 &self,
130 mut tensor: Tensor,
131 block_axis: usize,
132 ) -> TractResult<Tensor> {
133 ensure!(block_axis == tensor.rank() - 1);
134 ensure!(tensor.shape()[block_axis] % self.block_len() == 0);
135 let mut scratch = vec![0u8; self.block_bytes()];
136 if tensor.datum_type() == f32::datum_type() {
137 let mut tensor_plain = tensor.try_as_plain_mut()?;
138 for block in tensor_plain.as_slice_mut::<f32>()?.chunks_mut(self.block_len()) {
139 self.quant_block_f32(block, &mut scratch);
140 self.dequant_block_f32(&scratch, block);
141 }
142 drop(tensor_plain);
143 Ok(tensor)
144 } else if tensor.datum_type() == f16::datum_type() {
145 let mut tensor_plain = tensor.try_as_plain_mut()?;
146 for block in tensor_plain.as_slice_mut::<f16>()?.chunks_mut(self.block_len()) {
147 self.quant_block_f16(block, &mut scratch);
148 self.dequant_block_f16(&scratch, block);
149 }
150 drop(tensor_plain);
151 Ok(tensor)
152 } else {
153 todo!()
154 }
155 }
156
157 fn pack(
158 &self,
159 input: &[u8],
160 k: usize,
161 r: usize,
162 zip: usize,
163 scales_at_end: bool,
164 ) -> TractResult<EagerPackedInput>;
165
166 unsafe fn extract_packed_panel(
167 &self,
168 value: &EagerPackedInput,
169 target: &PackedFormat,
170 panel: usize,
171 scratch: *mut u8,
172 ) -> TractResult<()>;
173
174 fn extract_at_mn_f16(
175 &self,
176 value: &EagerPackedInput,
177 mn: usize,
178 target: &mut [f16],
179 ) -> TractResult<()>;
180
181 fn extract_at_mn_f32(
182 &self,
183 value: &EagerPackedInput,
184 mn: usize,
185 target: &mut [f32],
186 ) -> TractResult<()>;
187}
188
189dyn_clone::clone_trait_object!(BlockQuant);
190dyn_hash::hash_trait_object!(BlockQuant);
191dyn_eq::eq_trait_object!(BlockQuant);
192impl_downcast!(BlockQuant);
193
194#[allow(clippy::derived_hash_with_manual_eq)]
195#[derive(Clone, Hash)]
196pub struct PackedBlockQuantFormat {
197 pub bq: Box<dyn BlockQuant>,
198 pub r: usize,
199 pub zip: usize,
200 pub scales_at_end: bool,
201}
202
203impl PartialEq for PackedBlockQuantFormat {
204 fn eq(&self, other: &Self) -> bool {
205 *self.bq == *other.bq
206 && self.r == other.r
207 && self.zip == other.zip
208 && self.scales_at_end == other.scales_at_end
209 }
210}
211
212impl Eq for PackedBlockQuantFormat {}
213
214impl Display for PackedBlockQuantFormat {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 write!(f, "Packed{}[{}]", &*self.bq, self.r)?;
217 if self.zip != 0 {
218 write!(f, "Z{}", self.zip)?;
219 }
220 if self.scales_at_end {
221 write!(f, "Se")?;
222 }
223 Ok(())
224 }
225}
226
227impl Debug for PackedBlockQuantFormat {
228 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 <Self as Display>::fmt(self, f)
230 }
231}
232
233impl PackedBlockQuantFormat {
234 pub fn new(bq: &dyn BlockQuant, r: usize, zip: usize, scales_at_end: bool) -> Self {
235 PackedBlockQuantFormat { bq: clone_box(bq), r, zip, scales_at_end }
236 }
237
238 pub fn simulate_precision_loss(
239 &self,
240 tensor: Tensor,
241 block_axis: usize,
242 ) -> TractResult<Tensor> {
243 self.bq.simulate_precision_loss(tensor, block_axis)
244 }
245
246 pub fn pack(&self, input: &[u8], k: usize) -> TractResult<EagerPackedInput> {
247 self.bq.pack(input, k, self.r, self.zip, self.scales_at_end)
248 }
249}
250
251impl MMMInputFormat for PackedBlockQuantFormat {
252 fn prepare_tensor(&self, t: &Tensor, _k_axis: usize, _mn_axis: usize) -> TractResult<Tensor> {
253 let bqs = t.try_storage_as::<BlockQuantStorage>()?;
254 let num_groups: usize =
255 if t.rank() > 2 { t.shape()[..t.rank() - 2].iter().product() } else { 1 };
256 let m_per_group = t.shape()[t.rank().saturating_sub(2)];
257 let k = *t.shape().last().unwrap();
258 let values = (0..num_groups)
259 .map(|g| {
260 let slice = block_quant_slice(bqs.value(), &*self.bq, m_per_group, k, g);
261 let packed = self.pack(slice, k)?;
262 Ok(Box::new(packed) as Box<dyn MMMInputValue>)
263 })
264 .collect::<TractResult<Vec<_>>>()?;
265 let leading_shape = &t.shape()[..t.rank().saturating_sub(2)];
266 Ok(crate::mmm::PackedMatrixStorage::new_batched(leading_shape, values)
267 .into_tensor(t.datum_type()))
268 }
269
270 fn prepare_one(
271 &self,
272 t: &Tensor,
273 k_axis: usize,
274 mn_axis: usize,
275 ) -> TractResult<Box<dyn MMMInputValue>> {
276 let t = if t.is_plain() && t.datum_type().is_number() {
278 let k = t.shape()[k_axis];
279 let m = t.shape()[mn_axis];
280 assert!(k % self.bq.block_len() == 0);
281 let t: Cow<Tensor> = if k_axis == 1 && mn_axis == 0 {
282 Cow::Borrowed(t)
283 } else {
284 Cow::Owned(t.clone().move_axis(1, 0)?)
285 };
286 let quant = if t.datum_type() == f32::datum_type() {
287 self.bq.quant_f32(t.try_as_plain()?.as_slice()?)?
288 } else if t.datum_type() == f16::datum_type() {
289 self.bq.quant_f16(t.try_as_plain()?.as_slice()?)?
290 } else {
291 todo!()
292 };
293 Cow::Owned(
294 BlockQuantStorage::new(self.bq.clone(), m, k, Arc::new(quant))?
295 .into_tensor_with_shape(t.datum_type(), &[1, m, k]),
296 )
297 } else {
298 Cow::Borrowed(t)
299 };
300 ensure!(mn_axis == 0);
301 ensure!(k_axis == 1);
302 let bqs = t.try_storage_as::<BlockQuantStorage>()?;
303 let k = *t.shape().last().unwrap();
304 let packed = self.pack(bqs.value(), k)?;
305 Ok(Box::new(packed))
306 }
307
308 fn precursor(&self) -> WeightType {
309 WeightType::BlockQuant(self.bq.clone())
310 }
311
312 fn k_alignment(&self) -> usize {
313 self.bq.block_len()
314 }
315
316 fn r(&self) -> usize {
317 self.r
318 }
319
320 fn mem_size(&self, k: TDim, mn: TDim) -> TDim {
321 k * mn * self.bq.block_bytes() / self.bq.block_len()
322 }
323
324 fn extract_at_mn_f16(
325 &self,
326 data: &EagerPackedInput,
327 mn: usize,
328 slice: &mut [f16],
329 ) -> TractResult<()> {
330 self.bq.extract_at_mn_f16(data, mn, slice)
331 }
332
333 fn extract_at_mn_f32(
334 &self,
335 data: &EagerPackedInput,
336 mn: usize,
337 slice: &mut [f32],
338 ) -> TractResult<()> {
339 self.bq.extract_at_mn_f32(data, mn, slice)
340 }
341}