Skip to main content

tract_linalg/frame/block_quant/
value.rs

1use super::{BlockQuant, PackedBlockQuantFormat};
2use tract_data::TVec;
3use tract_data::internal::*;
4
5#[allow(clippy::derived_hash_with_manual_eq)]
6#[derive(Clone, Hash)]
7pub struct BlockQuantFact {
8    pub format: Box<dyn BlockQuant>,
9    shape: TVec<usize>,
10}
11impl BlockQuantFact {
12    pub fn new(format: Box<dyn BlockQuant>, shape: TVec<usize>) -> Self {
13        Self { format, shape }
14    }
15
16    /// Product of all leading dims except the last two (M, K).
17    /// For rank <= 2, returns 1.
18    pub fn num_groups(&self) -> usize {
19        if self.shape.len() <= 2 { 1 } else { self.shape[..self.shape.len() - 2].iter().product() }
20    }
21
22    /// Product of all dims except the last (K). This is the flat M
23    /// dimension (groups * m_per_group).
24    pub fn m(&self) -> usize {
25        self.shape[..self.shape.len() - 1].iter().product()
26    }
27
28    /// Last dimension.
29    pub fn k(&self) -> usize {
30        *self.shape.last().unwrap()
31    }
32
33    pub fn shape(&self) -> &[usize] {
34        &self.shape
35    }
36}
37
38impl std::fmt::Debug for BlockQuantFact {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        write!(f, "{}({:?})", self.format, self.shape)
41    }
42}
43
44impl ExoticFact for BlockQuantFact {
45    fn buffer_sizes(&self) -> TVec<TDim> {
46        let total = self.m() * self.k() / self.format.block_len() * self.format.block_bytes();
47        tvec!(total.to_dim())
48    }
49}
50
51impl PartialEq for BlockQuantFact {
52    fn eq(&self, other: &Self) -> bool {
53        *self.format == *other.format && self.shape == other.shape
54    }
55}
56impl Eq for BlockQuantFact {}
57
58#[derive(Clone, Hash, PartialEq)]
59pub struct PackedBlockQuantFact {
60    pub format: PackedBlockQuantFormat,
61    pub shape: TVec<usize>,
62}
63impl Eq for PackedBlockQuantFact {}
64
65impl std::fmt::Debug for PackedBlockQuantFact {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        write!(f, "{}({:?})", self.format, self.shape)
68    }
69}
70
71impl ExoticFact for PackedBlockQuantFact {
72    fn buffer_sizes(&self) -> TVec<TDim> {
73        tvec!(
74            (self.shape.iter().product::<usize>() / self.format.bq.block_len()
75                * self.format.bq.block_bytes())
76            .to_dim()
77        )
78    }
79}