Skip to main content

tract_transformers/ops/
mod.rs

1pub mod apply_rope;
2pub mod dyn_kv_cache;
3pub mod gelu_approximate;
4pub mod rms_norm;
5pub mod scaled_masked_softmax;
6pub mod sdpa;
7pub mod silu;
8
9use tract_core::internal::*;
10use tract_core::ops::konst::Const;
11use tract_nnef::tract_core;
12
13pub use apply_rope::{apply_rope_rule, rotate_half_rule};
14pub use dyn_kv_cache::replace_kv_cache;
15pub use gelu_approximate::gelu_approx_rule;
16pub use rms_norm::rms_norm_rule;
17pub use scaled_masked_softmax::scaled_masked_softmax_rule;
18pub use sdpa::fuse_kv_cache_broadcast_rule;
19pub use silu::silu_rule;
20
21use tract_core::ops::binary::TypedBinOp;
22use tract_core::ops::math::{Add, Mul};
23
24#[macro_export]
25macro_rules! rule_ensure {
26    ($cond:expr) => {
27        if !$cond {
28            return Ok(None);
29        }
30    };
31}
32
33fn next_node<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<&'a TypedNode> {
34    if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
35        return None;
36    }
37    let succ = node.outputs[0].successors[0];
38    Some(&model.nodes()[succ.node])
39}
40
41fn previous_node<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<&'a TypedNode> {
42    if node.inputs.len() != 1 {
43        return None;
44    }
45    Some(&model.nodes()[node.inputs[0].node])
46}
47
48fn previous_nodes<'a>(model: &'a TypedModel, node: &TypedNode) -> TVec<&'a TypedNode> {
49    node.inputs.iter().map(|n| &model.nodes()[n.node]).collect()
50}
51
52fn collect_node_const_inputs<'a>(model: &'a TypedModel, node: &TypedNode) -> TVec<&'a Const> {
53    node.inputs
54        .iter()
55        .filter_map(|i| {
56            let prec = &model.nodes()[i.node];
57            prec.op_as::<Const>()
58        })
59        .collect::<TVec<_>>()
60}
61
62fn single_prev_node_as<'a, O: TypedOp>(
63    model: &'a TypedModel,
64    node: &TypedNode,
65) -> Option<(usize, &'a TypedNode)> {
66    let prev_nodes = node
67        .inputs
68        .iter()
69        .enumerate()
70        .filter_map(|(in_idx, i)| {
71            let prec = &model.nodes()[i.node];
72            prec.op_is::<O>().then_some((in_idx, prec))
73        })
74        .collect::<TVec<_>>();
75
76    if prev_nodes.len() != 1 {
77        None
78    } else {
79        Some(prev_nodes[0])
80    }
81}
82
83fn find_succ_mul_with_const<'a>(
84    model: &'a TypedModel,
85    node: &'a TypedNode,
86    konst: f32,
87) -> Option<&'a TypedNode> {
88    let mul_coef_a = next_node(model, node)?;
89    let mul_coef_a_op = mul_coef_a.op_as::<TypedBinOp>()?;
90    (mul_coef_a_op.0.is::<Mul>() && matches_single_input_const(model, mul_coef_a, konst))
91        .then_some(mul_coef_a)
92}
93
94fn find_succ_add_with<'a>(
95    model: &'a TypedModel,
96    node: &'a TypedNode,
97    outled_id: &OutletId,
98) -> Option<&'a TypedNode> {
99    let add_succ = next_node(model, node)?;
100    let add_succ_op = add_succ.op_as::<TypedBinOp>()?;
101    (add_succ_op.0.is::<Add>() && add_succ.inputs.contains(outled_id)).then_some(add_succ)
102}
103
104fn matches_single_input_const(model: &TypedModel, node: &TypedNode, konst: f32) -> bool {
105    let consts = collect_node_const_inputs(model, node);
106    if consts.len() != 1 {
107        return false;
108    }
109    let Ok(in_const) = consts[0].val().cast_to_dt(DatumType::F32) else {
110        return false;
111    };
112    let Ok(in_const) = in_const.to_scalar_tensor() else {
113        return false;
114    };
115
116    in_const.close_enough(&tensor0(konst), Approximation::Approximate).is_ok()
117}
118
119fn find_succ_add_with_const<'a>(
120    model: &'a TypedModel,
121    node: &'a TypedNode,
122    konst: f32,
123) -> Option<&'a TypedNode> {
124    let add_coef_a = next_node(model, node)?;
125    let add_coef_a_op = add_coef_a.op_as::<TypedBinOp>()?;
126    if !add_coef_a_op.0.is::<Add>() {
127        return None;
128    }
129    (add_coef_a_op.0.is::<Add>() && matches_single_input_const(model, add_coef_a, konst))
130        .then_some(add_coef_a)
131}