tract_core/model/
rewriter.rs

1use std::any::TypeId;
2
3use crate::internal::*;
4
5type GenRewriteRule<Ctx> =
6    Box<dyn Fn(&Ctx, &TypedModel, &TypedNode) -> TractResult<Option<TypedModelPatch>>>;
7
8#[derive(Default)]
9#[allow(clippy::type_complexity)]
10pub struct Rewriter<Ctx> {
11    rules: HashMap<TypeId, Vec<(Cow<'static, str>, GenRewriteRule<Ctx>)>>,
12}
13
14impl<Ctx> Rewriter<Ctx> {
15    pub fn with_rule_for<O: Op + 'static>(
16        mut self,
17        name: impl Into<Cow<'static, str>>,
18        rule: impl Fn(&Ctx, &TypedModel, &TypedNode, &str, &O) -> TractResult<Option<TypedModelPatch>>
19            + 'static,
20    ) -> Self {
21        self.rules.entry(TypeId::of::<O>()).or_default().push((
22            name.into(),
23            Box::new(move |c: &Ctx, m: &TypedModel, n: &TypedNode| {
24                if let Some(o) = n.op_as::<O>() {
25                    rule(c, m, n, &n.name, o)
26                } else {
27                    Ok(None)
28                }
29            }),
30        ));
31        self
32    }
33
34    pub fn rewrite(&self, ctx: &Ctx, model: &mut TypedModel) -> TractResult<()> {
35        loop {
36            let mut done_anything = false;
37            for n in model.eval_order()? {
38                if let Some(rules) = self.rules.get(&(*model.node(n).op).type_id()) {
39                    for (name, rule) in rules {
40                        if let Some(patch) = (rule)(ctx, model, model.node(n))
41                            .with_context(|| format!("Evaluating rewriting rule \"{name}\" on node {}", model.node(n)))?
42                        {
43                            patch.apply(model).with_context(|| {
44                                format!("Applying patch for rewriting rule \"{name}\" on node {}", model.node(n))
45                            })?;
46                            done_anything = true;
47                        }
48                    }
49                }
50            }
51            if done_anything {
52                model.prop_consts()?;
53                model.compact()?;
54            } else {
55                return Ok(());
56            }
57        }
58    }
59}