tract_core/model/
rewriter.rs1use std::any::TypeId;
2
3use crate::internal::*;
4
5type GenRewriteRule<Ctx> =
6 Box<dyn Fn(&Ctx, &TypedModel, &TypedNode) -> TractResult<Option<TypedModelPatch>>>;
7
8#[derive(Default)]
9#[allow(clippy::type_complexity)]
10pub struct Rewriter<Ctx> {
11 rules: HashMap<TypeId, Vec<(Cow<'static, str>, GenRewriteRule<Ctx>)>>,
12}
13
14impl<Ctx> Rewriter<Ctx> {
15 pub fn with_rule_for<O: Op + 'static>(
16 mut self,
17 name: impl Into<Cow<'static, str>>,
18 rule: impl Fn(&Ctx, &TypedModel, &TypedNode, &str, &O) -> TractResult<Option<TypedModelPatch>>
19 + 'static,
20 ) -> Self {
21 self.rules.entry(TypeId::of::<O>()).or_default().push((
22 name.into(),
23 Box::new(move |c: &Ctx, m: &TypedModel, n: &TypedNode| {
24 if let Some(o) = n.op_as::<O>() {
25 rule(c, m, n, &n.name, o)
26 } else {
27 Ok(None)
28 }
29 }),
30 ));
31 self
32 }
33
34 pub fn rewrite(&self, ctx: &Ctx, model: &mut TypedModel) -> TractResult<()> {
35 loop {
36 let mut done_anything = false;
37 for n in model.eval_order()? {
38 if let Some(rules) = self.rules.get(&(*model.node(n).op).type_id()) {
39 for (name, rule) in rules {
40 if let Some(patch) = (rule)(ctx, model, model.node(n))
41 .with_context(|| format!("Evaluating rewriting rule \"{name}\" on node {}", model.node(n)))?
42 {
43 patch.apply(model).with_context(|| {
44 format!("Applying patch for rewriting rule \"{name}\" on node {}", model.node(n))
45 })?;
46 done_anything = true;
47 }
48 }
49 }
50 }
51 if done_anything {
52 model.prop_consts()?;
53 model.compact()?;
54 } else {
55 return Ok(());
56 }
57 }
58 }
59}