Skip to main content

tract_core/ops/matmul/
de_block_quant.rs

1use tract_linalg::block_quant::{BlockQuant, BlockQuantFact, BlockQuantStorage, Q4_0};
2
3use crate::internal::*;
4use crate::ops::einsum::einsum_matmul::EinSumMatMul;
5use crate::ops::konst::Const;
6use crate::transform::ModelTransform;
7
8#[derive(Debug)]
9pub struct BlockQuantTransform;
10
11impl ModelTransform for BlockQuantTransform {
12    fn name(&self) -> StaticName {
13        "block_quant".into()
14    }
15
16    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
17        crate::ops::einsum::einsum_matmul::detect_all(model)?;
18        Rewriter::<()>::default()
19            .with_rule_for("block_quant_einsum_weights", block_quant_einsum_weights)
20            .rewrite(&(), model)?;
21        crate::ops::einsum::einsum_matmul::flatten_all(model)?;
22        Ok(())
23    }
24}
25
26fn block_quant_einsum_weights(
27    _ctx: &(),
28    model: &TypedModel,
29    node: &TypedNode,
30    prefix: &str,
31    op: &EinSumMatMul,
32) -> TractResult<Option<TypedModelPatch>> {
33    rule_if!(node.inputs.len() == 2);
34    for (slot, fact) in model.node_input_facts(node.id)?.iter().enumerate() {
35        let Some(a) = fact.konst.as_ref() else { continue };
36        if a.rank() != 2 {
37            continue;
38        };
39        if op.k_axis().inputs[slot][0] == 0 {
40            let mut patch = TypedModelPatch::default();
41            let mut taps = patch.taps(model, &node.inputs)?;
42            taps[slot] = patch.wire_node(
43                format!("{}.t_{}", &node.name, slot),
44                AxisOp::Move(1, 0),
45                &[taps[slot]],
46            )?[0];
47            let mut new_op = op.clone();
48            new_op.op.axes = op
49                .op
50                .axes
51                .clone()
52                .remove_axis_occurency(InOut::In(slot), 0)?
53                .with_extra_axis_occurency(op.k_axis, InOut::In(slot), 1)?;
54            let output = patch.wire_node(prefix, new_op, &taps)?;
55            patch.shunt_outside(model, node.id.into(), output[0])?;
56            return Ok(Some(patch));
57        }
58        let format = Q4_0;
59        let mut patch = TypedModelPatch::default();
60        let weights = if a.datum_type() == f16::datum_type() {
61            format.quant_f16(a.try_as_plain()?.as_slice::<f16>()?)?
62        } else {
63            format.quant_f32(a.cast_to::<f32>()?.try_as_plain()?.as_slice::<f32>()?)?
64        };
65        let name = &model.node(node.inputs[0].node).name;
66        let m = a.shape()[0];
67        let k = a.shape()[1];
68        let bqs = BlockQuantStorage::new(Box::new(format), m, k, Arc::new(weights))?;
69        let fact =
70            Box::new(BlockQuantFact::new(dyn_clone::clone_box(bqs.format()), tvec!(1, m, k)));
71        let weights = patch.wire_node(
72            format!("{name}.bq"),
73            Const::new_with_exotic_fact(
74                Arc::new(bqs.into_tensor_with_shape(a.datum_type(), &[1, m, k])),
75                fact,
76            )?,
77            &[],
78        )?;
79        let tap = patch.tap_model(model, node.inputs[1])?;
80        // Block-quant tensor is rank 3 [G=1, M, K]; add a group dim to the axes
81        let mut new_op = op.op.clone();
82        new_op.axes = new_op.axes.with_extra_axis('G', InOut::In(slot), 0)?;
83        let wire = patch.wire_node(prefix, new_op, &[weights[0], tap])?;
84        patch.shunt_outside(model, node.id.into(), wire[0])?;
85        return Ok(Some(patch));
86    }
87    Ok(None)
88}