Skip to main content

tract_core/model/
rewriter.rs

1use std::any::TypeId;
2
3use crate::internal::*;
4
5type GenRewriteRule<Ctx> =
6    Box<dyn Fn(&Ctx, &TypedModel, &TypedNode) -> TractResult<Option<TypedModelPatch>>>;
7
8#[derive(Default)]
9#[allow(clippy::type_complexity)]
10pub struct Rewriter<Ctx> {
11    rules: HashMap<TypeId, Vec<(Cow<'static, str>, GenRewriteRule<Ctx>)>>,
12}
13
14impl<Ctx> Rewriter<Ctx> {
15    pub fn with_rule_for<O: Op + 'static>(
16        mut self,
17        name: impl Into<Cow<'static, str>>,
18        rule: impl Fn(&Ctx, &TypedModel, &TypedNode, &str, &O) -> TractResult<Option<TypedModelPatch>>
19        + 'static,
20    ) -> Self {
21        self.rules.entry(TypeId::of::<O>()).or_default().push((
22            name.into(),
23            Box::new(move |c: &Ctx, m: &TypedModel, n: &TypedNode| {
24                if let Some(o) = n.op_as::<O>() { rule(c, m, n, &n.name, o) } else { Ok(None) }
25            }),
26        ));
27        self
28    }
29
30    pub fn rewrite(&self, ctx: &Ctx, model: &mut TypedModel) -> TractResult<()> {
31        loop {
32            let mut done_anything = false;
33            for n in model.eval_order()? {
34                if let Some(rules) = self.rules.get(&(*model.node(n).op).type_id()) {
35                    for (name, rule) in rules {
36                        if let Some(patch) =
37                            (rule)(ctx, model, model.node(n)).with_context(|| {
38                                format!(
39                                    "Evaluating rewriting rule \"{name}\" on node {}",
40                                    model.node(n)
41                                )
42                            })?
43                        {
44                            patch.apply(model).with_context(|| {
45                                format!(
46                                    "Applying patch for rewriting rule \"{name}\" on node {}",
47                                    model.node(n)
48                                )
49                            })?;
50                            done_anything = true;
51                        }
52                    }
53                }
54            }
55            if done_anything {
56                model.prop_consts()?;
57                model.compact()?;
58            } else {
59                return Ok(());
60            }
61        }
62    }
63}