1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
use tract_linalg::frame::block_quant::{BlockQuant, Q4_0};

use crate::internal::*;
use crate::ops::einsum::codegen::{ensure_mkn_axes, AxesOrPatch};
use crate::ops::einsum::EinSum;
use crate::ops::konst::Const;
use crate::transform::ModelTransform;

#[derive(Clone, Hash)]
pub struct BlockQuantFact {
    pub format: Box<dyn BlockQuant>,
    pub shape: ShapeFact,
}

impl std::fmt::Debug for BlockQuantFact {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}({:?})", self.format, self.shape)
    }
}

impl OpaqueFact for BlockQuantFact {}

#[derive(Clone, Hash)]
pub struct BlockQuantValue {
    pub fact: BlockQuantFact,
    pub value: Blob,
}

impl OpaquePayload for BlockQuantValue {}

impl std::fmt::Debug for BlockQuantValue {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{:?} {:?}", self.fact, self.value)
    }
}

impl std::fmt::Display for BlockQuantValue {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{self:?}")
    }
}

#[derive(Debug)]
pub struct BlockQuantTransform;

impl ModelTransform for BlockQuantTransform {
    fn name(&self) -> Cow<str> {
        "BlockQuantTransform".into()
    }

    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
        Rewriter::<()>::default()
            .with_rule_for("block_quant_einsum_weights", block_quant_einsum_weights)
            .rewrite(&(), model)
    }
}

fn block_quant_einsum_weights(
    _ctx: &(),
    model: &TypedModel,
    node: &TypedNode,
    prefix: &str,
    op: &EinSum,
) -> TractResult<Option<TypedModelPatch>> {
    let &[a, b] = &*model.node_input_facts(node.id)? else { return Ok(None) };
    if b.konst.is_some() {
        let mut new_axes = op.axes.clone().with_extra_input(2)?;
        for (ix, axis) in op.axes.axes(InOut::In(0)).enumerate() {
            new_axes = new_axes.with_extra_axis_occurency(axis.repr, InOut::In(2), ix)?;
        }
        new_axes = new_axes.remove_slot(InOut::In(0))?;
        return Ok(Some(TypedModelPatch::replace_single_op(
            model,
            node,
            &[node.inputs[1], node.inputs[0]],
            EinSum { axes: new_axes, ..op.clone() },
        )?));
    }
    if a.konst.is_none() || a.rank() != 2 {
        return Ok(None);
    }
    let AxesOrPatch::Axes(m, k, _n) = ensure_mkn_axes(op, model, node)? else { return Ok(None) };
    if m.inputs[0][0] == 1 && k.inputs[0][0] == 0 {
        let a: &Tensor = a.konst.as_ref().unwrap();
        let mut patch = TypedModelPatch::default();
        let konst =
            patch.add_const(&model.node(node.inputs[0].node).name, a.clone().move_axis(1, 0)?)?;
        let axes = op
            .axes
            .clone()
            .with_extra_axis_occurency(k, InOut::In(0), 2)?
            .remove_axis_occurency(InOut::In(0), 0)?;
        let tap = patch.tap_model(model, node.inputs[1])?;
        let output = patch.wire_node(prefix, EinSum { axes, ..op.clone() }, &[konst, tap])?;
        patch.shunt_outside(model, node.id.into(), output[0])?;
        return Ok(Some(patch));
    }
    let format = Q4_0;
    let mut patch = TypedModelPatch::default();
    let weights = if a.datum_type == f16::datum_type() {
        format.quant_f16(a.konst.as_ref().unwrap().as_slice::<f16>()?)?
    } else {
        format.quant_f32(a.konst.as_ref().unwrap().cast_to::<f32>()?.as_slice::<f32>()?)?
    };
    let name = &model.node(node.inputs[0].node).name;
    let fact = BlockQuantFact { format: Box::new(format), shape: a.shape.clone() };
    let value = BlockQuantValue { fact: fact.clone(), value: weights };
    let weights = patch.wire_node(
        name,
        Const::new_with_opaque_fact(rctensor0(Opaque(Arc::new(value))), Box::new(fact)),
        &[],
    )?;
    let tap = patch.tap_model(model, node.inputs[1])?;
    let wire = patch.wire_node(prefix, op.clone(), &[weights[0], tap])?;
    patch.shunt_outside(model, node.id.into(), wire[0])?;
    Ok(Some(patch))
}