tract_linalg/frame/block_quant/
mod.rs1use downcast_rs::{impl_downcast, Downcast};
2use dyn_clone::{clone_box, DynClone};
3use dyn_hash::DynHash;
4use num_traits::Zero;
5use tract_data::internal::*;
6use tract_data::itertools::Itertools;
7
8use std::alloc::Layout;
9use std::borrow::Cow;
10use std::fmt::{Debug, Display};
11use std::hash::Hash;
12use std::sync::Arc;
13
14mod helpers;
15mod q4_0;
16mod value;
17
18pub use helpers::{NibbleReader, NibbleWriter};
19pub use q4_0::Q4_0;
20pub use value::{BlockQuantFact, BlockQuantValue, PackedBlockQuantFact};
21
22use crate::mmm::{EagerPackedInput, MMMInputFormat};
23use crate::pack::PackedFormat;
24
25use crate::WeightType;
26
27use super::mmm::MMMInputValue;
28
29pub trait BlockQuant: Debug + Display + Send + Sync + DynClone + DynHash + Downcast {
30 fn same_as(&self, other: &dyn BlockQuant) -> bool;
31
32 fn block_len(&self) -> usize;
33
34 fn block_bytes(&self) -> usize;
35
36 fn dequant_block_f32(&self, quant: &[u8], block: &mut [f32]);
37 fn dequant_block_f16(&self, quant: &[u8], block: &mut [f16]);
38 fn quant_block_f16(&self, block: &[f16], quant: &mut [u8]);
39 fn quant_block_f32(&self, block: &[f32], quant: &mut [u8]);
40
41 fn quant_f16(&self, input: &[f16]) -> TractResult<Blob> {
42 unsafe {
43 let blocks = input.len() / self.block_len();
44 let mut quant = Blob::for_layout(
45 Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(),
46 );
47 for b in 0..blocks {
48 let block = &input[b * self.block_len()..][..self.block_len()];
49 let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()];
50 self.quant_block_f16(block, qblock);
51 }
52 Ok(quant)
53 }
54 }
55
56 fn quant_f32(&self, input: &[f32]) -> TractResult<Blob> {
57 unsafe {
58 let blocks = input.len() / self.block_len();
59 let mut quant = Blob::for_layout(
60 Layout::from_size_align(blocks * self.block_bytes(), 128).unwrap(),
61 );
62 for b in 0..blocks {
63 let block = &input[b * self.block_len()..][..self.block_len()];
64 let qblock = &mut quant[b * self.block_bytes()..][..self.block_bytes()];
65 self.quant_block_f32(block, qblock);
66 }
67 Ok(quant)
68 }
69 }
70
71 fn dequant_f32(&self, input: &[u8]) -> TractResult<Tensor> {
72 unsafe {
73 let blocks = input.len() / self.block_bytes();
74 let mut tensor = Tensor::uninitialized::<f32>(&[blocks * self.block_len()])?;
75 let slice = tensor.as_slice_mut::<f32>()?;
76 for b in 0..blocks {
77 let block = &mut slice[b * self.block_len()..][..self.block_len()];
78 let qblock = &input[b * self.block_bytes()..][..self.block_bytes()];
79 self.dequant_block_f32(qblock, block);
80 }
81 Ok(tensor)
82 }
83 }
84
85 fn dequant_f16(&self, input: &[u8]) -> TractResult<Tensor> {
86 unsafe {
87 let blocks = input.len() / self.block_bytes();
88 let mut tensor = Tensor::uninitialized::<f16>(&[blocks * self.block_len()])?;
89 let slice = tensor.as_slice_mut::<f16>()?;
90 for b in 0..blocks {
91 let block = &mut slice[b * self.block_len()..][..self.block_len()];
92 let qblock = &input[b * self.block_bytes()..][..self.block_bytes()];
93 self.dequant_block_f16(qblock, block);
94 }
95 Ok(tensor)
96 }
97 }
98
99 fn extract_at_offset_f16(&self, input: &[u8], offset: usize) -> f16 {
100 let len = self.block_len();
101 let block_id = offset / len;
102 let mut block = vec![f16::zero(); self.block_len()];
103 self.dequant_block_f16(
104 &input[block_id * self.block_bytes()..][..self.block_bytes()],
105 &mut block,
106 );
107 block[offset % len]
108 }
109
110 fn extract_at_offset_f32(&self, input: &[u8], offset: usize) -> f32 {
111 let len = self.block_len();
112 let block_id = offset / len;
113 let mut block = vec![f32::zero(); self.block_len()];
114 self.dequant_block_f32(
115 &input[block_id * self.block_bytes()..][..self.block_bytes()],
116 &mut block,
117 );
118 block[offset % len]
119 }
120
121 fn simulate_precision_loss(
122 &self,
123 mut tensor: Tensor,
124 block_axis: usize,
125 ) -> TractResult<Tensor> {
126 ensure!(block_axis == tensor.rank() - 1);
127 ensure!(tensor.shape()[block_axis] % self.block_len() == 0);
128 let mut scratch = vec![0u8; self.block_bytes()];
129 if tensor.datum_type() == f32::datum_type() {
130 for block in tensor.as_slice_mut::<f32>()?.chunks_mut(self.block_len()) {
131 self.quant_block_f32(block, &mut scratch);
132 self.dequant_block_f32(&scratch, block);
133 }
134 Ok(tensor)
135 } else if tensor.datum_type() == f16::datum_type() {
136 for block in tensor.as_slice_mut::<f16>()?.chunks_mut(self.block_len()) {
137 self.quant_block_f16(block, &mut scratch);
138 self.dequant_block_f16(&scratch, block);
139 }
140 Ok(tensor)
141 } else {
142 todo!()
143 }
144 }
145
146 fn pack(
147 &self,
148 input: &[u8],
149 k: usize,
150 r: usize,
151 zip: usize,
152 scales_at_end: bool,
153 ) -> TractResult<EagerPackedInput>;
154
155 unsafe fn extract_packed_panel(
156 &self,
157 value: &EagerPackedInput,
158 target: &PackedFormat,
159 panel: usize,
160 scratch: *mut u8,
161 ) -> TractResult<()>;
162
163 fn extract_at_mn_f16(
164 &self,
165 value: &EagerPackedInput,
166 mn: usize,
167 target: &mut [f16],
168 ) -> TractResult<()>;
169
170 fn extract_at_mn_f32(
171 &self,
172 value: &EagerPackedInput,
173 mn: usize,
174 target: &mut [f32],
175 ) -> TractResult<()>;
176}
177
178dyn_clone::clone_trait_object!(BlockQuant);
179dyn_hash::hash_trait_object!(BlockQuant);
180impl_downcast!(BlockQuant);
181
182#[allow(clippy::derived_hash_with_manual_eq)]
183#[derive(Clone, Hash)]
184pub struct PackedBlockQuantFormat {
185 pub bq: Box<dyn BlockQuant>,
186 pub r: usize,
187 pub zip: usize,
188 pub scales_at_end: bool,
189}
190
191impl PartialEq for PackedBlockQuantFormat {
192 fn eq(&self, other: &Self) -> bool {
193 self.bq.same_as(&*other.bq)
194 && self.r == other.r
195 && self.zip == other.zip
196 && self.scales_at_end == other.scales_at_end
197 }
198}
199
200impl Eq for PackedBlockQuantFormat {}
201
202impl Display for PackedBlockQuantFormat {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 write!(f, "Packed{}[{}]", &*self.bq, self.r)?;
205 if self.zip != 0 {
206 write!(f, "Z{}", self.zip)?;
207 }
208 if self.scales_at_end {
209 write!(f, "Se")?;
210 }
211 Ok(())
212 }
213}
214
215impl Debug for PackedBlockQuantFormat {
216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 <Self as Display>::fmt(self, f)
218 }
219}
220
221impl PackedBlockQuantFormat {
222 pub fn new(bq: &dyn BlockQuant, r: usize, zip: usize, scales_at_end: bool) -> Self {
223 PackedBlockQuantFormat { bq: clone_box(bq), r, zip, scales_at_end }
224 }
225
226 pub fn simulate_precision_loss(
227 &self,
228 tensor: Tensor,
229 block_axis: usize,
230 ) -> TractResult<Tensor> {
231 self.bq.simulate_precision_loss(tensor, block_axis)
232 }
233
234 pub fn pack(&self, input: &[u8], k: usize) -> TractResult<EagerPackedInput> {
235 self.bq.pack(input, k, self.r, self.zip, self.scales_at_end)
236 }
237}
238
239impl MMMInputFormat for PackedBlockQuantFormat {
240 fn prepare_tensor(&self, t: &Tensor, _k_axis: usize, _mn_axis: usize) -> TractResult<Tensor> {
241 let packed = t
242 .as_slice::<Opaque>()?
243 .iter()
244 .map(|o| {
245 let bqv = o.downcast_ref::<BlockQuantValue>().unwrap();
246 let packed = self.pack(&bqv.value, bqv.fact.k())?;
247 Ok(Opaque(Arc::new(Box::new(packed) as Box<dyn MMMInputValue>)))
248 })
249 .collect::<TractResult<Vec<Opaque>>>()?;
250 tensor1(&packed).into_shape(t.shape())
251 }
252
253 fn prepare_one(
254 &self,
255 t: &Tensor,
256 k_axis: usize,
257 mn_axis: usize,
258 ) -> TractResult<Box<dyn MMMInputValue>> {
259 let t = if t.datum_type().is_number() {
261 let k = t.shape()[k_axis];
262 let m = t.shape()[mn_axis];
263 assert!(k % self.bq.block_len() == 0);
264 let t: Cow<Tensor> = if k_axis == 1 && mn_axis == 0 {
265 Cow::Borrowed(t)
266 } else {
267 Cow::Owned(t.clone().move_axis(1, 0)?)
268 };
269 let quant = if t.datum_type() == f32::datum_type() {
270 self.bq.quant_f32(t.as_slice()?)?
271 } else if t.datum_type() == f16::datum_type() {
272 self.bq.quant_f16(t.as_slice()?)?
273 } else {
274 todo!()
275 };
276 Cow::Owned(tensor0(Opaque(Arc::new(BlockQuantValue {
277 value: Arc::new(quant),
278 fact: BlockQuantFact::new(self.bq.clone(), tvec!(m, k)),
279 }))))
280 } else {
281 Cow::Borrowed(t)
282 };
283 ensure!(mn_axis == 0);
284 ensure!(k_axis == 1);
285 let bqv = t.to_scalar::<Opaque>()?.downcast_ref::<BlockQuantValue>().unwrap();
286 let packed = self.pack(&bqv.value, bqv.fact.k())?;
287 Ok(Box::new(packed))
288 }
289
290 fn precursor(&self) -> WeightType {
291 WeightType::BlockQuant(self.bq.clone())
292 }
293
294 fn k_alignment(&self) -> usize {
295 self.bq.block_len()
296 }
297
298 fn r(&self) -> usize {
299 self.r
300 }
301
302 fn mem_size(&self, k: TDim, mn: TDim) -> TDim {
303 k * mn * self.bq.block_bytes() / self.bq.block_len()
304 }
305
306 fn same_as(&self, other: &dyn MMMInputFormat) -> bool {
307 other.downcast_ref::<Self>().is_some_and(|other| self == other)
308 }
309
310 fn extract_at_mn_f16(
311 &self,
312 data: &EagerPackedInput,
313 mn: usize,
314 slice: &mut [f16],
315 ) -> TractResult<()> {
316 self.bq.extract_at_mn_f16(data, mn, slice)
317 }
318
319 fn extract_at_mn_f32(
320 &self,
321 data: &EagerPackedInput,
322 mn: usize,
323 slice: &mut [f32],
324 ) -> TractResult<()> {
325 self.bq.extract_at_mn_f32(data, mn, slice)
326 }
327}