Skip to main content

tract_linalg/frame/block_quant/
storage.rs

1use std::fmt;
2use std::sync::Arc;
3
4use tract_data::internal::*;
5
6use super::BlockQuant;
7use super::BlockQuantFact;
8
9/// Concrete tensor storage for block-quantized weights.
10///
11/// Stores a single contiguous `Arc<Blob>` of quantized data along with the
12/// block-quant format. Shape lives on the tensor, not here.
13#[derive(Clone, PartialEq, Eq)]
14pub struct BlockQuantStorage {
15    format: Box<dyn BlockQuant>,
16    data: Arc<Blob>,
17}
18
19impl BlockQuantStorage {
20    fn expected_bytes(format: &dyn BlockQuant, m: usize, k: usize) -> usize {
21        m * k / format.block_len() * format.block_bytes()
22    }
23
24    pub fn new(
25        format: Box<dyn BlockQuant>,
26        m: usize,
27        k: usize,
28        data: Arc<Blob>,
29    ) -> TractResult<Self> {
30        let expected = Self::expected_bytes(&*format, m, k);
31        ensure!(
32            data.len() == expected,
33            "BlockQuantStorage::new: blob length {} does not match expected {} (m={}, k={}, format={})",
34            data.len(),
35            expected,
36            m,
37            k,
38            format,
39        );
40        Ok(Self { format, data })
41    }
42
43    pub fn format(&self) -> &dyn BlockQuant {
44        &*self.format
45    }
46
47    /// Returns the single contiguous blob.
48    pub fn value(&self) -> &Arc<Blob> {
49        &self.data
50    }
51
52    /// Converts this storage into a `Tensor` with the given shape.
53    ///
54    /// `dt` is the logical element type (e.g. f32, f16) — the type these
55    /// weights represent when dequantized.
56    pub fn into_tensor_with_shape(self, dt: DatumType, shape: &[usize]) -> Tensor {
57        Tensor::from_storage(dt, shape, self)
58    }
59}
60
61/// Returns a byte slice for a single group within contiguous block-quant data.
62pub fn block_quant_slice<'a>(
63    data: &'a [u8],
64    format: &dyn BlockQuant,
65    m_per_group: usize,
66    k: usize,
67    g: usize,
68) -> &'a [u8] {
69    let row_bytes = k / format.block_len() * format.block_bytes();
70    let group_bytes = m_per_group * row_bytes;
71    let start = g * group_bytes;
72    &data[start..start + group_bytes]
73}
74
75impl fmt::Debug for BlockQuantStorage {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        write!(f, "BlockQuantStorage({}, bytes={})", self.format, self.data.len())
78    }
79}
80
81impl fmt::Display for BlockQuantStorage {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        write!(f, "BlockQuantStorage({}, bytes={})", self.format, self.data.len())
84    }
85}
86
87impl TensorStorage for BlockQuantStorage {
88    fn byte_len(&self) -> usize {
89        self.data.len()
90    }
91
92    fn is_empty(&self) -> bool {
93        self.data.is_empty()
94    }
95
96    fn deep_clone(&self) -> Box<dyn TensorStorage> {
97        Box::new(self.clone())
98    }
99
100    fn as_plain(&self) -> Option<&PlainStorage> {
101        None
102    }
103
104    fn as_plain_mut(&mut self) -> Option<&mut PlainStorage> {
105        None
106    }
107
108    fn into_plain(self: Box<Self>) -> Option<PlainStorage> {
109        None
110    }
111
112    fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
113        state.write_u8(1);
114        self.format.dyn_hash(state);
115        state.write(self.data.as_bytes());
116    }
117
118    fn exotic_fact(&self, shape: &[usize]) -> TractResult<Option<Box<dyn ExoticFact>>> {
119        Ok(Some(Box::new(BlockQuantFact::new(dyn_clone::clone_box(&*self.format), shape.into()))))
120    }
121}