tract_linalg/frame/block_quant/
mod.rs1use downcast_rs::{Downcast, impl_downcast};
2use dyn_clone::{DynClone, clone_box};
3use dyn_eq::DynEq;
4use dyn_hash::DynHash;
5use num_traits::Zero;
6use tract_data::internal::*;
7use tract_data::itertools::Itertools;
8
9use std::alloc::Layout;
10use std::borrow::Cow;
11use std::fmt::{Debug, Display};
12use std::hash::Hash;
13use std::sync::Arc;
14
15mod helpers;
16mod q4_0;
17mod q8_1;
18mod storage;
19mod value;
20
21pub use helpers::{NibbleReader, NibbleWriter};
22pub use q4_0::Q4_0;
23pub use q8_1::Q8_1;
24pub use storage::{BlockQuantStorage, block_quant_slice};
25pub use value::{BlockQuantFact, PackedBlockQuantFact};
26
27use crate::mmm::{EagerPackedInput, MMMInputFormat};
28use crate::pack::PackedFormat;
29
30use crate::WeightType;
31
32use super::mmm::MMMInputValue;
33
34pub trait BlockQuant:
35 Debug + Display + Send + Sync + DynClone + DynHash + dyn_eq::DynEq + Downcast
36{
37 fn block_len(&self) -> usize;
38
39 fn block_bytes(&self) -> usize;
40
41 fn dequant_block_f32(&self, quant: &[u8], block: &mut [f32]);
42 fn dequant_block_f16(&self, quant: &[u8], block: &mut [f16]);
43 fn quant_block_f16(&self, block: &[f16], quant: &mut [u8]);
44 fn quant_block_f32(&self, block: &[f32], quant: &mut [u8]);
45
46 fn quant_f16(&self, input: &[f16]) -> TractResult<Blob> {
47 unsafe {
48 let blocks = input.len() / self.block_len();
49 let mut quant = Blob::for_layout(
50 Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(),
51 );
52 for b in 0..blocks {
53 let block = &input[b * self.block_len()..][..self.block_len()];
54 let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()];
55 self.quant_block_f16(block, qblock);
56 }
57 Ok(quant)
58 }
59 }
60
61 fn quant_f32(&self, input: &[f32]) -> TractResult<Blob> {
62 unsafe {
63 let blocks = input.len() / self.block_len();
64 let mut quant = Blob::for_layout(
65 Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(),
66 );
67 for b in 0..blocks {
68 let block = &input[b * self.block_len()..][..self.block_len()];
69 let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()];
70 self.quant_block_f32(block, qblock);
71 }
72 Ok(quant)
73 }
74 }
75
76 fn dequant_f32(&self, input: &[u8]) -> TractResult<Tensor> {
77 unsafe {
78 let blocks = input.len() / self.block_bytes();
79 let mut tensor = Tensor::uninitialized::<f32>(&[blocks * self.block_len()])?;
80 let mut tensor_plain = tensor.try_as_plain_mut()?;
81 let slice = tensor_plain.as_slice_mut::<f32>()?;
82 for b in 0..blocks {
83 let block = &mut slice[b * self.block_len()..][..self.block_len()];
84 let qblock = &input[b * self.block_bytes()..][..self.block_bytes()];
85 self.dequant_block_f32(qblock, block);
86 }
87 Ok(tensor)
88 }
89 }
90
91 fn dequant_f16(&self, input: &[u8]) -> TractResult<Tensor> {
92 unsafe {
93 let blocks = input.len() / self.block_bytes();
94 let mut tensor = Tensor::uninitialized::<f16>(&[blocks * self.block_len()])?;
95 let mut tensor_plain = tensor.try_as_plain_mut()?;
96 let slice = tensor_plain.as_slice_mut::<f16>()?;
97 for b in 0..blocks {
98 let block = &mut slice[b * self.block_len()..][..self.block_len()];
99 let qblock = &input[b * self.block_bytes()..][..self.block_bytes()];
100 self.dequant_block_f16(qblock, block);
101 }
102 Ok(tensor)
103 }
104 }
105
106 fn extract_at_offset_f16(&self, input: &[u8], offset: usize) -> f16 {
107 let len = self.block_len();
108 let block_id = offset / len;
109 let mut block = vec![f16::zero(); self.block_len()];
110 self.dequant_block_f16(
111 &input[block_id * self.block_bytes()..][..self.block_bytes()],
112 &mut block,
113 );
114 block[offset % len]
115 }
116
117 fn extract_at_offset_f32(&self, input: &[u8], offset: usize) -> f32 {
118 let len = self.block_len();
119 let block_id = offset / len;
120 let mut block = vec![f32::zero(); self.block_len()];
121 self.dequant_block_f32(
122 &input[block_id * self.block_bytes()..][..self.block_bytes()],
123 &mut block,
124 );
125 block[offset % len]
126 }
127
128 fn simulate_precision_loss(
129 &self,
130 mut tensor: Tensor,
131 block_axis: usize,
132 ) -> TractResult<Tensor> {
133 ensure!(block_axis == tensor.rank() - 1);
134 ensure!(tensor.shape()[block_axis] % self.block_len() == 0);
135 let mut scratch = vec![0u8; self.block_bytes()];
136 if tensor.datum_type() == f32::datum_type() {
137 let mut tensor_plain = tensor.try_as_plain_mut()?;
138 for block in tensor_plain.as_slice_mut::<f32>()?.chunks_mut(self.block_len()) {
139 self.quant_block_f32(block, &mut scratch);
140 self.dequant_block_f32(&scratch, block);
141 }
142 Ok(tensor)
143 } else if tensor.datum_type() == f16::datum_type() {
144 let mut tensor_plain = tensor.try_as_plain_mut()?;
145 for block in tensor_plain.as_slice_mut::<f16>()?.chunks_mut(self.block_len()) {
146 self.quant_block_f16(block, &mut scratch);
147 self.dequant_block_f16(&scratch, block);
148 }
149 Ok(tensor)
150 } else {
151 todo!()
152 }
153 }
154
155 fn pack(
156 &self,
157 input: &[u8],
158 k: usize,
159 r: usize,
160 zip: usize,
161 scales_at_end: bool,
162 ) -> TractResult<EagerPackedInput>;
163
164 unsafe fn extract_packed_panel(
165 &self,
166 value: &EagerPackedInput,
167 target: &PackedFormat,
168 panel: usize,
169 scratch: *mut u8,
170 ) -> TractResult<()>;
171
172 fn extract_at_mn_f16(
173 &self,
174 value: &EagerPackedInput,
175 mn: usize,
176 target: &mut [f16],
177 ) -> TractResult<()>;
178
179 fn extract_at_mn_f32(
180 &self,
181 value: &EagerPackedInput,
182 mn: usize,
183 target: &mut [f32],
184 ) -> TractResult<()>;
185}
186
187dyn_clone::clone_trait_object!(BlockQuant);
188dyn_hash::hash_trait_object!(BlockQuant);
189dyn_eq::eq_trait_object!(BlockQuant);
190impl_downcast!(BlockQuant);
191
192#[allow(clippy::derived_hash_with_manual_eq)]
193#[derive(Clone, Hash)]
194pub struct PackedBlockQuantFormat {
195 pub bq: Box<dyn BlockQuant>,
196 pub r: usize,
197 pub zip: usize,
198 pub scales_at_end: bool,
199}
200
201impl PartialEq for PackedBlockQuantFormat {
202 fn eq(&self, other: &Self) -> bool {
203 *self.bq == *other.bq
204 && self.r == other.r
205 && self.zip == other.zip
206 && self.scales_at_end == other.scales_at_end
207 }
208}
209
210impl Eq for PackedBlockQuantFormat {}
211
212impl Display for PackedBlockQuantFormat {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 write!(f, "Packed{}[{}]", &*self.bq, self.r)?;
215 if self.zip != 0 {
216 write!(f, "Z{}", self.zip)?;
217 }
218 if self.scales_at_end {
219 write!(f, "Se")?;
220 }
221 Ok(())
222 }
223}
224
225impl Debug for PackedBlockQuantFormat {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 <Self as Display>::fmt(self, f)
228 }
229}
230
231impl PackedBlockQuantFormat {
232 pub fn new(bq: &dyn BlockQuant, r: usize, zip: usize, scales_at_end: bool) -> Self {
233 PackedBlockQuantFormat { bq: clone_box(bq), r, zip, scales_at_end }
234 }
235
236 pub fn simulate_precision_loss(
237 &self,
238 tensor: Tensor,
239 block_axis: usize,
240 ) -> TractResult<Tensor> {
241 self.bq.simulate_precision_loss(tensor, block_axis)
242 }
243
244 pub fn pack(&self, input: &[u8], k: usize) -> TractResult<EagerPackedInput> {
245 self.bq.pack(input, k, self.r, self.zip, self.scales_at_end)
246 }
247}
248
249impl MMMInputFormat for PackedBlockQuantFormat {
250 fn prepare_tensor(&self, t: &Tensor, _k_axis: usize, _mn_axis: usize) -> TractResult<Tensor> {
251 let bqs = t.try_storage_as::<BlockQuantStorage>()?;
252 let num_groups: usize =
253 if t.rank() > 2 { t.shape()[..t.rank() - 2].iter().product() } else { 1 };
254 let m_per_group = t.shape()[t.rank().saturating_sub(2)];
255 let k = *t.shape().last().unwrap();
256 let values = (0..num_groups)
257 .map(|g| {
258 let slice = block_quant_slice(bqs.value(), &*self.bq, m_per_group, k, g);
259 let packed = self.pack(slice, k)?;
260 Ok(Box::new(packed) as Box<dyn MMMInputValue>)
261 })
262 .collect::<TractResult<Vec<_>>>()?;
263 let leading_shape = &t.shape()[..t.rank().saturating_sub(2)];
264 Ok(crate::mmm::PackedMatrixStorage::new_batched(leading_shape, values)
265 .into_tensor(t.datum_type()))
266 }
267
268 fn prepare_one(
269 &self,
270 t: &Tensor,
271 k_axis: usize,
272 mn_axis: usize,
273 ) -> TractResult<Box<dyn MMMInputValue>> {
274 let t = if t.is_plain() && t.datum_type().is_number() {
276 let k = t.shape()[k_axis];
277 let m = t.shape()[mn_axis];
278 assert!(k % self.bq.block_len() == 0);
279 let t: Cow<Tensor> = if k_axis == 1 && mn_axis == 0 {
280 Cow::Borrowed(t)
281 } else {
282 Cow::Owned(t.clone().move_axis(1, 0)?)
283 };
284 let quant = if t.datum_type() == f32::datum_type() {
285 self.bq.quant_f32(t.try_as_plain()?.as_slice()?)?
286 } else if t.datum_type() == f16::datum_type() {
287 self.bq.quant_f16(t.try_as_plain()?.as_slice()?)?
288 } else {
289 todo!()
290 };
291 Cow::Owned(
292 BlockQuantStorage::new(self.bq.clone(), m, k, Arc::new(quant))?
293 .into_tensor_with_shape(t.datum_type(), &[1, m, k]),
294 )
295 } else {
296 Cow::Borrowed(t)
297 };
298 ensure!(mn_axis == 0);
299 ensure!(k_axis == 1);
300 let bqs = t.try_storage_as::<BlockQuantStorage>()?;
301 let k = *t.shape().last().unwrap();
302 let packed = self.pack(bqs.value(), k)?;
303 Ok(Box::new(packed))
304 }
305
306 fn precursor(&self) -> WeightType {
307 WeightType::BlockQuant(self.bq.clone())
308 }
309
310 fn k_alignment(&self) -> usize {
311 self.bq.block_len()
312 }
313
314 fn r(&self) -> usize {
315 self.r
316 }
317
318 fn mem_size(&self, k: TDim, mn: TDim) -> TDim {
319 k * mn * self.bq.block_bytes() / self.bq.block_len()
320 }
321
322 fn extract_at_mn_f16(
323 &self,
324 data: &EagerPackedInput,
325 mn: usize,
326 slice: &mut [f16],
327 ) -> TractResult<()> {
328 self.bq.extract_at_mn_f16(data, mn, slice)
329 }
330
331 fn extract_at_mn_f32(
332 &self,
333 data: &EagerPackedInput,
334 mn: usize,
335 slice: &mut [f32],
336 ) -> TractResult<()> {
337 self.bq.extract_at_mn_f32(data, mn, slice)
338 }
339}