tract_linalg/frame/block_quant/
value.rs

1use std::ops::Range;
2use std::sync::Arc;
3
4use super::{BlockQuant, PackedBlockQuantFormat};
5use tract_data::internal::*;
6use tract_data::TVec;
7
8#[allow(clippy::derived_hash_with_manual_eq)]
9#[derive(Clone, Hash)]
10pub struct BlockQuantFact {
11    pub format: Box<dyn BlockQuant>,
12    shape: TVec<usize>,
13}
14impl BlockQuantFact {
15    pub fn new(format: Box<dyn BlockQuant>, shape: TVec<usize>) -> Self {
16        Self { format, shape }
17    }
18
19    pub fn m(&self) -> usize {
20        self.shape[0]
21    }
22
23    pub fn k(&self) -> usize {
24        self.shape.iter().skip(1).product()
25    }
26
27    pub fn shape(&self) -> &[usize] {
28        &self.shape
29    }
30}
31
32impl std::fmt::Debug for BlockQuantFact {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(f, "{}({:?})", self.format, self.shape)
35    }
36}
37
38impl OpaqueFact for BlockQuantFact {
39    fn mem_size(&self) -> TDim {
40        (self.shape.iter().product::<usize>() / self.format.block_len() * self.format.block_bytes())
41            .to_dim()
42    }
43
44    fn same_as(&self, other: &dyn OpaqueFact) -> bool {
45        other.downcast_ref::<Self>().is_some_and(|o| o == self)
46    }
47}
48
49impl PartialEq for BlockQuantFact {
50    fn eq(&self, other: &Self) -> bool {
51        self.format.same_as(&*other.format) && self.shape == other.shape
52    }
53}
54
55#[derive(Clone, Hash)]
56pub struct BlockQuantValue {
57    pub fact: BlockQuantFact,
58    pub value: Arc<Blob>,
59}
60
61impl BlockQuantValue {
62    pub fn split_rows(&self, range: Range<usize>) -> TractResult<BlockQuantValue> {
63        let row_bytes =
64            self.fact.k() / self.fact.format.block_len() * self.fact.format.block_bytes();
65        let mut value =
66            unsafe { Blob::new_for_size_and_align(range.len() * row_bytes, vector_size()) };
67        value.copy_from_slice(&self.value[range.start * row_bytes..][..range.len() * row_bytes]);
68        let mut shape = self.fact.shape.clone();
69        shape[0] = range.len();
70        Ok(BlockQuantValue {
71            fact: BlockQuantFact { format: self.fact.format.clone(), shape },
72            value: Arc::new(value),
73        })
74    }
75}
76
77impl OpaquePayload for BlockQuantValue {
78    fn same_as(&self, other: &dyn OpaquePayload) -> bool {
79        other.downcast_ref::<Self>().is_some_and(|o| o.fact == self.fact && o.value == self.value)
80    }
81}
82
83impl std::fmt::Debug for BlockQuantValue {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        write!(f, "{:?} {:?}", self.fact, self.value)
86    }
87}
88
89impl std::fmt::Display for BlockQuantValue {
90    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        write!(f, "{self:?}")
92    }
93}
94
95#[derive(Clone, Hash, PartialEq)]
96pub struct PackedBlockQuantFact {
97    pub format: PackedBlockQuantFormat,
98    pub shape: TVec<usize>,
99}
100
101impl std::fmt::Debug for PackedBlockQuantFact {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        write!(f, "{}({:?})", self.format, self.shape)
104    }
105}
106
107impl OpaqueFact for PackedBlockQuantFact {
108    fn mem_size(&self) -> TDim {
109        (self.shape.iter().product::<usize>() / self.format.bq.block_len()
110            * self.format.bq.block_bytes())
111        .to_dim()
112    }
113    fn same_as(&self, other: &dyn OpaqueFact) -> bool {
114        other.downcast_ref::<Self>().is_some_and(|o| o == self)
115    }
116}