tract_linalg/frame/block_quant/
storage.rs1use std::fmt;
2use std::sync::Arc;
3
4use tract_data::internal::*;
5
6use super::BlockQuant;
7use super::BlockQuantFact;
8
9#[derive(Clone, PartialEq, Eq)]
14pub struct BlockQuantStorage {
15 format: Box<dyn BlockQuant>,
16 data: Arc<Blob>,
17}
18
19impl BlockQuantStorage {
20 fn expected_bytes(format: &dyn BlockQuant, m: usize, k: usize) -> usize {
21 m * k / format.block_len() * format.block_bytes()
22 }
23
24 pub fn new(
25 format: Box<dyn BlockQuant>,
26 m: usize,
27 k: usize,
28 data: Arc<Blob>,
29 ) -> TractResult<Self> {
30 let expected = Self::expected_bytes(&*format, m, k);
31 ensure!(
32 data.len() == expected,
33 "BlockQuantStorage::new: blob length {} does not match expected {} (m={}, k={}, format={})",
34 data.len(),
35 expected,
36 m,
37 k,
38 format,
39 );
40 Ok(Self { format, data })
41 }
42
43 pub fn format(&self) -> &dyn BlockQuant {
44 &*self.format
45 }
46
47 pub fn value(&self) -> &Arc<Blob> {
49 &self.data
50 }
51
52 pub fn into_tensor_with_shape(self, dt: DatumType, shape: &[usize]) -> Tensor {
57 Tensor::from_storage(dt, shape, self)
58 }
59}
60
61pub fn block_quant_slice<'a>(
63 data: &'a [u8],
64 format: &dyn BlockQuant,
65 m_per_group: usize,
66 k: usize,
67 g: usize,
68) -> &'a [u8] {
69 let row_bytes = k / format.block_len() * format.block_bytes();
70 let group_bytes = m_per_group * row_bytes;
71 let start = g * group_bytes;
72 &data[start..start + group_bytes]
73}
74
75impl fmt::Debug for BlockQuantStorage {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 write!(f, "BlockQuantStorage({}, bytes={})", self.format, self.data.len())
78 }
79}
80
81impl fmt::Display for BlockQuantStorage {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 write!(f, "BlockQuantStorage({}, bytes={})", self.format, self.data.len())
84 }
85}
86
87impl TensorStorage for BlockQuantStorage {
88 fn byte_len(&self) -> usize {
89 self.data.len()
90 }
91
92 fn is_empty(&self) -> bool {
93 self.data.is_empty()
94 }
95
96 fn deep_clone(&self) -> Box<dyn TensorStorage> {
97 Box::new(self.clone())
98 }
99
100 fn as_plain(&self) -> Option<&PlainStorage> {
101 None
102 }
103
104 fn as_plain_mut(&mut self) -> Option<&mut PlainStorage> {
105 None
106 }
107
108 fn into_plain(self: Box<Self>) -> Option<PlainStorage> {
109 None
110 }
111
112 fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
113 state.write_u8(1);
114 self.format.dyn_hash(state);
115 state.write(self.data.as_bytes());
116 }
117
118 fn exotic_fact(&self, shape: &[usize]) -> TractResult<Option<Box<dyn ExoticFact>>> {
119 Ok(Some(Box::new(BlockQuantFact::new(dyn_clone::clone_box(&*self.format), shape.into()))))
120 }
121}