Skip to main content

tract_core/optim/
mod.rs

1use crate::internal::*;
2use std::collections::HashSet;
3use std::fmt::Debug;
4use tract_itertools::Itertools;
5
6pub mod change_axes;
7mod concat_then_einsum;
8mod op_optim;
9mod prop_const;
10mod push_split_down;
11mod slice;
12
13use self::change_axes::ChangeAxes;
14use self::prop_const::PropConst;
15use self::push_split_down::PushSplitDown;
16use self::slice::PushSliceUp;
17use op_optim::OpOptim;
18
19pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
20    fn reset(&mut self) -> TractResult<()>;
21    fn next(
22        &mut self,
23        session: &mut OptimizerSession,
24        model: &TypedModel,
25    ) -> TractResult<Option<TypedModelPatch>>;
26}
27
28dyn_clone::clone_trait_object!(TypedPass);
29
30#[derive(Debug)]
31pub struct Optimizer {
32    pub passes: Vec<Box<dyn TypedPass>>,
33    pub steps: Option<usize>,
34}
35
36impl Optimizer {
37    fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
38        Optimizer { passes, steps: None }
39    }
40
41    pub fn add_pass(&mut self, idx: usize, pass: Box<dyn TypedPass>) {
42        let num_pass = self.passes.len();
43        if idx > num_pass {
44            log::warn!(
45                "Cannot add new pass {pass:?} at index {idx}. Optimizer currently as {num_pass} passes, pass will be added as the last pass."
46            );
47            self.passes.push(pass);
48        } else {
49            self.passes.insert(idx, pass);
50        }
51    }
52
53    pub fn stopping_at(self, steps: usize) -> Optimizer {
54        Optimizer { steps: Some(steps), ..self }
55    }
56
57    pub fn prop_consts() -> Optimizer {
58        Optimizer::passes(vec![Box::<PropConst>::default()])
59    }
60
61    pub fn declutter() -> Optimizer {
62        Optimizer::passes(vec![
63            Box::<PropConst>::default(),
64            Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
65            Box::new(PushSliceUp),
66            Box::new(PushSplitDown),
67            Box::<concat_then_einsum::ConcatThenEinsum>::default(),
68            Box::<ChangeAxes>::default(),
69        ])
70    }
71
72    pub fn codegen() -> Optimizer {
73        Optimizer::passes(vec![
74            Box::<PropConst>::default(),
75            Box::new(OpOptim(
76                "codegen",
77                |op, _session, model, node| TypedOp::codegen(op, model, node),
78                0,
79            )),
80            Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
81            Box::new(PushSplitDown),
82            Box::new(OpOptim(
83                "fuse",
84                |op, _session, model, node| TypedOp::fuse(op, model, node),
85                0,
86            )),
87        ])
88    }
89
90    pub fn optimize(&self, model: &mut TypedModel) -> TractResult<()> {
91        self.session().optimize(model)
92    }
93
94    pub fn session(&self) -> OptimizerSession<'_> {
95        OptimizerSession { optimizer: self, counter: 0, seen: Default::default() }
96    }
97}
98
99#[derive(Debug)]
100pub struct OptimizerSession<'o> {
101    optimizer: &'o Optimizer,
102    counter: usize,
103    seen: HashSet<String>,
104}
105
106impl OptimizerSession<'_> {
107    pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
108        model.check_consistency().context("during optimizer preflight check")?;
109        model.compact().context("during optimizer preflight compaction")?;
110        model.check_names().context("after optimizer preflight compaction")?;
111        for i in 0.. {
112            let old = self.counter;
113            self.run_all_passes(i, model)?;
114            if old == self.counter {
115                return Ok(());
116            }
117            model.compact()?;
118        }
119        unreachable!()
120    }
121
122    pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
123        let mut passes = self.optimizer.passes.clone();
124        for p in passes.iter_mut() {
125            self.run_one_pass_outer(i, p.as_mut(), model)
126                .with_context(|| format!("running pass {p:?}"))?;
127            model.compact()?;
128            model
129                .check_consistency()
130                .with_context(|| format!("consistency check after pass {p:?}"))?;
131        }
132        Ok(())
133    }
134
135    pub fn run_one_pass_outer(
136        &mut self,
137        i: usize,
138        p: &mut dyn TypedPass,
139        model: &mut TypedModel,
140    ) -> TractResult<()> {
141        loop {
142            let old_counter = self.counter;
143            self.run_one_pass_inner(i, p, model)?;
144            if self.counter == old_counter {
145                return Ok(());
146            }
147            model.compact().with_context(|| format!("after pass {p:?}"))?;
148        }
149    }
150
151    pub fn run_one_pass_inner(
152        &mut self,
153        i: usize,
154        p: &mut dyn TypedPass,
155        model: &mut TypedModel,
156    ) -> TractResult<()> {
157        p.reset()?;
158        if let Some(steps) = self.optimizer.steps {
159            if self.counter >= steps {
160                return Ok(());
161            }
162        }
163        while let Some(mut patch) = p.next(self, model)? {
164            patch.push_context(format!("{p:?}/{i}"));
165            patch.model.check_consistency().context("checking patch internal consistency")?;
166            model
167                .check_consistency()
168                .context("Checking target model consistency before patching")?;
169            if let Some(watchdog) = patch.dont_apply_twice.take() {
170                if self.seen.contains(&watchdog) {
171                    debug!("Loop detected: {watchdog} seen before");
172                    continue;
173                } else {
174                    self.seen.insert(watchdog);
175                }
176            }
177            let patch_name = patch.context.iter().rev().join(" >> ");
178            debug!("applying patch #{}: {patch_name}", self.counter);
179            patch.apply(model).with_context(|| format!("Applying patch {patch_name}"))?;
180            model
181                .check_consistency()
182                .context("Checking target model consistency after patching")?;
183            self.counter += 1;
184            if let Some(steps) = self.optimizer.steps {
185                if self.counter >= steps {
186                    return Ok(());
187                }
188            }
189        }
190        model.check_consistency().with_context(|| format!("after pass {p:?}"))?;
191        Ok(())
192    }
193}