tract_linalg/frame/
weights.rs1use std::fmt::Debug;
2use tract_data::prelude::DatumType;
3
4use crate::block_quant::{BlockQuant, PackedBlockQuantFormat};
5
6use crate::mmm::MMMInputFormat;
7use crate::pack::PackedFormat;
8
9#[derive(Clone)]
10pub enum WeightType {
11 Plain(DatumType),
12 BlockQuant(Box<dyn BlockQuant>),
13}
14
15impl From<DatumType> for WeightType {
16 fn from(value: DatumType) -> Self {
17 match value {
18 DatumType::F16 => WeightType::Plain(DatumType::F16),
19 DatumType::F32 => WeightType::Plain(DatumType::F32),
20 DatumType::F64 => WeightType::Plain(DatumType::F64),
21 DatumType::I32 => WeightType::Plain(DatumType::I32),
22 DatumType::I8 | DatumType::QI8(_) => WeightType::Plain(DatumType::I8),
23 DatumType::U8 | DatumType::QU8(_) => WeightType::Plain(DatumType::U8),
24 _ => panic!("Can't build a WeightType from {value:?}"),
25 }
26 }
27}
28
29impl From<Box<dyn MMMInputFormat>> for WeightType {
30 fn from(value: Box<dyn MMMInputFormat>) -> Self {
31 (&*value).into()
32 }
33}
34
35impl From<&dyn MMMInputFormat> for WeightType {
36 fn from(value: &dyn MMMInputFormat) -> Self {
37 if let Some(pf) = value.downcast_ref::<PackedFormat>() {
38 WeightType::Plain(pf.dt)
39 } else if let Some(pbqf) = value.downcast_ref::<PackedBlockQuantFormat>() {
40 WeightType::BlockQuant(dyn_clone::clone_box(&*pbqf.bq))
41 } else {
42 todo!()
43 }
44 }
45}
46
47impl PartialEq for WeightType {
48 fn eq(&self, other: &Self) -> bool {
49 use WeightType::*;
50 match (self, other) {
51 (Plain(a), Plain(b)) => a == b,
52 (BlockQuant(a), BlockQuant(b)) => a.same_as(&**b),
53 _ => false,
54 }
55 }
56}
57
58impl<BQ: BlockQuant> From<BQ> for WeightType {
59 fn from(value: BQ) -> Self {
60 WeightType::BlockQuant(dyn_clone::clone_box(&value))
61 }
62}
63
64impl Debug for WeightType {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 match self {
67 Self::Plain(p) => write!(f, "{p:?}"),
68 Self::BlockQuant(bq) => write!(f, "{bq:?}"),
69 }
70 }
71}
72
73impl WeightType {
74 pub fn as_dt(&self) -> Option<DatumType> {
75 match self {
76 WeightType::Plain(dt) => Some(*dt),
77 _ => None,
78 }
79 }
80}