tract_core/ops/matmul/
de_block_quant.rs1use tract_linalg::block_quant::{BlockQuant, BlockQuantFact, Q4_0};
2
3use crate::internal::*;
4use crate::ops::einsum::einsum_matmul::EinSumMatMul;
5use crate::ops::konst::Const;
6use crate::transform::ModelTransform;
7
8#[derive(Debug)]
9pub struct BlockQuantTransform;
10
11impl ModelTransform for BlockQuantTransform {
12 fn name(&self) -> StaticName {
13 "BlockQuantTransform".into()
14 }
15
16 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
17 crate::ops::einsum::einsum_matmul::detect_all(model)?;
18 Rewriter::<()>::default()
19 .with_rule_for("block_quant_einsum_weights", block_quant_einsum_weights)
20 .rewrite(&(), model)?;
21 crate::ops::einsum::einsum_matmul::flatten_all(model)?;
22 Ok(())
23 }
24}
25
26fn block_quant_einsum_weights(
27 _ctx: &(),
28 model: &TypedModel,
29 node: &TypedNode,
30 prefix: &str,
31 op: &EinSumMatMul,
32) -> TractResult<Option<TypedModelPatch>> {
33 rule_if!(node.inputs.len() == 2);
34 for (slot, fact) in model.node_input_facts(node.id)?.iter().enumerate() {
35 let Some(a) = fact.konst.as_ref() else { continue };
36 if a.rank() != 2 {
37 continue;
38 };
39 if op.k_axis().inputs[slot][0] == 0 {
40 let mut patch = TypedModelPatch::default();
41 let mut taps = patch.taps(model, &node.inputs)?;
42 taps[slot] = patch.wire_node(
43 format!("{}.t_{}", &node.name, slot),
44 AxisOp::Move(1, 0),
45 &[taps[slot]],
46 )?[0];
47 let mut new_op = op.clone();
48 new_op.op.axes = op
49 .op
50 .axes
51 .clone()
52 .remove_axis_occurency(InOut::In(slot), 0)?
53 .with_extra_axis_occurency(op.k_axis, InOut::In(slot), 1)?;
54 let output = patch.wire_node(prefix, new_op, &taps)?;
55 patch.shunt_outside(model, node.id.into(), output[0])?;
56 return Ok(Some(patch));
57 }
58 let format = Q4_0;
59 let mut patch = TypedModelPatch::default();
60 let weights = if a.datum_type() == f16::datum_type() {
61 format.quant_f16(a.as_slice::<f16>()?)?
62 } else {
63 format.quant_f32(a.cast_to::<f32>()?.as_slice::<f32>()?)?
64 };
65 let name = &model.node(node.inputs[0].node).name;
66 let fact = BlockQuantFact::new(Box::new(format), a.shape().into());
67 let value = BlobWithFact { fact: Box::new(fact.clone()), value: Arc::new(weights) };
68 let weights = patch.wire_node(
69 format!("{name}.bq"),
70 Const::new_with_opaque_fact(rctensor0(Opaque(Arc::new(value))), Box::new(fact))?,
71 &[],
72 )?;
73 let tap = patch.tap_model(model, node.inputs[1])?;
74 let wire = patch.wire_node(prefix, op.op.clone(), &[weights[0], tap])?;
75 patch.shunt_outside(model, node.id.into(), wire[0])?;
76 return Ok(Some(patch));
77 }
78 Ok(None)
79}