tract_linalg/frame/
weights.rs

1use std::fmt::Debug;
2use tract_data::prelude::DatumType;
3
4use crate::block_quant::{BlockQuant, PackedBlockQuantFormat};
5
6use crate::mmm::MMMInputFormat;
7use crate::pack::PackedFormat;
8
9#[derive(Clone)]
10pub enum WeightType {
11    Plain(DatumType),
12    BlockQuant(Box<dyn BlockQuant>),
13}
14
15impl From<DatumType> for WeightType {
16    fn from(value: DatumType) -> Self {
17        match value {
18            DatumType::F16 => WeightType::Plain(DatumType::F16),
19            DatumType::F32 => WeightType::Plain(DatumType::F32),
20            DatumType::I32 => WeightType::Plain(DatumType::I32),
21            DatumType::I8 | DatumType::QI8(_) => WeightType::Plain(DatumType::I8),
22            DatumType::U8 | DatumType::QU8(_) => WeightType::Plain(DatumType::U8),
23            _ => panic!("Can't build a WeightType from {value:?}"),
24        }
25    }
26}
27
28impl From<Box<dyn MMMInputFormat>> for WeightType {
29    fn from(value: Box<dyn MMMInputFormat>) -> Self {
30        (&*value).into()
31    }
32}
33
34impl From<&dyn MMMInputFormat> for WeightType {
35    fn from(value: &dyn MMMInputFormat) -> Self {
36        if let Some(pf) = value.downcast_ref::<PackedFormat>() {
37            WeightType::Plain(pf.dt)
38        } else if let Some(pbqf) = value.downcast_ref::<PackedBlockQuantFormat>() {
39            WeightType::BlockQuant(dyn_clone::clone_box(&*pbqf.bq))
40        } else {
41            todo!()
42        }
43    }
44}
45
46impl PartialEq for WeightType {
47    fn eq(&self, other: &Self) -> bool {
48        use WeightType::*;
49        match (self, other) {
50            (Plain(a), Plain(b)) => a == b,
51            (BlockQuant(a), BlockQuant(b)) => a.same_as(&**b),
52            _ => false,
53        }
54    }
55}
56
57impl<BQ: BlockQuant> From<BQ> for WeightType {
58    fn from(value: BQ) -> Self {
59        WeightType::BlockQuant(dyn_clone::clone_box(&value))
60    }
61}
62
63impl Debug for WeightType {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        match self {
66            Self::Plain(p) => write!(f, "{:?}", p),
67            Self::BlockQuant(bq) => write!(f, "{:?}", bq),
68        }
69    }
70}
71
72impl WeightType {
73    pub fn as_dt(&self) -> Option<DatumType> {
74        match self {
75            WeightType::Plain(dt) => Some(*dt),
76            _ => None,
77        }
78    }
79}