1use std::fmt::Debug;
2
3use tract_data::prelude::DatumType;
4
5use crate::frame::block_quant::{BlockQuant, PackedBlockQuantFormat};
6
7use super::pack::PackedFormat;
8use super::panel_extract::PanelExtractor;
9use super::{MMMInputFormat, MatMatMul};
10
11#[derive(Clone)]
18pub enum WeightType {
19 Plain(DatumType),
20 BlockQuant(Box<dyn BlockQuant>),
21}
22
23impl From<DatumType> for WeightType {
24 fn from(value: DatumType) -> Self {
25 match value {
26 DatumType::F16 => WeightType::Plain(DatumType::F16),
27 DatumType::F32 => WeightType::Plain(DatumType::F32),
28 DatumType::I32 => WeightType::Plain(DatumType::I32),
29 _ => panic!(),
30 }
31 }
32}
33
34impl From<Box<dyn MMMInputFormat>> for WeightType {
35 fn from(value: Box<dyn MMMInputFormat>) -> Self {
36 (&*value).into()
37 }
38}
39
40impl From<&dyn MMMInputFormat> for WeightType {
41 fn from(value: &dyn MMMInputFormat) -> Self {
42 if let Some(pf) = value.downcast_ref::<PackedFormat>() {
43 WeightType::Plain(pf.dt)
44 } else if let Some(pbqf) = value.downcast_ref::<PackedBlockQuantFormat>() {
45 WeightType::BlockQuant(dyn_clone::clone_box(&*pbqf.bq))
46 } else {
47 todo!()
48 }
49 }
50}
51
52impl PartialEq for WeightType {
53 fn eq(&self, other: &Self) -> bool {
54 use WeightType::*;
55 match (self, other) {
56 (Plain(a), Plain(b)) => a == b,
57 (BlockQuant(a), BlockQuant(b)) => a.same_as(&**b),
58 _ => false,
59 }
60 }
61}
62
63impl<BQ: BlockQuant> From<BQ> for WeightType {
64 fn from(value: BQ) -> Self {
65 WeightType::BlockQuant(dyn_clone::clone_box(&value))
66 }
67}
68
69impl Debug for WeightType {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 match self {
72 Self::Plain(p) => write!(f, "{:?}", p),
73 Self::BlockQuant(bq) => write!(f, "{:?}", bq),
74 }
75 }
76}
77
78#[derive(Debug, Copy, Clone, PartialEq, Eq)]
84pub enum KitDatumType {
85 F16,
86 F32,
87 I32,
88}
89
90impl From<DatumType> for KitDatumType {
91 fn from(value: DatumType) -> Self {
92 match value {
93 DatumType::F16 => KitDatumType::F16,
94 DatumType::F32 => KitDatumType::F32,
95 DatumType::I32 => KitDatumType::I32,
96 _ => panic!(),
97 }
98 }
99}
100
101impl From<&dyn MMMInputFormat> for KitDatumType {
102 fn from(value: &dyn MMMInputFormat) -> Self {
103 if let Some(pf) = value.downcast_ref::<PackedFormat>() {
104 pf.dt.into()
105 } else {
106 todo!()
107 }
108 }
109}
110
111impl From<Box<dyn MMMInputFormat>> for KitDatumType {
112 fn from(value: Box<dyn MMMInputFormat>) -> Self {
113 (&*value).into()
114 }
115}
116
117#[derive(Debug)]
118pub struct MMMKit {
119 pub weight: WeightType,
120 pub accumulator: KitDatumType,
121 pub activation: KitDatumType,
122 pub static_packer: Box<dyn MMMInputFormat>,
123 pub items: Vec<MMMKitItem>,
124 pub generic_fallback: bool,
125}
126
127#[derive(Debug)]
128pub struct MMMKitItem {
129 pub mmm: Box<dyn MatMatMul>,
130 pub packing: usize,
131 pub weight_panel_extractor: Option<PanelExtractor>,
132}
133
134impl MMMKit {
135 pub(crate) fn new_for_mmm(mmm: Box<dyn MatMatMul>, packing: usize) -> MMMKit {
136 let static_packer = mmm.packings()[packing].0.clone();
137 Self::new(
138 static_packer.clone(),
139 mmm.internal_type(),
140 &*mmm.packings()[packing].1,
141 &*static_packer,
142 )
143 .with_native(mmm, packing)
144 }
145
146 pub(crate) fn new(
147 weight: impl Into<WeightType>,
148 accumulator: impl Into<KitDatumType>,
149 activation: impl Into<KitDatumType>,
150 static_packer: &dyn MMMInputFormat,
151 ) -> MMMKit {
152 let (weight, accumulator, activation) =
153 (weight.into(), accumulator.into(), activation.into());
154 let kit = MMMKit {
155 weight,
156 accumulator,
157 activation,
158 static_packer: dyn_clone::clone_box(static_packer),
159 items: vec![],
160 generic_fallback: false,
161 };
162 match &kit.weight {
163 WeightType::Plain(p) => {
164 debug_assert!(
165 kit.static_packer.downcast_ref::<PackedFormat>().is_some_and(|pf| pf.dt == *p),
166 "Static packer not compatible with weight format {kit:?}"
167 )
168 }
169 WeightType::BlockQuant(bq) => debug_assert!(
170 kit.static_packer
171 .downcast_ref::<PackedBlockQuantFormat>()
172 .is_some_and(|pbqf| pbqf.bq.same_as(&**bq)),
173 "Static packer not compatible with weight format {kit:?}"
174 ),
175 };
176 kit
177 }
178
179 fn add_item(
180 mut self,
181 mmm: Box<dyn MatMatMul>,
182 packing: usize,
183 weight_panel_extractor: Option<PanelExtractor>,
184 ) -> Self {
185 debug_assert!(
186 self.accumulator == mmm.internal_type().into(),
187 "Accumulator mismatch {self:?} {mmm:?}/{packing} {:?}",
188 mmm.packings()[packing].0
189 );
190 debug_assert!(
191 mmm.packings()[packing]
192 .1
193 .downcast_ref::<PackedFormat>()
194 .is_some_and(|pf| KitDatumType::from(pf.dt) == self.activation),
195 "Activation packed dt mismatch {self:?} {:?}",
196 mmm.packings()[packing].1
197 );
198 self.items.push(MMMKitItem { mmm, packing, weight_panel_extractor });
199 self
200 }
201
202 pub(crate) fn with_native(self, mmm: Box<dyn MatMatMul>, packing: usize) -> Self {
203 debug_assert!(
204 mmm.packings()[packing].0.same_as(&*self.static_packer),
205 "Weight packing mismatch {self:?} {mmm:?}/{packing} {:?}",
206 mmm.packings()[packing].0
207 );
208 self.add_item(mmm, packing, None)
209 }
210
211 #[allow(dead_code)]
212 pub(crate) fn with_extracting(
213 self,
214 mmm: Box<dyn MatMatMul>,
215 packing: usize,
216 weight_panel_extractor: PanelExtractor,
217 ) -> Self {
218 debug_assert!(
219 self.static_packer.same_as(&*weight_panel_extractor.from),
220 "Static weight packing/extractor mismatch {self:?} {mmm:?}/{packing} {:?} {weight_panel_extractor:?}",
221 mmm.packings()[packing].0
222 );
223 debug_assert!(
224 weight_panel_extractor.to.same_as(&*mmm.packings()[packing].0),
225 "Extractor/kernel packing mismatch {self:?} {mmm:?}/{packing} {:?} {weight_panel_extractor:?}",
226 mmm.packings()[packing].0
227 );
228 self.add_item(mmm, packing, Some(weight_panel_extractor))
229 }
230
231 pub(crate) fn with_generic_fallback(self, generic_fallback: bool) -> Self {
232 Self { generic_fallback, ..self }
233 }
234
235 pub fn name(&self) -> &str {
236 self.items[0].mmm.name()
237 }
238
239 pub fn item_for_mv(&self) -> &MMMKitItem {
240 self.items.iter().min_by_key(|item| item.n()).unwrap()
241 }
242
243 pub fn item_for_squarish(&self) -> &MMMKitItem {
244 self.items.iter().max_by_key(|item| item.n()).unwrap()
245 }
246}
247
248impl MMMKitItem {
249 pub fn n(&self) -> usize {
250 self.mmm.packings()[self.packing].1.r()
251 }
252}