tract_core/optim/
mod.rs

1use crate::internal::*;
2use std::collections::HashSet;
3use std::fmt::Debug;
4use tract_itertools::Itertools;
5
6pub mod change_axes;
7mod op_optim;
8mod prop_const;
9mod push_split_down;
10mod slice;
11
12use self::change_axes::ChangeAxes;
13use self::prop_const::PropConst;
14use self::push_split_down::PushSplitDown;
15use self::slice::PushSliceUp;
16use op_optim::OpOptim;
17
18pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
19    fn reset(&mut self) -> TractResult<()>;
20    fn next(
21        &mut self,
22        session: &mut OptimizerSession,
23        model: &TypedModel,
24    ) -> TractResult<Option<TypedModelPatch>>;
25}
26
27dyn_clone::clone_trait_object!(TypedPass);
28
29#[derive(Debug)]
30pub struct Optimizer {
31    pub passes: Vec<Box<dyn TypedPass>>,
32    pub steps: Option<usize>,
33}
34
35impl Optimizer {
36    fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
37        Optimizer { passes, steps: None }
38    }
39
40    pub fn add_pass(&mut self, idx: usize, pass: Box<dyn TypedPass>) {
41        let num_pass = self.passes.len();
42        if idx > num_pass {
43            log::warn!("Cannot add new pass {:?} at index {}. Optimizer currently as {} passes, pass will be added as the last pass.", pass, idx, num_pass);
44            self.passes.push(pass);
45        } else {
46            self.passes.insert(idx, pass);
47        }
48    }
49
50    pub fn stopping_at(self, steps: usize) -> Optimizer {
51        Optimizer { steps: Some(steps), ..self }
52    }
53
54    pub fn prop_consts() -> Optimizer {
55        Optimizer::passes(vec![Box::<PropConst>::default()])
56    }
57
58    pub fn declutter() -> Optimizer {
59        Optimizer::passes(vec![
60            Box::<PropConst>::default(),
61            Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
62            Box::new(PushSliceUp),
63            Box::new(PushSplitDown),
64            Box::<ChangeAxes>::default(),
65        ])
66    }
67
68    pub fn codegen() -> Optimizer {
69        Optimizer::passes(vec![
70            Box::<PropConst>::default(),
71            Box::new(OpOptim(
72                "codegen",
73                |op, _session, model, node| TypedOp::codegen(op, model, node),
74                0,
75            )),
76            Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
77            Box::new(PushSplitDown),
78            Box::new(OpOptim(
79                "fuse",
80                |op, _session, model, node| TypedOp::fuse(op, model, node),
81                0,
82            )),
83        ])
84    }
85
86    pub fn optimize(&self, model: &mut TypedModel) -> TractResult<()> {
87        self.session().optimize(model)
88    }
89
90    pub fn session(&self) -> OptimizerSession {
91        OptimizerSession { optimizer: self, counter: 0, seen: Default::default() }
92    }
93}
94
95#[derive(Debug)]
96pub struct OptimizerSession<'o> {
97    optimizer: &'o Optimizer,
98    counter: usize,
99    seen: HashSet<String>,
100}
101
102impl OptimizerSession<'_> {
103    pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
104        model.check_consistency().context("during optimizer preflight check")?;
105        model.compact().context("during optimizer preflight compaction")?;
106        model.check_names().context("after optimizer preflight compaction")?;
107        for i in 0.. {
108            let old = self.counter;
109            self.run_all_passes(i, model)?;
110            if old == self.counter {
111                return Ok(());
112            }
113            model.compact()?;
114        }
115        unreachable!()
116    }
117
118    pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
119        let mut passes = self.optimizer.passes.clone();
120        for p in passes.iter_mut() {
121            self.run_one_pass_outer(i, p.as_mut(), model)
122                .with_context(|| format!("running pass {p:?}"))?;
123            model.compact()?;
124            model
125                .check_consistency()
126                .with_context(|| format!("consistency check after pass {p:?}"))?;
127        }
128        Ok(())
129    }
130
131    pub fn run_one_pass_outer(
132        &mut self,
133        i: usize,
134        p: &mut dyn TypedPass,
135        model: &mut TypedModel,
136    ) -> TractResult<()> {
137        loop {
138            let old_counter = self.counter;
139            self.run_one_pass_inner(i, p, model)?;
140            if self.counter == old_counter {
141                return Ok(());
142            }
143            model.compact().with_context(|| format!("after pass {p:?}"))?;
144        }
145    }
146
147    pub fn run_one_pass_inner(
148        &mut self,
149        i: usize,
150        p: &mut dyn TypedPass,
151        model: &mut TypedModel,
152    ) -> TractResult<()> {
153        p.reset()?;
154        if let Some(steps) = self.optimizer.steps {
155            if self.counter >= steps {
156                return Ok(());
157            }
158        }
159        while let Some(mut patch) = p.next(self, model)? {
160            patch.push_context(format!("{p:?}/{i}"));
161            patch.model.check_consistency().context("checking patch internal consistency")?;
162            model
163                .check_consistency()
164                .context("Checking target model consistency before patching")?;
165            if let Some(watchdog) = patch.dont_apply_twice.take() {
166                if self.seen.contains(&watchdog) {
167                    debug!("Loop detected: {} seen before", watchdog);
168                    continue;
169                } else {
170                    self.seen.insert(watchdog);
171                }
172            }
173            let patch_name = patch.context.iter().rev().join(" >> ");
174            debug!("applying patch #{}: {patch_name}", self.counter);
175            patch.apply(model).with_context(|| format!("Applying patch {patch_name}"))?;
176            model
177                .check_consistency()
178                .context("Checking target model consistency after patching")?;
179            self.counter += 1;
180            if let Some(steps) = self.optimizer.steps {
181                if self.counter >= steps {
182                    return Ok(());
183                }
184            }
185        }
186        model.check_consistency().with_context(|| format!("after pass {p:?}"))?;
187        Ok(())
188    }
189}