tract_linalg/frame/
weights.rs1use std::fmt::Debug;
2use tract_data::prelude::DatumType;
3
4use crate::block_quant::{BlockQuant, PackedBlockQuantFormat};
5
6use crate::mmm::MMMInputFormat;
7use crate::pack::PackedFormat;
8
9#[derive(Clone)]
10pub enum WeightType {
11 Plain(DatumType),
12 BlockQuant(Box<dyn BlockQuant>),
13}
14
15impl From<DatumType> for WeightType {
16 fn from(value: DatumType) -> Self {
17 match value {
18 DatumType::F16 => WeightType::Plain(DatumType::F16),
19 DatumType::F32 => WeightType::Plain(DatumType::F32),
20 DatumType::I32 => WeightType::Plain(DatumType::I32),
21 DatumType::I8 | DatumType::QI8(_) => WeightType::Plain(DatumType::I8),
22 DatumType::U8 | DatumType::QU8(_) => WeightType::Plain(DatumType::U8),
23 _ => panic!("Can't build a WeightType from {value:?}"),
24 }
25 }
26}
27
28impl From<Box<dyn MMMInputFormat>> for WeightType {
29 fn from(value: Box<dyn MMMInputFormat>) -> Self {
30 (&*value).into()
31 }
32}
33
34impl From<&dyn MMMInputFormat> for WeightType {
35 fn from(value: &dyn MMMInputFormat) -> Self {
36 if let Some(pf) = value.downcast_ref::<PackedFormat>() {
37 WeightType::Plain(pf.dt)
38 } else if let Some(pbqf) = value.downcast_ref::<PackedBlockQuantFormat>() {
39 WeightType::BlockQuant(dyn_clone::clone_box(&*pbqf.bq))
40 } else {
41 todo!()
42 }
43 }
44}
45
46impl PartialEq for WeightType {
47 fn eq(&self, other: &Self) -> bool {
48 use WeightType::*;
49 match (self, other) {
50 (Plain(a), Plain(b)) => a == b,
51 (BlockQuant(a), BlockQuant(b)) => a.same_as(&**b),
52 _ => false,
53 }
54 }
55}
56
57impl<BQ: BlockQuant> From<BQ> for WeightType {
58 fn from(value: BQ) -> Self {
59 WeightType::BlockQuant(dyn_clone::clone_box(&value))
60 }
61}
62
63impl Debug for WeightType {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 match self {
66 Self::Plain(p) => write!(f, "{:?}", p),
67 Self::BlockQuant(bq) => write!(f, "{:?}", bq),
68 }
69 }
70}
71
72impl WeightType {
73 pub fn as_dt(&self) -> Option<DatumType> {
74 match self {
75 WeightType::Plain(dt) => Some(*dt),
76 _ => None,
77 }
78 }
79}