Skip to main content

tract_gpu/rewrite_rules/
mod.rs

1use tract_core::model::{TypedModel, TypedNode};
2use tract_core::prelude::TVec;
3
4pub mod rewire_syncs;
5
6#[macro_export]
7macro_rules! rule_ensure {
8    ($cond:expr) => {
9        if !$cond {
10            return Ok(None);
11        }
12    };
13}
14
15pub fn next_node<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<&'a TypedNode> {
16    if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
17        return None;
18    }
19    let succ = node.outputs[0].successors[0];
20    Some(&model.nodes()[succ.node])
21}
22
23pub fn previous_node<'a>(model: &'a TypedModel, node: &TypedNode) -> Option<&'a TypedNode> {
24    if node.inputs.len() != 1 {
25        return None;
26    }
27    Some(&model.nodes()[node.inputs[0].node])
28}
29
30pub fn previous_nodes<'a>(model: &'a TypedModel, node: &TypedNode) -> TVec<&'a TypedNode> {
31    node.inputs.iter().map(|n| &model.nodes()[n.node]).collect()
32}