tract_gpu/rewrite_rules/
mod.rs1use tract_core::model::{TypedModel, TypedNode};
2use tract_core::prelude::TVec;
3
4pub mod rewire_syncs;
5
6#[macro_export]
7macro_rules! rule_ensure {
8 ($cond:expr) => {
9 if !$cond {
10 return Ok(None);
11 }
12 };
13}
14
15pub fn next_node<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<&'a TypedNode> {
16 if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
17 return None;
18 }
19 let succ = node.outputs[0].successors[0];
20 Some(&model.nodes()[succ.node])
21}
22
23pub fn previous_node<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<&'a TypedNode> {
24 if node.inputs.len() != 1 {
25 return None;
26 }
27 Some(&model.nodes()[node.inputs[0].node])
28}
29
30pub fn previous_nodes<'a>(model: &'a TypedModel, node: &TypedNode) -> TVec<&'a TypedNode> {
31 node.inputs.iter().map(|n| &model.nodes()[n.node]).collect()
32}