Skip to main content

tract_linalg/frame/
weights.rs

1use std::fmt::Debug;
2use tract_data::prelude::DatumType;
3
4use crate::block_quant::{BlockQuant, PackedBlockQuantFormat};
5
6use crate::mmm::MMMInputFormat;
7use crate::pack::PackedFormat;
8
9#[derive(Clone)]
10pub enum WeightType {
11    Plain(DatumType),
12    BlockQuant(Box<dyn BlockQuant>),
13}
14
15impl From<DatumType> for WeightType {
16    fn from(value: DatumType) -> Self {
17        match value {
18            DatumType::F16 => WeightType::Plain(DatumType::F16),
19            DatumType::F32 => WeightType::Plain(DatumType::F32),
20            DatumType::F64 => WeightType::Plain(DatumType::F64),
21            DatumType::I32 => WeightType::Plain(DatumType::I32),
22            DatumType::I8 | DatumType::QI8(_) => WeightType::Plain(DatumType::I8),
23            DatumType::U8 | DatumType::QU8(_) => WeightType::Plain(DatumType::U8),
24            _ => panic!("Can't build a WeightType from {value:?}"),
25        }
26    }
27}
28
29impl From<Box<dyn MMMInputFormat>> for WeightType {
30    fn from(value: Box<dyn MMMInputFormat>) -> Self {
31        (&*value).into()
32    }
33}
34
35impl From<&dyn MMMInputFormat> for WeightType {
36    fn from(value: &dyn MMMInputFormat) -> Self {
37        if let Some(pf) = value.downcast_ref::<PackedFormat>() {
38            WeightType::Plain(pf.dt)
39        } else if let Some(pbqf) = value.downcast_ref::<PackedBlockQuantFormat>() {
40            WeightType::BlockQuant(dyn_clone::clone_box(&*pbqf.bq))
41        } else {
42            todo!()
43        }
44    }
45}
46
47impl PartialEq for WeightType {
48    fn eq(&self, other: &Self) -> bool {
49        use WeightType::*;
50        match (self, other) {
51            (Plain(a), Plain(b)) => a == b,
52            (BlockQuant(a), BlockQuant(b)) => a.same_as(&**b),
53            _ => false,
54        }
55    }
56}
57
58impl<BQ: BlockQuant> From<BQ> for WeightType {
59    fn from(value: BQ) -> Self {
60        WeightType::BlockQuant(dyn_clone::clone_box(&value))
61    }
62}
63
64impl Debug for WeightType {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        match self {
67            Self::Plain(p) => write!(f, "{p:?}"),
68            Self::BlockQuant(bq) => write!(f, "{bq:?}"),
69        }
70    }
71}
72
73impl WeightType {
74    pub fn as_dt(&self) -> Option<DatumType> {
75        match self {
76            WeightType::Plain(dt) => Some(*dt),
77            _ => None,
78        }
79    }
80}