Skip to main content

tract_core/optim/
mod.rs

1use crate::internal::*;
2use std::collections::HashSet;
3use std::fmt::Debug;
4use tract_itertools::Itertools;
5
6pub mod change_axes;
7mod concat_then_einsum;
8mod op_optim;
9mod prop_const;
10pub mod propagate_roi;
11pub mod propagate_uniform_tdim;
12mod push_split_down;
13mod slice;
14mod uniform_mask;
15
16use self::change_axes::ChangeAxes;
17use self::prop_const::PropConst;
18use self::propagate_roi::PropagateRoi;
19use self::propagate_uniform_tdim::PropagateUniformTdim;
20use self::push_split_down::PushSplitDown;
21use self::slice::PushSliceUp;
22use self::uniform_mask::FoldUniformMask;
23use op_optim::OpOptim;
24
25pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
26    fn reset(&mut self) -> TractResult<()>;
27    fn next(
28        &mut self,
29        session: &mut OptimizerSession,
30        model: &TypedModel,
31    ) -> TractResult<Option<TypedModelPatch>>;
32    /// In-place model mutation hook. Returns true if the model was changed.
33    fn run_direct(&mut self, _model: &mut TypedModel) -> TractResult<bool> {
34        Ok(false)
35    }
36}
37
38#[derive(Clone, Debug, Default)]
39struct MergeConsecutiveSameRoleAxes;
40
41impl TypedPass for MergeConsecutiveSameRoleAxes {
42    fn reset(&mut self) -> TractResult<()> {
43        Ok(())
44    }
45    fn next(
46        &mut self,
47        _session: &mut OptimizerSession,
48        _model: &TypedModel,
49    ) -> TractResult<Option<TypedModelPatch>> {
50        Ok(None)
51    }
52    fn run_direct(&mut self, model: &mut TypedModel) -> TractResult<bool> {
53        let before = model.nodes.len();
54        crate::ops::einsum::einsum_matmul::merge_consecutive_same_role_axes(model)?;
55        Ok(model.nodes.len() != before)
56    }
57}
58
59dyn_clone::clone_trait_object!(TypedPass);
60
61#[derive(Debug)]
62pub struct Optimizer {
63    pub passes: Vec<Box<dyn TypedPass>>,
64    pub steps: Option<usize>,
65}
66
67impl Optimizer {
68    fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
69        Optimizer { passes, steps: None }
70    }
71
72    pub fn add_pass(&mut self, idx: usize, pass: Box<dyn TypedPass>) {
73        let num_pass = self.passes.len();
74        if idx > num_pass {
75            log::warn!(
76                "Cannot add new pass {pass:?} at index {idx}. Optimizer currently as {num_pass} passes, pass will be added as the last pass."
77            );
78            self.passes.push(pass);
79        } else {
80            self.passes.insert(idx, pass);
81        }
82    }
83
84    pub fn stopping_at(self, steps: usize) -> Optimizer {
85        Optimizer { steps: Some(steps), ..self }
86    }
87
88    pub fn prop_consts() -> Optimizer {
89        Optimizer::passes(vec![Box::<PropConst>::default()])
90    }
91
92    pub fn declutter() -> Optimizer {
93        Optimizer::passes(vec![
94            Box::<PropConst>::default(),
95            Box::<PropagateUniformTdim>::default(),
96            Box::<PropagateRoi>::default(),
97            Box::<FoldUniformMask>::default(),
98            Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
99            Box::new(PushSliceUp),
100            Box::new(PushSplitDown),
101            Box::<concat_then_einsum::ConcatThenEinsum>::default(),
102            Box::<ChangeAxes>::default(),
103        ])
104    }
105
106    pub fn codegen() -> Optimizer {
107        Optimizer::passes(vec![
108            Box::<PropConst>::default(),
109            Box::<MergeConsecutiveSameRoleAxes>::default(),
110            Box::new(OpOptim(
111                "codegen",
112                |op, _session, model, node| TypedOp::codegen(op, model, node),
113                0,
114            )),
115            Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
116            Box::new(PushSplitDown),
117            Box::new(OpOptim(
118                "fuse",
119                |op, _session, model, node| TypedOp::fuse(op, model, node),
120                0,
121            )),
122        ])
123    }
124
125    pub fn optimize(&self, model: &mut TypedModel) -> TractResult<()> {
126        self.session().optimize(model)
127    }
128
129    pub fn session(&self) -> OptimizerSession<'_> {
130        OptimizerSession { optimizer: self, counter: 0, seen: Default::default() }
131    }
132}
133
134#[derive(Debug)]
135pub struct OptimizerSession<'o> {
136    optimizer: &'o Optimizer,
137    counter: usize,
138    seen: HashSet<String>,
139}
140
141impl OptimizerSession<'_> {
142    pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
143        let _proof_session = model.symbols.proof_cache_session();
144        model.check_consistency().context("during optimizer preflight check")?;
145        model.compact().context("during optimizer preflight compaction")?;
146        model.check_names().context("after optimizer preflight compaction")?;
147        for i in 0.. {
148            let old = self.counter;
149            self.run_all_passes(i, model)?;
150            if old == self.counter {
151                return Ok(());
152            }
153            model.compact()?;
154        }
155        unreachable!()
156    }
157
158    pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
159        let mut passes = self.optimizer.passes.clone();
160        for p in passes.iter_mut() {
161            self.run_one_pass_outer(i, p.as_mut(), model)
162                .with_context(|| format!("running pass {p:?}"))?;
163            model.compact()?;
164            model
165                .check_consistency()
166                .with_context(|| format!("consistency check after pass {p:?}"))?;
167        }
168        Ok(())
169    }
170
171    pub fn run_one_pass_outer(
172        &mut self,
173        i: usize,
174        p: &mut dyn TypedPass,
175        model: &mut TypedModel,
176    ) -> TractResult<()> {
177        loop {
178            let old_counter = self.counter;
179            self.run_one_pass_inner(i, p, model)?;
180            if self.counter == old_counter {
181                return Ok(());
182            }
183            model.compact().with_context(|| format!("after pass {p:?}"))?;
184        }
185    }
186
187    pub fn run_one_pass_inner(
188        &mut self,
189        i: usize,
190        p: &mut dyn TypedPass,
191        model: &mut TypedModel,
192    ) -> TractResult<()> {
193        p.reset()?;
194        if let Some(steps) = self.optimizer.steps
195            && self.counter >= steps
196        {
197            return Ok(());
198        }
199        while let Some(mut patch) = p.next(self, model)? {
200            patch.push_context(format!("{p:?}/{i}"));
201            patch.model.check_consistency().context("checking patch internal consistency")?;
202            model
203                .check_consistency()
204                .context("Checking target model consistency before patching")?;
205            if let Some(watchdog) = patch.dont_apply_twice.take() {
206                if self.seen.contains(&watchdog) {
207                    debug!("Loop detected: {watchdog} seen before");
208                    continue;
209                } else {
210                    self.seen.insert(watchdog);
211                }
212            }
213            let patch_name = patch.context.iter().rev().join(" >> ");
214            debug!("applying patch #{}: {patch_name}", self.counter);
215            patch.apply(model).with_context(|| format!("Applying patch {patch_name}"))?;
216            model
217                .check_consistency()
218                .context("Checking target model consistency after patching")?;
219            self.counter += 1;
220            if let Some(steps) = self.optimizer.steps
221                && self.counter >= steps
222            {
223                return Ok(());
224            }
225        }
226        if p.run_direct(model)? {
227            model.check_consistency().with_context(|| format!("after run_direct {p:?}"))?;
228        }
229        model.check_consistency().with_context(|| format!("after pass {p:?}"))?;
230        Ok(())
231    }
232}