use tract_linalg::frame::block_quant::{BlockQuant, Q4_0};
use crate::internal::*;
use crate::ops::einsum::codegen::{ensure_mkn_axes, AxesOrPatch};
use crate::ops::einsum::EinSum;
use crate::ops::konst::Const;
use crate::transform::ModelTransform;
#[derive(Clone, Hash)]
pub struct BlockQuantFact {
pub format: Box<dyn BlockQuant>,
pub shape: ShapeFact,
}
impl std::fmt::Debug for BlockQuantFact {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}({:?})", self.format, self.shape)
}
}
impl OpaqueFact for BlockQuantFact {}
#[derive(Clone, Hash)]
pub struct BlockQuantValue {
pub fact: BlockQuantFact,
pub value: Blob,
}
impl OpaquePayload for BlockQuantValue {}
impl std::fmt::Debug for BlockQuantValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?} {:?}", self.fact, self.value)
}
}
impl std::fmt::Display for BlockQuantValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[derive(Debug)]
pub struct BlockQuantTransform;
impl ModelTransform for BlockQuantTransform {
fn name(&self) -> Cow<str> {
"BlockQuantTransform".into()
}
fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
Rewriter::<()>::default()
.with_rule_for("block_quant_einsum_weights", block_quant_einsum_weights)
.rewrite(&(), model)
}
}
fn block_quant_einsum_weights(
_ctx: &(),
model: &TypedModel,
node: &TypedNode,
prefix: &str,
op: &EinSum,
) -> TractResult<Option<TypedModelPatch>> {
let &[a, b] = &*model.node_input_facts(node.id)? else { return Ok(None) };
if b.konst.is_some() {
let mut new_axes = op.axes.clone().with_extra_input(2)?;
for (ix, axis) in op.axes.axes(InOut::In(0)).enumerate() {
new_axes = new_axes.with_extra_axis_occurency(axis.repr, InOut::In(2), ix)?;
}
new_axes = new_axes.remove_slot(InOut::In(0))?;
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&[node.inputs[1], node.inputs[0]],
EinSum { axes: new_axes, ..op.clone() },
)?));
}
if a.konst.is_none() || a.rank() != 2 {
return Ok(None);
}
let AxesOrPatch::Axes(m, k, _n) = ensure_mkn_axes(op, model, node)? else { return Ok(None) };
if m.inputs[0][0] == 1 && k.inputs[0][0] == 0 {
let a: &Tensor = a.konst.as_ref().unwrap();
let mut patch = TypedModelPatch::default();
let konst =
patch.add_const(&model.node(node.inputs[0].node).name, a.clone().move_axis(1, 0)?)?;
let axes = op
.axes
.clone()
.with_extra_axis_occurency(k, InOut::In(0), 2)?
.remove_axis_occurency(InOut::In(0), 0)?;
let tap = patch.tap_model(model, node.inputs[1])?;
let output = patch.wire_node(prefix, EinSum { axes, ..op.clone() }, &[konst, tap])?;
patch.shunt_outside(model, node.id.into(), output[0])?;
return Ok(Some(patch));
}
let format = Q4_0;
let mut patch = TypedModelPatch::default();
let weights = if a.datum_type == f16::datum_type() {
format.quant_f16(a.konst.as_ref().unwrap().as_slice::<f16>()?)?
} else {
format.quant_f32(a.konst.as_ref().unwrap().cast_to::<f32>()?.as_slice::<f32>()?)?
};
let name = &model.node(node.inputs[0].node).name;
let fact = BlockQuantFact { format: Box::new(format), shape: a.shape.clone() };
let value = BlockQuantValue { fact: fact.clone(), value: weights };
let weights = patch.wire_node(
name,
Const::new_with_opaque_fact(rctensor0(Opaque(Arc::new(value))), Box::new(fact)),
&[],
)?;
let tap = patch.tap_model(model, node.inputs[1])?;
let wire = patch.wire_node(prefix, op.clone(), &[weights[0], tap])?;
patch.shunt_outside(model, node.id.into(), wire[0])?;
Ok(Some(patch))
}