tract_linalg/frame/mmm/
kit.rs

1use std::fmt::Debug;
2
3use tract_data::prelude::DatumType;
4
5use crate::frame::block_quant::{BlockQuant, PackedBlockQuantFormat};
6
7use super::pack::PackedFormat;
8use super::panel_extract::PanelExtractor;
9use super::{MMMInputFormat, MatMatMul};
10
11// final hypothesis
12// * A is const weight. either a DT, or a blockquant
13// * m, k are constant, n is an undetermined TDim
14//
15// for now (?) acc.dt == B.dt == C.dt
16
17#[derive(Clone)]
18pub enum WeightType {
19    Plain(DatumType),
20    BlockQuant(Box<dyn BlockQuant>),
21}
22
23impl From<DatumType> for WeightType {
24    fn from(value: DatumType) -> Self {
25        match value {
26            DatumType::F16 => WeightType::Plain(DatumType::F16),
27            DatumType::F32 => WeightType::Plain(DatumType::F32),
28            DatumType::I32 => WeightType::Plain(DatumType::I32),
29            _ => panic!(),
30        }
31    }
32}
33
34impl From<Box<dyn MMMInputFormat>> for WeightType {
35    fn from(value: Box<dyn MMMInputFormat>) -> Self {
36        (&*value).into()
37    }
38}
39
40impl From<&dyn MMMInputFormat> for WeightType {
41    fn from(value: &dyn MMMInputFormat) -> Self {
42        if let Some(pf) = value.downcast_ref::<PackedFormat>() {
43            WeightType::Plain(pf.dt)
44        } else if let Some(pbqf) = value.downcast_ref::<PackedBlockQuantFormat>() {
45            WeightType::BlockQuant(dyn_clone::clone_box(&*pbqf.bq))
46        } else {
47            todo!()
48        }
49    }
50}
51
52impl PartialEq for WeightType {
53    fn eq(&self, other: &Self) -> bool {
54        use WeightType::*;
55        match (self, other) {
56            (Plain(a), Plain(b)) => a == b,
57            (BlockQuant(a), BlockQuant(b)) => a.same_as(&**b),
58            _ => false,
59        }
60    }
61}
62
63impl<BQ: BlockQuant> From<BQ> for WeightType {
64    fn from(value: BQ) -> Self {
65        WeightType::BlockQuant(dyn_clone::clone_box(&value))
66    }
67}
68
69impl Debug for WeightType {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        match self {
72            Self::Plain(p) => write!(f, "{:?}", p),
73            Self::BlockQuant(bq) => write!(f, "{:?}", bq),
74        }
75    }
76}
77
78// pub enum PackedWeightType {
79//     Plain(PackedFormat),
80//     BlockQuant(PackedBlockQuantFormat),
81// }
82
83#[derive(Debug, Copy, Clone, PartialEq, Eq)]
84pub enum KitDatumType {
85    F16,
86    F32,
87    I32,
88}
89
90impl From<DatumType> for KitDatumType {
91    fn from(value: DatumType) -> Self {
92        match value {
93            DatumType::F16 => KitDatumType::F16,
94            DatumType::F32 => KitDatumType::F32,
95            DatumType::I32 => KitDatumType::I32,
96            _ => panic!(),
97        }
98    }
99}
100
101impl From<&dyn MMMInputFormat> for KitDatumType {
102    fn from(value: &dyn MMMInputFormat) -> Self {
103        if let Some(pf) = value.downcast_ref::<PackedFormat>() {
104            pf.dt.into()
105        } else {
106            todo!()
107        }
108    }
109}
110
111impl From<Box<dyn MMMInputFormat>> for KitDatumType {
112    fn from(value: Box<dyn MMMInputFormat>) -> Self {
113        (&*value).into()
114    }
115}
116
117#[derive(Debug)]
118pub struct MMMKit {
119    pub weight: WeightType,
120    pub accumulator: KitDatumType,
121    pub activation: KitDatumType,
122    pub static_packer: Box<dyn MMMInputFormat>,
123    pub items: Vec<MMMKitItem>,
124    pub generic_fallback: bool,
125}
126
127#[derive(Debug)]
128pub struct MMMKitItem {
129    pub mmm: Box<dyn MatMatMul>,
130    pub packing: usize,
131    pub weight_panel_extractor: Option<PanelExtractor>,
132}
133
134impl MMMKit {
135    pub(crate) fn new_for_mmm(mmm: Box<dyn MatMatMul>, packing: usize) -> MMMKit {
136        let static_packer = mmm.packings()[packing].0.clone();
137        Self::new(
138            static_packer.clone(),
139            mmm.internal_type(),
140            &*mmm.packings()[packing].1,
141            &*static_packer,
142        )
143        .with_native(mmm, packing)
144    }
145
146    pub(crate) fn new(
147        weight: impl Into<WeightType>,
148        accumulator: impl Into<KitDatumType>,
149        activation: impl Into<KitDatumType>,
150        static_packer: &dyn MMMInputFormat,
151    ) -> MMMKit {
152        let (weight, accumulator, activation) =
153            (weight.into(), accumulator.into(), activation.into());
154        let kit = MMMKit {
155            weight,
156            accumulator,
157            activation,
158            static_packer: dyn_clone::clone_box(static_packer),
159            items: vec![],
160            generic_fallback: false,
161        };
162        match &kit.weight {
163            WeightType::Plain(p) => {
164                debug_assert!(
165                    kit.static_packer.downcast_ref::<PackedFormat>().is_some_and(|pf| pf.dt == *p),
166                    "Static packer not compatible with weight format {kit:?}"
167                )
168            }
169            WeightType::BlockQuant(bq) => debug_assert!(
170                kit.static_packer
171                    .downcast_ref::<PackedBlockQuantFormat>()
172                    .is_some_and(|pbqf| pbqf.bq.same_as(&**bq)),
173                "Static packer not compatible with weight format {kit:?}"
174            ),
175        };
176        kit
177    }
178
179    fn add_item(
180        mut self,
181        mmm: Box<dyn MatMatMul>,
182        packing: usize,
183        weight_panel_extractor: Option<PanelExtractor>,
184    ) -> Self {
185        debug_assert!(
186            self.accumulator == mmm.internal_type().into(),
187            "Accumulator mismatch {self:?} {mmm:?}/{packing} {:?}",
188            mmm.packings()[packing].0
189        );
190        debug_assert!(
191            mmm.packings()[packing]
192                .1
193                .downcast_ref::<PackedFormat>()
194                .is_some_and(|pf| KitDatumType::from(pf.dt) == self.activation),
195            "Activation packed dt mismatch {self:?} {:?}",
196            mmm.packings()[packing].1
197        );
198        self.items.push(MMMKitItem { mmm, packing, weight_panel_extractor });
199        self
200    }
201
202    pub(crate) fn with_native(self, mmm: Box<dyn MatMatMul>, packing: usize) -> Self {
203        debug_assert!(
204            mmm.packings()[packing].0.same_as(&*self.static_packer),
205            "Weight packing mismatch {self:?} {mmm:?}/{packing} {:?}",
206            mmm.packings()[packing].0
207        );
208        self.add_item(mmm, packing, None)
209    }
210
211    #[allow(dead_code)]
212    pub(crate) fn with_extracting(
213        self,
214        mmm: Box<dyn MatMatMul>,
215        packing: usize,
216        weight_panel_extractor: PanelExtractor,
217    ) -> Self {
218        debug_assert!(
219            self.static_packer.same_as(&*weight_panel_extractor.from),
220            "Static weight packing/extractor mismatch {self:?} {mmm:?}/{packing} {:?} {weight_panel_extractor:?}",
221            mmm.packings()[packing].0
222        );
223        debug_assert!(
224            weight_panel_extractor.to.same_as(&*mmm.packings()[packing].0),
225            "Extractor/kernel packing mismatch {self:?} {mmm:?}/{packing} {:?} {weight_panel_extractor:?}",
226            mmm.packings()[packing].0
227        );
228        self.add_item(mmm, packing, Some(weight_panel_extractor))
229    }
230
231    pub(crate) fn with_generic_fallback(self, generic_fallback: bool) -> Self {
232        Self { generic_fallback, ..self }
233    }
234
235    pub fn name(&self) -> &str {
236        self.items[0].mmm.name()
237    }
238
239    pub fn item_for_mv(&self) -> &MMMKitItem {
240        self.items.iter().min_by_key(|item| item.n()).unwrap()
241    }
242
243    pub fn item_for_squarish(&self) -> &MMMKitItem {
244        self.items.iter().max_by_key(|item| item.n()).unwrap()
245    }
246}
247
248impl MMMKitItem {
249    pub fn n(&self) -> usize {
250        self.mmm.packings()[self.packing].1.r()
251    }
252}