Skip to main content

tract_core/optim/
mod.rs

1use crate::internal::*;
2use std::collections::HashSet;
3use std::fmt::Debug;
4use tract_itertools::Itertools;
5
6pub mod change_axes;
7mod concat_then_einsum;
8mod op_optim;
9mod prop_const;
10pub mod propagate_roi;
11mod push_split_down;
12mod slice;
13mod uniform_mask;
14
15use self::change_axes::ChangeAxes;
16use self::prop_const::PropConst;
17use self::propagate_roi::PropagateRoi;
18use self::push_split_down::PushSplitDown;
19use self::slice::PushSliceUp;
20use self::uniform_mask::FoldUniformMask;
21use op_optim::OpOptim;
22
23pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
24    fn reset(&mut self) -> TractResult<()>;
25    fn next(
26        &mut self,
27        session: &mut OptimizerSession,
28        model: &TypedModel,
29    ) -> TractResult<Option<TypedModelPatch>>;
30    /// In-place model mutation hook. Returns true if the model was changed.
31    fn run_direct(&mut self, _model: &mut TypedModel) -> TractResult<bool> {
32        Ok(false)
33    }
34}
35
36dyn_clone::clone_trait_object!(TypedPass);
37
38#[derive(Debug)]
39pub struct Optimizer {
40    pub passes: Vec<Box<dyn TypedPass>>,
41    pub steps: Option<usize>,
42}
43
44impl Optimizer {
45    fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
46        Optimizer { passes, steps: None }
47    }
48
49    pub fn add_pass(&mut self, idx: usize, pass: Box<dyn TypedPass>) {
50        let num_pass = self.passes.len();
51        if idx > num_pass {
52            log::warn!(
53                "Cannot add new pass {pass:?} at index {idx}. Optimizer currently as {num_pass} passes, pass will be added as the last pass."
54            );
55            self.passes.push(pass);
56        } else {
57            self.passes.insert(idx, pass);
58        }
59    }
60
61    pub fn stopping_at(self, steps: usize) -> Optimizer {
62        Optimizer { steps: Some(steps), ..self }
63    }
64
65    pub fn prop_consts() -> Optimizer {
66        Optimizer::passes(vec![Box::<PropConst>::default()])
67    }
68
69    pub fn declutter() -> Optimizer {
70        Optimizer::passes(vec![
71            Box::<PropConst>::default(),
72            Box::<PropagateRoi>::default(),
73            Box::<FoldUniformMask>::default(),
74            Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
75            Box::new(PushSliceUp),
76            Box::new(PushSplitDown),
77            Box::<concat_then_einsum::ConcatThenEinsum>::default(),
78            Box::<ChangeAxes>::default(),
79        ])
80    }
81
82    pub fn codegen() -> Optimizer {
83        Optimizer::passes(vec![
84            Box::<PropConst>::default(),
85            Box::new(OpOptim(
86                "codegen",
87                |op, _session, model, node| TypedOp::codegen(op, model, node),
88                0,
89            )),
90            Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
91            Box::new(PushSplitDown),
92            Box::new(OpOptim(
93                "fuse",
94                |op, _session, model, node| TypedOp::fuse(op, model, node),
95                0,
96            )),
97        ])
98    }
99
100    pub fn optimize(&self, model: &mut TypedModel) -> TractResult<()> {
101        self.session().optimize(model)
102    }
103
104    pub fn session(&self) -> OptimizerSession<'_> {
105        OptimizerSession { optimizer: self, counter: 0, seen: Default::default() }
106    }
107}
108
109#[derive(Debug)]
110pub struct OptimizerSession<'o> {
111    optimizer: &'o Optimizer,
112    counter: usize,
113    seen: HashSet<String>,
114}
115
116impl OptimizerSession<'_> {
117    pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
118        let _proof_session = model.symbols.proof_cache_session();
119        model.check_consistency().context("during optimizer preflight check")?;
120        model.compact().context("during optimizer preflight compaction")?;
121        model.check_names().context("after optimizer preflight compaction")?;
122        for i in 0.. {
123            let old = self.counter;
124            self.run_all_passes(i, model)?;
125            if old == self.counter {
126                return Ok(());
127            }
128            model.compact()?;
129        }
130        unreachable!()
131    }
132
133    pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
134        let mut passes = self.optimizer.passes.clone();
135        for p in passes.iter_mut() {
136            self.run_one_pass_outer(i, p.as_mut(), model)
137                .with_context(|| format!("running pass {p:?}"))?;
138            model.compact()?;
139            model
140                .check_consistency()
141                .with_context(|| format!("consistency check after pass {p:?}"))?;
142        }
143        Ok(())
144    }
145
146    pub fn run_one_pass_outer(
147        &mut self,
148        i: usize,
149        p: &mut dyn TypedPass,
150        model: &mut TypedModel,
151    ) -> TractResult<()> {
152        loop {
153            let old_counter = self.counter;
154            self.run_one_pass_inner(i, p, model)?;
155            if self.counter == old_counter {
156                return Ok(());
157            }
158            model.compact().with_context(|| format!("after pass {p:?}"))?;
159        }
160    }
161
162    pub fn run_one_pass_inner(
163        &mut self,
164        i: usize,
165        p: &mut dyn TypedPass,
166        model: &mut TypedModel,
167    ) -> TractResult<()> {
168        p.reset()?;
169        if let Some(steps) = self.optimizer.steps
170            && self.counter >= steps
171        {
172            return Ok(());
173        }
174        while let Some(mut patch) = p.next(self, model)? {
175            patch.push_context(format!("{p:?}/{i}"));
176            patch.model.check_consistency().context("checking patch internal consistency")?;
177            model
178                .check_consistency()
179                .context("Checking target model consistency before patching")?;
180            if let Some(watchdog) = patch.dont_apply_twice.take() {
181                if self.seen.contains(&watchdog) {
182                    debug!("Loop detected: {watchdog} seen before");
183                    continue;
184                } else {
185                    self.seen.insert(watchdog);
186                }
187            }
188            let patch_name = patch.context.iter().rev().join(" >> ");
189            debug!("applying patch #{}: {patch_name}", self.counter);
190            patch.apply(model).with_context(|| format!("Applying patch {patch_name}"))?;
191            model
192                .check_consistency()
193                .context("Checking target model consistency after patching")?;
194            self.counter += 1;
195            if let Some(steps) = self.optimizer.steps
196                && self.counter >= steps
197            {
198                return Ok(());
199            }
200        }
201        if p.run_direct(model)? {
202            model.check_consistency().with_context(|| format!("after run_direct {p:?}"))?;
203        }
204        model.check_consistency().with_context(|| format!("after pass {p:?}"))?;
205        Ok(())
206    }
207}