1use crate::internal::*;
2use std::collections::HashSet;
3use std::fmt::Debug;
4use tract_itertools::Itertools;
5
6pub mod change_axes;
7mod concat_then_einsum;
8mod op_optim;
9mod prop_const;
10pub mod propagate_roi;
11pub mod propagate_uniform_tdim;
12mod push_split_down;
13mod slice;
14mod uniform_mask;
15
16use self::change_axes::ChangeAxes;
17use self::prop_const::PropConst;
18use self::propagate_roi::PropagateRoi;
19use self::propagate_uniform_tdim::PropagateUniformTdim;
20use self::push_split_down::PushSplitDown;
21use self::slice::PushSliceUp;
22use self::uniform_mask::FoldUniformMask;
23use op_optim::OpOptim;
24
25pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
26 fn reset(&mut self) -> TractResult<()>;
27 fn next(
28 &mut self,
29 session: &mut OptimizerSession,
30 model: &TypedModel,
31 ) -> TractResult<Option<TypedModelPatch>>;
32 fn run_direct(&mut self, _model: &mut TypedModel) -> TractResult<bool> {
34 Ok(false)
35 }
36}
37
38#[derive(Clone, Debug, Default)]
39struct MergeConsecutiveSameRoleAxes;
40
41impl TypedPass for MergeConsecutiveSameRoleAxes {
42 fn reset(&mut self) -> TractResult<()> {
43 Ok(())
44 }
45 fn next(
46 &mut self,
47 _session: &mut OptimizerSession,
48 _model: &TypedModel,
49 ) -> TractResult<Option<TypedModelPatch>> {
50 Ok(None)
51 }
52 fn run_direct(&mut self, model: &mut TypedModel) -> TractResult<bool> {
53 let before = model.nodes.len();
54 crate::ops::einsum::einsum_matmul::merge_consecutive_same_role_axes(model)?;
55 Ok(model.nodes.len() != before)
56 }
57}
58
59dyn_clone::clone_trait_object!(TypedPass);
60
61#[derive(Debug)]
62pub struct Optimizer {
63 pub passes: Vec<Box<dyn TypedPass>>,
64 pub steps: Option<usize>,
65}
66
67impl Optimizer {
68 fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
69 Optimizer { passes, steps: None }
70 }
71
72 pub fn add_pass(&mut self, idx: usize, pass: Box<dyn TypedPass>) {
73 let num_pass = self.passes.len();
74 if idx > num_pass {
75 log::warn!(
76 "Cannot add new pass {pass:?} at index {idx}. Optimizer currently as {num_pass} passes, pass will be added as the last pass."
77 );
78 self.passes.push(pass);
79 } else {
80 self.passes.insert(idx, pass);
81 }
82 }
83
84 pub fn stopping_at(self, steps: usize) -> Optimizer {
85 Optimizer { steps: Some(steps), ..self }
86 }
87
88 pub fn prop_consts() -> Optimizer {
89 Optimizer::passes(vec![Box::<PropConst>::default()])
90 }
91
92 pub fn declutter() -> Optimizer {
93 Optimizer::passes(vec![
94 Box::<PropConst>::default(),
95 Box::<PropagateUniformTdim>::default(),
96 Box::<PropagateRoi>::default(),
97 Box::<FoldUniformMask>::default(),
98 Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
99 Box::new(PushSliceUp),
100 Box::new(PushSplitDown),
101 Box::<concat_then_einsum::ConcatThenEinsum>::default(),
102 Box::<ChangeAxes>::default(),
103 ])
104 }
105
106 pub fn codegen() -> Optimizer {
107 Optimizer::passes(vec![
108 Box::<PropConst>::default(),
109 Box::<MergeConsecutiveSameRoleAxes>::default(),
110 Box::new(OpOptim(
111 "codegen",
112 |op, _session, model, node| TypedOp::codegen(op, model, node),
113 0,
114 )),
115 Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
116 Box::new(PushSplitDown),
117 Box::new(OpOptim(
118 "fuse",
119 |op, _session, model, node| TypedOp::fuse(op, model, node),
120 0,
121 )),
122 ])
123 }
124
125 pub fn optimize(&self, model: &mut TypedModel) -> TractResult<()> {
126 self.session().optimize(model)
127 }
128
129 pub fn session(&self) -> OptimizerSession<'_> {
130 OptimizerSession { optimizer: self, counter: 0, seen: Default::default() }
131 }
132}
133
134#[derive(Debug)]
135pub struct OptimizerSession<'o> {
136 optimizer: &'o Optimizer,
137 counter: usize,
138 seen: HashSet<String>,
139}
140
141impl OptimizerSession<'_> {
142 pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
143 let _proof_session = model.symbols.proof_cache_session();
144 model.check_consistency().context("during optimizer preflight check")?;
145 model.compact().context("during optimizer preflight compaction")?;
146 model.check_names().context("after optimizer preflight compaction")?;
147 for i in 0.. {
148 let old = self.counter;
149 self.run_all_passes(i, model)?;
150 if old == self.counter {
151 return Ok(());
152 }
153 model.compact()?;
154 }
155 unreachable!()
156 }
157
158 pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
159 let mut passes = self.optimizer.passes.clone();
160 for p in passes.iter_mut() {
161 self.run_one_pass_outer(i, p.as_mut(), model)
162 .with_context(|| format!("running pass {p:?}"))?;
163 model.compact()?;
164 model
165 .check_consistency()
166 .with_context(|| format!("consistency check after pass {p:?}"))?;
167 }
168 Ok(())
169 }
170
171 pub fn run_one_pass_outer(
172 &mut self,
173 i: usize,
174 p: &mut dyn TypedPass,
175 model: &mut TypedModel,
176 ) -> TractResult<()> {
177 loop {
178 let old_counter = self.counter;
179 self.run_one_pass_inner(i, p, model)?;
180 if self.counter == old_counter {
181 return Ok(());
182 }
183 model.compact().with_context(|| format!("after pass {p:?}"))?;
184 }
185 }
186
187 pub fn run_one_pass_inner(
188 &mut self,
189 i: usize,
190 p: &mut dyn TypedPass,
191 model: &mut TypedModel,
192 ) -> TractResult<()> {
193 p.reset()?;
194 if let Some(steps) = self.optimizer.steps
195 && self.counter >= steps
196 {
197 return Ok(());
198 }
199 while let Some(mut patch) = p.next(self, model)? {
200 patch.push_context(format!("{p:?}/{i}"));
201 patch.model.check_consistency().context("checking patch internal consistency")?;
202 model
203 .check_consistency()
204 .context("Checking target model consistency before patching")?;
205 if let Some(watchdog) = patch.dont_apply_twice.take() {
206 if self.seen.contains(&watchdog) {
207 debug!("Loop detected: {watchdog} seen before");
208 continue;
209 } else {
210 self.seen.insert(watchdog);
211 }
212 }
213 let patch_name = patch.context.iter().rev().join(" >> ");
214 debug!("applying patch #{}: {patch_name}", self.counter);
215 patch.apply(model).with_context(|| format!("Applying patch {patch_name}"))?;
216 model
217 .check_consistency()
218 .context("Checking target model consistency after patching")?;
219 self.counter += 1;
220 if let Some(steps) = self.optimizer.steps
221 && self.counter >= steps
222 {
223 return Ok(());
224 }
225 }
226 if p.run_direct(model)? {
227 model.check_consistency().with_context(|| format!("after run_direct {p:?}"))?;
228 }
229 model.check_consistency().with_context(|| format!("after pass {p:?}"))?;
230 Ok(())
231 }
232}