tract_linalg/frame/block_quant/
mod.rs1use downcast_rs::{Downcast, impl_downcast};
2use dyn_clone::{DynClone, clone_box};
3use dyn_hash::DynHash;
4use num_traits::Zero;
5use tract_data::internal::*;
6use tract_data::itertools::Itertools;
7
8use std::alloc::Layout;
9use std::borrow::Cow;
10use std::fmt::{Debug, Display};
11use std::hash::Hash;
12use std::sync::Arc;
13
14mod helpers;
15mod q4_0;
16mod q8_1;
17mod value;
18
19pub use helpers::{NibbleReader, NibbleWriter};
20pub use q4_0::Q4_0;
21pub use q8_1::Q8_1;
22pub use value::{BlockQuantFact, PackedBlockQuantFact};
23
24use crate::mmm::{EagerPackedInput, MMMInputFormat};
25use crate::pack::PackedFormat;
26
27use crate::WeightType;
28
29use super::mmm::MMMInputValue;
30
31pub trait BlockQuant: Debug + Display + Send + Sync + DynClone + DynHash + Downcast {
32 fn same_as(&self, other: &dyn BlockQuant) -> bool;
33
34 fn block_len(&self) -> usize;
35
36 fn block_bytes(&self) -> usize;
37
38 fn dequant_block_f32(&self, quant: &[u8], block: &mut [f32]);
39 fn dequant_block_f16(&self, quant: &[u8], block: &mut [f16]);
40 fn quant_block_f16(&self, block: &[f16], quant: &mut [u8]);
41 fn quant_block_f32(&self, block: &[f32], quant: &mut [u8]);
42
43 fn quant_f16(&self, input: &[f16]) -> TractResult<Blob> {
44 unsafe {
45 let blocks = input.len() / self.block_len();
46 let mut quant = Blob::for_layout(
47 Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(),
48 );
49 for b in 0..blocks {
50 let block = &input[b * self.block_len()..][..self.block_len()];
51 let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()];
52 self.quant_block_f16(block, qblock);
53 }
54 Ok(quant)
55 }
56 }
57
58 fn quant_f32(&self, input: &[f32]) -> TractResult<Blob> {
59 unsafe {
60 let blocks = input.len() / self.block_len();
61 let mut quant = Blob::for_layout(
62 Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(),
63 );
64 for b in 0..blocks {
65 let block = &input[b * self.block_len()..][..self.block_len()];
66 let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()];
67 self.quant_block_f32(block, qblock);
68 }
69 Ok(quant)
70 }
71 }
72
73 fn dequant_f32(&self, input: &[u8]) -> TractResult<Tensor> {
74 unsafe {
75 let blocks = input.len() / self.block_bytes();
76 let mut tensor = Tensor::uninitialized::<f32>(&[blocks * self.block_len()])?;
77 let slice = tensor.as_slice_mut::<f32>()?;
78 for b in 0..blocks {
79 let block = &mut slice[b * self.block_len()..][..self.block_len()];
80 let qblock = &input[b * self.block_bytes()..][..self.block_bytes()];
81 self.dequant_block_f32(qblock, block);
82 }
83 Ok(tensor)
84 }
85 }
86
87 fn dequant_f16(&self, input: &[u8]) -> TractResult<Tensor> {
88 unsafe {
89 let blocks = input.len() / self.block_bytes();
90 let mut tensor = Tensor::uninitialized::<f16>(&[blocks * self.block_len()])?;
91 let slice = tensor.as_slice_mut::<f16>()?;
92 for b in 0..blocks {
93 let block = &mut slice[b * self.block_len()..][..self.block_len()];
94 let qblock = &input[b * self.block_bytes()..][..self.block_bytes()];
95 self.dequant_block_f16(qblock, block);
96 }
97 Ok(tensor)
98 }
99 }
100
101 fn extract_at_offset_f16(&self, input: &[u8], offset: usize) -> f16 {
102 let len = self.block_len();
103 let block_id = offset / len;
104 let mut block = vec![f16::zero(); self.block_len()];
105 self.dequant_block_f16(
106 &input[block_id * self.block_bytes()..][..self.block_bytes()],
107 &mut block,
108 );
109 block[offset % len]
110 }
111
112 fn extract_at_offset_f32(&self, input: &[u8], offset: usize) -> f32 {
113 let len = self.block_len();
114 let block_id = offset / len;
115 let mut block = vec![f32::zero(); self.block_len()];
116 self.dequant_block_f32(
117 &input[block_id * self.block_bytes()..][..self.block_bytes()],
118 &mut block,
119 );
120 block[offset % len]
121 }
122
123 fn simulate_precision_loss(
124 &self,
125 mut tensor: Tensor,
126 block_axis: usize,
127 ) -> TractResult<Tensor> {
128 ensure!(block_axis == tensor.rank() - 1);
129 ensure!(tensor.shape()[block_axis] % self.block_len() == 0);
130 let mut scratch = vec![0u8; self.block_bytes()];
131 if tensor.datum_type() == f32::datum_type() {
132 for block in tensor.as_slice_mut::<f32>()?.chunks_mut(self.block_len()) {
133 self.quant_block_f32(block, &mut scratch);
134 self.dequant_block_f32(&scratch, block);
135 }
136 Ok(tensor)
137 } else if tensor.datum_type() == f16::datum_type() {
138 for block in tensor.as_slice_mut::<f16>()?.chunks_mut(self.block_len()) {
139 self.quant_block_f16(block, &mut scratch);
140 self.dequant_block_f16(&scratch, block);
141 }
142 Ok(tensor)
143 } else {
144 todo!()
145 }
146 }
147
148 fn pack(
149 &self,
150 input: &[u8],
151 k: usize,
152 r: usize,
153 zip: usize,
154 scales_at_end: bool,
155 ) -> TractResult<EagerPackedInput>;
156
157 unsafe fn extract_packed_panel(
158 &self,
159 value: &EagerPackedInput,
160 target: &PackedFormat,
161 panel: usize,
162 scratch: *mut u8,
163 ) -> TractResult<()>;
164
165 fn extract_at_mn_f16(
166 &self,
167 value: &EagerPackedInput,
168 mn: usize,
169 target: &mut [f16],
170 ) -> TractResult<()>;
171
172 fn extract_at_mn_f32(
173 &self,
174 value: &EagerPackedInput,
175 mn: usize,
176 target: &mut [f32],
177 ) -> TractResult<()>;
178}
179
180dyn_clone::clone_trait_object!(BlockQuant);
181dyn_hash::hash_trait_object!(BlockQuant);
182impl_downcast!(BlockQuant);
183
184#[allow(clippy::derived_hash_with_manual_eq)]
185#[derive(Clone, Hash)]
186pub struct PackedBlockQuantFormat {
187 pub bq: Box<dyn BlockQuant>,
188 pub r: usize,
189 pub zip: usize,
190 pub scales_at_end: bool,
191}
192
193impl PartialEq for PackedBlockQuantFormat {
194 fn eq(&self, other: &Self) -> bool {
195 self.bq.same_as(&*other.bq)
196 && self.r == other.r
197 && self.zip == other.zip
198 && self.scales_at_end == other.scales_at_end
199 }
200}
201
202impl Eq for PackedBlockQuantFormat {}
203
204impl Display for PackedBlockQuantFormat {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 write!(f, "Packed{}[{}]", &*self.bq, self.r)?;
207 if self.zip != 0 {
208 write!(f, "Z{}", self.zip)?;
209 }
210 if self.scales_at_end {
211 write!(f, "Se")?;
212 }
213 Ok(())
214 }
215}
216
217impl Debug for PackedBlockQuantFormat {
218 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
219 <Self as Display>::fmt(self, f)
220 }
221}
222
223impl PackedBlockQuantFormat {
224 pub fn new(bq: &dyn BlockQuant, r: usize, zip: usize, scales_at_end: bool) -> Self {
225 PackedBlockQuantFormat { bq: clone_box(bq), r, zip, scales_at_end }
226 }
227
228 pub fn simulate_precision_loss(
229 &self,
230 tensor: Tensor,
231 block_axis: usize,
232 ) -> TractResult<Tensor> {
233 self.bq.simulate_precision_loss(tensor, block_axis)
234 }
235
236 pub fn pack(&self, input: &[u8], k: usize) -> TractResult<EagerPackedInput> {
237 self.bq.pack(input, k, self.r, self.zip, self.scales_at_end)
238 }
239}
240
241impl MMMInputFormat for PackedBlockQuantFormat {
242 fn prepare_tensor(&self, t: &Tensor, _k_axis: usize, _mn_axis: usize) -> TractResult<Tensor> {
243 let packed = t
244 .as_slice::<Opaque>()?
245 .iter()
246 .map(|o| {
247 let bwf = o.downcast_ref::<BlobWithFact>().unwrap();
248 let bqf = bwf.fact.downcast_ref::<BlockQuantFact>().unwrap();
249 let packed = self.pack(&bwf.value, bqf.k())?;
250 Ok(Opaque(Arc::new(Box::new(packed) as Box<dyn MMMInputValue>)))
251 })
252 .collect::<TractResult<Vec<Opaque>>>()?;
253 tensor1(&packed).into_shape(t.shape())
254 }
255
256 fn prepare_one(
257 &self,
258 t: &Tensor,
259 k_axis: usize,
260 mn_axis: usize,
261 ) -> TractResult<Box<dyn MMMInputValue>> {
262 let t = if t.datum_type().is_number() {
264 let k = t.shape()[k_axis];
265 let m = t.shape()[mn_axis];
266 assert!(k % self.bq.block_len() == 0);
267 let t: Cow<Tensor> = if k_axis == 1 && mn_axis == 0 {
268 Cow::Borrowed(t)
269 } else {
270 Cow::Owned(t.clone().move_axis(1, 0)?)
271 };
272 let quant = if t.datum_type() == f32::datum_type() {
273 self.bq.quant_f32(t.as_slice()?)?
274 } else if t.datum_type() == f16::datum_type() {
275 self.bq.quant_f16(t.as_slice()?)?
276 } else {
277 todo!()
278 };
279 Cow::Owned(tensor0(Opaque(Arc::new(BlobWithFact {
280 value: Arc::new(quant),
281 fact: Box::new(BlockQuantFact::new(self.bq.clone(), tvec!(m, k))),
282 }))))
283 } else {
284 Cow::Borrowed(t)
285 };
286 ensure!(mn_axis == 0);
287 ensure!(k_axis == 1);
288 let bwf = t.to_scalar::<Opaque>()?.downcast_ref::<BlobWithFact>().unwrap();
289 let bqf = bwf.fact.downcast_ref::<BlockQuantFact>().unwrap();
290 let packed = self.pack(&bwf.value, bqf.k())?;
291 Ok(Box::new(packed))
292 }
293
294 fn precursor(&self) -> WeightType {
295 WeightType::BlockQuant(self.bq.clone())
296 }
297
298 fn k_alignment(&self) -> usize {
299 self.bq.block_len()
300 }
301
302 fn r(&self) -> usize {
303 self.r
304 }
305
306 fn mem_size(&self, k: TDim, mn: TDim) -> TDim {
307 k * mn * self.bq.block_bytes() / self.bq.block_len()
308 }
309
310 fn same_as(&self, other: &dyn MMMInputFormat) -> bool {
311 other.downcast_ref::<Self>().is_some_and(|other| self == other)
312 }
313
314 fn extract_at_mn_f16(
315 &self,
316 data: &EagerPackedInput,
317 mn: usize,
318 slice: &mut [f16],
319 ) -> TractResult<()> {
320 self.bq.extract_at_mn_f16(data, mn, slice)
321 }
322
323 fn extract_at_mn_f32(
324 &self,
325 data: &EagerPackedInput,
326 mn: usize,
327 slice: &mut [f32],
328 ) -> TractResult<()> {
329 self.bq.extract_at_mn_f32(data, mn, slice)
330 }
331}