tract_core/model/
rewriter.rs1use std::any::TypeId;
2
3use crate::internal::*;
4
5type GenRewriteRule<Ctx> =
6 Box<dyn Fn(&Ctx, &TypedModel, &TypedNode) -> TractResult<Option<TypedModelPatch>>>;
7
8#[derive(Default)]
9#[allow(clippy::type_complexity)]
10pub struct Rewriter<Ctx> {
11 rules: HashMap<TypeId, Vec<(Cow<'static, str>, GenRewriteRule<Ctx>)>>,
12}
13
14impl<Ctx> Rewriter<Ctx> {
15 pub fn with_rule_for<O: Op + 'static>(
16 mut self,
17 name: impl Into<Cow<'static, str>>,
18 rule: impl Fn(&Ctx, &TypedModel, &TypedNode, &str, &O) -> TractResult<Option<TypedModelPatch>>
19 + 'static,
20 ) -> Self {
21 self.rules.entry(TypeId::of::<O>()).or_default().push((
22 name.into(),
23 Box::new(move |c: &Ctx, m: &TypedModel, n: &TypedNode| {
24 if let Some(o) = n.op_as::<O>() { rule(c, m, n, &n.name, o) } else { Ok(None) }
25 }),
26 ));
27 self
28 }
29
30 pub fn rewrite(&self, ctx: &Ctx, model: &mut TypedModel) -> TractResult<()> {
31 loop {
32 let mut done_anything = false;
33 for n in model.eval_order()? {
34 if let Some(rules) = self.rules.get(&(*model.node(n).op).type_id()) {
35 for (name, rule) in rules {
36 if let Some(patch) =
37 (rule)(ctx, model, model.node(n)).with_context(|| {
38 format!(
39 "Evaluating rewriting rule \"{name}\" on node {}",
40 model.node(n)
41 )
42 })?
43 {
44 patch.apply(model).with_context(|| {
45 format!(
46 "Applying patch for rewriting rule \"{name}\" on node {}",
47 model.node(n)
48 )
49 })?;
50 done_anything = true;
51 }
52 }
53 }
54 }
55 if done_anything {
56 model.prop_consts()?;
57 model.compact()?;
58 } else {
59 return Ok(());
60 }
61 }
62 }
63}