tract_transformers/ops/
mod.rs1pub mod apply_rope;
2pub mod dyn_kv_cache;
3pub mod gelu_approximate;
4pub mod rms_norm;
5pub mod scaled_masked_softmax;
6pub mod sdpa;
7pub mod silu;
8
9use tract_core::internal::*;
10use tract_core::ops::konst::Const;
11use tract_nnef::tract_core;
12
13pub use apply_rope::{apply_rope_rule, rotate_half_rule};
14pub use dyn_kv_cache::replace_kv_cache;
15pub use gelu_approximate::gelu_approx_rule;
16pub use rms_norm::rms_norm_rule;
17pub use scaled_masked_softmax::scaled_masked_softmax_rule;
18pub use sdpa::fuse_kv_cache_broadcast_rule;
19pub use silu::silu_rule;
20
21use tract_core::ops::binary::TypedBinOp;
22use tract_core::ops::math::{Add, Mul};
23
24#[macro_export]
25macro_rules! rule_ensure {
26 ($cond:expr) => {
27 if !$cond {
28 return Ok(None);
29 }
30 };
31}
32
33fn next_node<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<&'a TypedNode> {
34 if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
35 return None;
36 }
37 let succ = node.outputs[0].successors[0];
38 Some(&model.nodes()[succ.node])
39}
40
41fn previous_node<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<&'a TypedNode> {
42 if node.inputs.len() != 1 {
43 return None;
44 }
45 Some(&model.nodes()[node.inputs[0].node])
46}
47
48fn previous_nodes<'a>(model: &'a TypedModel, node: &TypedNode) -> TVec<&'a TypedNode> {
49 node.inputs.iter().map(|n| &model.nodes()[n.node]).collect()
50}
51
52fn collect_node_const_inputs<'a>(model: &'a TypedModel, node: &TypedNode) -> TVec<&'a Const> {
53 node.inputs
54 .iter()
55 .filter_map(|i| {
56 let prec = &model.nodes()[i.node];
57 prec.op_as::<Const>()
58 })
59 .collect::<TVec<_>>()
60}
61
62fn single_prev_node_as<'a, O: TypedOp>(
63 model: &'a TypedModel,
64 node: &TypedNode,
65) -> Option<(usize, &'a TypedNode)> {
66 let prev_nodes = node
67 .inputs
68 .iter()
69 .enumerate()
70 .filter_map(|(in_idx, i)| {
71 let prec = &model.nodes()[i.node];
72 prec.op_is::<O>().then_some((in_idx, prec))
73 })
74 .collect::<TVec<_>>();
75
76 if prev_nodes.len() != 1 {
77 None
78 } else {
79 Some(prev_nodes[0])
80 }
81}
82
83fn find_succ_mul_with_const<'a>(
84 model: &'a TypedModel,
85 node: &'a TypedNode,
86 konst: f32,
87) -> Option<&'a TypedNode> {
88 let mul_coef_a = next_node(model, node)?;
89 let mul_coef_a_op = mul_coef_a.op_as::<TypedBinOp>()?;
90 (mul_coef_a_op.0.is::<Mul>() && matches_single_input_const(model, mul_coef_a, konst))
91 .then_some(mul_coef_a)
92}
93
94fn find_succ_add_with<'a>(
95 model: &'a TypedModel,
96 node: &'a TypedNode,
97 outled_id: &OutletId,
98) -> Option<&'a TypedNode> {
99 let add_succ = next_node(model, node)?;
100 let add_succ_op = add_succ.op_as::<TypedBinOp>()?;
101 (add_succ_op.0.is::<Add>() && add_succ.inputs.contains(outled_id)).then_some(add_succ)
102}
103
104fn matches_single_input_const(model: &TypedModel, node: &TypedNode, konst: f32) -> bool {
105 let consts = collect_node_const_inputs(model, node);
106 if consts.len() != 1 {
107 return false;
108 }
109 let Ok(in_const) = consts[0].val().cast_to_dt(DatumType::F32) else {
110 return false;
111 };
112 let Ok(in_const) = in_const.to_scalar_tensor() else {
113 return false;
114 };
115
116 in_const.close_enough(&tensor0(konst), Approximation::Approximate).is_ok()
117}
118
119fn find_succ_add_with_const<'a>(
120 model: &'a TypedModel,
121 node: &'a TypedNode,
122 konst: f32,
123) -> Option<&'a TypedNode> {
124 let add_coef_a = next_node(model, node)?;
125 let add_coef_a_op = add_coef_a.op_as::<TypedBinOp>()?;
126 if !add_coef_a_op.0.is::<Add>() {
127 return None;
128 }
129 (add_coef_a_op.0.is::<Add>() && matches_single_input_const(model, add_coef_a, konst))
130 .then_some(add_coef_a)
131}