1use crate::internal::*;
2use std::collections::HashSet;
3use std::fmt::Debug;
4use tract_itertools::Itertools;
5
6pub mod change_axes;
7mod concat_then_einsum;
8mod op_optim;
9mod prop_const;
10pub mod propagate_roi;
11mod push_split_down;
12mod slice;
13mod uniform_mask;
14
15use self::change_axes::ChangeAxes;
16use self::prop_const::PropConst;
17use self::propagate_roi::PropagateRoi;
18use self::push_split_down::PushSplitDown;
19use self::slice::PushSliceUp;
20use self::uniform_mask::FoldUniformMask;
21use op_optim::OpOptim;
22
23pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
24 fn reset(&mut self) -> TractResult<()>;
25 fn next(
26 &mut self,
27 session: &mut OptimizerSession,
28 model: &TypedModel,
29 ) -> TractResult<Option<TypedModelPatch>>;
30 fn run_direct(&mut self, _model: &mut TypedModel) -> TractResult<bool> {
32 Ok(false)
33 }
34}
35
36dyn_clone::clone_trait_object!(TypedPass);
37
38#[derive(Debug)]
39pub struct Optimizer {
40 pub passes: Vec<Box<dyn TypedPass>>,
41 pub steps: Option<usize>,
42}
43
44impl Optimizer {
45 fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
46 Optimizer { passes, steps: None }
47 }
48
49 pub fn add_pass(&mut self, idx: usize, pass: Box<dyn TypedPass>) {
50 let num_pass = self.passes.len();
51 if idx > num_pass {
52 log::warn!(
53 "Cannot add new pass {pass:?} at index {idx}. Optimizer currently as {num_pass} passes, pass will be added as the last pass."
54 );
55 self.passes.push(pass);
56 } else {
57 self.passes.insert(idx, pass);
58 }
59 }
60
61 pub fn stopping_at(self, steps: usize) -> Optimizer {
62 Optimizer { steps: Some(steps), ..self }
63 }
64
65 pub fn prop_consts() -> Optimizer {
66 Optimizer::passes(vec![Box::<PropConst>::default()])
67 }
68
69 pub fn declutter() -> Optimizer {
70 Optimizer::passes(vec![
71 Box::<PropConst>::default(),
72 Box::<PropagateRoi>::default(),
73 Box::<FoldUniformMask>::default(),
74 Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
75 Box::new(PushSliceUp),
76 Box::new(PushSplitDown),
77 Box::<concat_then_einsum::ConcatThenEinsum>::default(),
78 Box::<ChangeAxes>::default(),
79 ])
80 }
81
82 pub fn codegen() -> Optimizer {
83 Optimizer::passes(vec![
84 Box::<PropConst>::default(),
85 Box::new(OpOptim(
86 "codegen",
87 |op, _session, model, node| TypedOp::codegen(op, model, node),
88 0,
89 )),
90 Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
91 Box::new(PushSplitDown),
92 Box::new(OpOptim(
93 "fuse",
94 |op, _session, model, node| TypedOp::fuse(op, model, node),
95 0,
96 )),
97 ])
98 }
99
100 pub fn optimize(&self, model: &mut TypedModel) -> TractResult<()> {
101 self.session().optimize(model)
102 }
103
104 pub fn session(&self) -> OptimizerSession<'_> {
105 OptimizerSession { optimizer: self, counter: 0, seen: Default::default() }
106 }
107}
108
109#[derive(Debug)]
110pub struct OptimizerSession<'o> {
111 optimizer: &'o Optimizer,
112 counter: usize,
113 seen: HashSet<String>,
114}
115
116impl OptimizerSession<'_> {
117 pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
118 let _proof_session = model.symbols.proof_cache_session();
119 model.check_consistency().context("during optimizer preflight check")?;
120 model.compact().context("during optimizer preflight compaction")?;
121 model.check_names().context("after optimizer preflight compaction")?;
122 for i in 0.. {
123 let old = self.counter;
124 self.run_all_passes(i, model)?;
125 if old == self.counter {
126 return Ok(());
127 }
128 model.compact()?;
129 }
130 unreachable!()
131 }
132
133 pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
134 let mut passes = self.optimizer.passes.clone();
135 for p in passes.iter_mut() {
136 self.run_one_pass_outer(i, p.as_mut(), model)
137 .with_context(|| format!("running pass {p:?}"))?;
138 model.compact()?;
139 model
140 .check_consistency()
141 .with_context(|| format!("consistency check after pass {p:?}"))?;
142 }
143 Ok(())
144 }
145
146 pub fn run_one_pass_outer(
147 &mut self,
148 i: usize,
149 p: &mut dyn TypedPass,
150 model: &mut TypedModel,
151 ) -> TractResult<()> {
152 loop {
153 let old_counter = self.counter;
154 self.run_one_pass_inner(i, p, model)?;
155 if self.counter == old_counter {
156 return Ok(());
157 }
158 model.compact().with_context(|| format!("after pass {p:?}"))?;
159 }
160 }
161
162 pub fn run_one_pass_inner(
163 &mut self,
164 i: usize,
165 p: &mut dyn TypedPass,
166 model: &mut TypedModel,
167 ) -> TractResult<()> {
168 p.reset()?;
169 if let Some(steps) = self.optimizer.steps
170 && self.counter >= steps
171 {
172 return Ok(());
173 }
174 while let Some(mut patch) = p.next(self, model)? {
175 patch.push_context(format!("{p:?}/{i}"));
176 patch.model.check_consistency().context("checking patch internal consistency")?;
177 model
178 .check_consistency()
179 .context("Checking target model consistency before patching")?;
180 if let Some(watchdog) = patch.dont_apply_twice.take() {
181 if self.seen.contains(&watchdog) {
182 debug!("Loop detected: {watchdog} seen before");
183 continue;
184 } else {
185 self.seen.insert(watchdog);
186 }
187 }
188 let patch_name = patch.context.iter().rev().join(" >> ");
189 debug!("applying patch #{}: {patch_name}", self.counter);
190 patch.apply(model).with_context(|| format!("Applying patch {patch_name}"))?;
191 model
192 .check_consistency()
193 .context("Checking target model consistency after patching")?;
194 self.counter += 1;
195 if let Some(steps) = self.optimizer.steps
196 && self.counter >= steps
197 {
198 return Ok(());
199 }
200 }
201 if p.run_direct(model)? {
202 model.check_consistency().with_context(|| format!("after run_direct {p:?}"))?;
203 }
204 model.check_consistency().with_context(|| format!("after pass {p:?}"))?;
205 Ok(())
206 }
207}