Skip to main content

tract_core/optim/
mod.rs

1use crate::internal::*;
2use std::collections::HashSet;
3use std::fmt::Debug;
4use tract_itertools::Itertools;
5
6pub mod change_axes;
7mod concat_then_einsum;
8mod op_optim;
9mod prop_const;
10mod push_split_down;
11mod slice;
12
13use self::change_axes::ChangeAxes;
14use self::prop_const::PropConst;
15use self::push_split_down::PushSplitDown;
16use self::slice::PushSliceUp;
17use op_optim::OpOptim;
18
19pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
20    fn reset(&mut self) -> TractResult<()>;
21    fn next(
22        &mut self,
23        session: &mut OptimizerSession,
24        model: &TypedModel,
25    ) -> TractResult<Option<TypedModelPatch>>;
26}
27
28dyn_clone::clone_trait_object!(TypedPass);
29
30#[derive(Debug)]
31pub struct Optimizer {
32    pub passes: Vec<Box<dyn TypedPass>>,
33    pub steps: Option<usize>,
34}
35
36impl Optimizer {
37    fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
38        Optimizer { passes, steps: None }
39    }
40
41    pub fn add_pass(&mut self, idx: usize, pass: Box<dyn TypedPass>) {
42        let num_pass = self.passes.len();
43        if idx > num_pass {
44            log::warn!(
45                "Cannot add new pass {pass:?} at index {idx}. Optimizer currently as {num_pass} passes, pass will be added as the last pass."
46            );
47            self.passes.push(pass);
48        } else {
49            self.passes.insert(idx, pass);
50        }
51    }
52
53    pub fn stopping_at(self, steps: usize) -> Optimizer {
54        Optimizer { steps: Some(steps), ..self }
55    }
56
57    pub fn prop_consts() -> Optimizer {
58        Optimizer::passes(vec![Box::<PropConst>::default()])
59    }
60
61    pub fn declutter() -> Optimizer {
62        Optimizer::passes(vec![
63            Box::<PropConst>::default(),
64            Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
65            Box::new(PushSliceUp),
66            Box::new(PushSplitDown),
67            Box::<concat_then_einsum::ConcatThenEinsum>::default(),
68            Box::<ChangeAxes>::default(),
69        ])
70    }
71
72    pub fn codegen() -> Optimizer {
73        Optimizer::passes(vec![
74            Box::<PropConst>::default(),
75            Box::new(OpOptim(
76                "codegen",
77                |op, _session, model, node| TypedOp::codegen(op, model, node),
78                0,
79            )),
80            Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
81            Box::new(PushSplitDown),
82            Box::new(OpOptim(
83                "fuse",
84                |op, _session, model, node| TypedOp::fuse(op, model, node),
85                0,
86            )),
87        ])
88    }
89
90    pub fn optimize(&self, model: &mut TypedModel) -> TractResult<()> {
91        self.session().optimize(model)
92    }
93
94    pub fn session(&self) -> OptimizerSession<'_> {
95        OptimizerSession { optimizer: self, counter: 0, seen: Default::default() }
96    }
97}
98
99#[derive(Debug)]
100pub struct OptimizerSession<'o> {
101    optimizer: &'o Optimizer,
102    counter: usize,
103    seen: HashSet<String>,
104}
105
106impl OptimizerSession<'_> {
107    pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
108        let _proof_session = model.symbols.proof_cache_session();
109        model.check_consistency().context("during optimizer preflight check")?;
110        model.compact().context("during optimizer preflight compaction")?;
111        model.check_names().context("after optimizer preflight compaction")?;
112        for i in 0.. {
113            let old = self.counter;
114            self.run_all_passes(i, model)?;
115            if old == self.counter {
116                return Ok(());
117            }
118            model.compact()?;
119        }
120        unreachable!()
121    }
122
123    pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
124        let mut passes = self.optimizer.passes.clone();
125        for p in passes.iter_mut() {
126            self.run_one_pass_outer(i, p.as_mut(), model)
127                .with_context(|| format!("running pass {p:?}"))?;
128            model.compact()?;
129            model
130                .check_consistency()
131                .with_context(|| format!("consistency check after pass {p:?}"))?;
132        }
133        Ok(())
134    }
135
136    pub fn run_one_pass_outer(
137        &mut self,
138        i: usize,
139        p: &mut dyn TypedPass,
140        model: &mut TypedModel,
141    ) -> TractResult<()> {
142        loop {
143            let old_counter = self.counter;
144            self.run_one_pass_inner(i, p, model)?;
145            if self.counter == old_counter {
146                return Ok(());
147            }
148            model.compact().with_context(|| format!("after pass {p:?}"))?;
149        }
150    }
151
152    pub fn run_one_pass_inner(
153        &mut self,
154        i: usize,
155        p: &mut dyn TypedPass,
156        model: &mut TypedModel,
157    ) -> TractResult<()> {
158        p.reset()?;
159        if let Some(steps) = self.optimizer.steps {
160            if self.counter >= steps {
161                return Ok(());
162            }
163        }
164        while let Some(mut patch) = p.next(self, model)? {
165            patch.push_context(format!("{p:?}/{i}"));
166            patch.model.check_consistency().context("checking patch internal consistency")?;
167            model
168                .check_consistency()
169                .context("Checking target model consistency before patching")?;
170            if let Some(watchdog) = patch.dont_apply_twice.take() {
171                if self.seen.contains(&watchdog) {
172                    debug!("Loop detected: {watchdog} seen before");
173                    continue;
174                } else {
175                    self.seen.insert(watchdog);
176                }
177            }
178            let patch_name = patch.context.iter().rev().join(" >> ");
179            debug!("applying patch #{}: {patch_name}", self.counter);
180            patch.apply(model).with_context(|| format!("Applying patch {patch_name}"))?;
181            model
182                .check_consistency()
183                .context("Checking target model consistency after patching")?;
184            self.counter += 1;
185            if let Some(steps) = self.optimizer.steps {
186                if self.counter >= steps {
187                    return Ok(());
188                }
189            }
190        }
191        model.check_consistency().with_context(|| format!("after pass {p:?}"))?;
192        Ok(())
193    }
194}