Skip to main content

tract_linalg/frame/block_quant/
value.rs

1use super::{BlockQuant, PackedBlockQuantFormat};
2use tract_data::TVec;
3use tract_data::internal::*;
4
5#[allow(clippy::derived_hash_with_manual_eq)]
6#[derive(Clone, Hash)]
7pub struct BlockQuantFact {
8    pub format: Box<dyn BlockQuant>,
9    shape: TVec<usize>,
10}
11impl BlockQuantFact {
12    pub fn new(format: Box<dyn BlockQuant>, shape: TVec<usize>) -> Self {
13        Self { format, shape }
14    }
15
16    pub fn m(&self) -> usize {
17        self.shape[0]
18    }
19
20    pub fn k(&self) -> usize {
21        self.shape.iter().skip(1).product()
22    }
23
24    pub fn shape(&self) -> &[usize] {
25        &self.shape
26    }
27}
28
29impl std::fmt::Debug for BlockQuantFact {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        write!(f, "{}({:?})", self.format, self.shape)
32    }
33}
34
35impl OpaqueFact for BlockQuantFact {
36    fn same_as(&self, other: &dyn OpaqueFact) -> bool {
37        other.downcast_ref::<Self>().is_some_and(|o| o == self)
38    }
39
40    fn buffer_sizes(&self) -> TVec<TDim> {
41        tvec!(
42            (self.shape.iter().product::<usize>() / self.format.block_len()
43                * self.format.block_bytes())
44            .to_dim()
45        )
46    }
47}
48
49impl PartialEq for BlockQuantFact {
50    fn eq(&self, other: &Self) -> bool {
51        self.format.same_as(&*other.format) && self.shape == other.shape
52    }
53}
54
55#[derive(Clone, Hash, PartialEq)]
56pub struct PackedBlockQuantFact {
57    pub format: PackedBlockQuantFormat,
58    pub shape: TVec<usize>,
59}
60
61impl std::fmt::Debug for PackedBlockQuantFact {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        write!(f, "{}({:?})", self.format, self.shape)
64    }
65}
66
67impl OpaqueFact for PackedBlockQuantFact {
68    fn same_as(&self, other: &dyn OpaqueFact) -> bool {
69        other.downcast_ref::<Self>().is_some_and(|o| o == self)
70    }
71
72    fn buffer_sizes(&self) -> TVec<TDim> {
73        tvec!(
74            (self.shape.iter().product::<usize>() / self.format.bq.block_len()
75                * self.format.bq.block_bytes())
76            .to_dim()
77        )
78    }
79}