1use crate::internal::*;
2use std::collections::HashSet;
3use std::fmt::Debug;
4use tract_itertools::Itertools;
5
6pub mod change_axes;
7mod concat_then_einsum;
8mod op_optim;
9mod prop_const;
10mod push_split_down;
11mod slice;
12
13use self::change_axes::ChangeAxes;
14use self::prop_const::PropConst;
15use self::push_split_down::PushSplitDown;
16use self::slice::PushSliceUp;
17use op_optim::OpOptim;
18
19pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
20 fn reset(&mut self) -> TractResult<()>;
21 fn next(
22 &mut self,
23 session: &mut OptimizerSession,
24 model: &TypedModel,
25 ) -> TractResult<Option<TypedModelPatch>>;
26}
27
28dyn_clone::clone_trait_object!(TypedPass);
29
30#[derive(Debug)]
31pub struct Optimizer {
32 pub passes: Vec<Box<dyn TypedPass>>,
33 pub steps: Option<usize>,
34}
35
36impl Optimizer {
37 fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
38 Optimizer { passes, steps: None }
39 }
40
41 pub fn add_pass(&mut self, idx: usize, pass: Box<dyn TypedPass>) {
42 let num_pass = self.passes.len();
43 if idx > num_pass {
44 log::warn!(
45 "Cannot add new pass {pass:?} at index {idx}. Optimizer currently as {num_pass} passes, pass will be added as the last pass."
46 );
47 self.passes.push(pass);
48 } else {
49 self.passes.insert(idx, pass);
50 }
51 }
52
53 pub fn stopping_at(self, steps: usize) -> Optimizer {
54 Optimizer { steps: Some(steps), ..self }
55 }
56
57 pub fn prop_consts() -> Optimizer {
58 Optimizer::passes(vec![Box::<PropConst>::default()])
59 }
60
61 pub fn declutter() -> Optimizer {
62 Optimizer::passes(vec![
63 Box::<PropConst>::default(),
64 Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
65 Box::new(PushSliceUp),
66 Box::new(PushSplitDown),
67 Box::<concat_then_einsum::ConcatThenEinsum>::default(),
68 Box::<ChangeAxes>::default(),
69 ])
70 }
71
72 pub fn codegen() -> Optimizer {
73 Optimizer::passes(vec![
74 Box::<PropConst>::default(),
75 Box::new(OpOptim(
76 "codegen",
77 |op, _session, model, node| TypedOp::codegen(op, model, node),
78 0,
79 )),
80 Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
81 Box::new(PushSplitDown),
82 Box::new(OpOptim(
83 "fuse",
84 |op, _session, model, node| TypedOp::fuse(op, model, node),
85 0,
86 )),
87 ])
88 }
89
90 pub fn optimize(&self, model: &mut TypedModel) -> TractResult<()> {
91 self.session().optimize(model)
92 }
93
94 pub fn session(&self) -> OptimizerSession<'_> {
95 OptimizerSession { optimizer: self, counter: 0, seen: Default::default() }
96 }
97}
98
99#[derive(Debug)]
100pub struct OptimizerSession<'o> {
101 optimizer: &'o Optimizer,
102 counter: usize,
103 seen: HashSet<String>,
104}
105
106impl OptimizerSession<'_> {
107 pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
108 let _proof_session = model.symbols.proof_cache_session();
109 model.check_consistency().context("during optimizer preflight check")?;
110 model.compact().context("during optimizer preflight compaction")?;
111 model.check_names().context("after optimizer preflight compaction")?;
112 for i in 0.. {
113 let old = self.counter;
114 self.run_all_passes(i, model)?;
115 if old == self.counter {
116 return Ok(());
117 }
118 model.compact()?;
119 }
120 unreachable!()
121 }
122
123 pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
124 let mut passes = self.optimizer.passes.clone();
125 for p in passes.iter_mut() {
126 self.run_one_pass_outer(i, p.as_mut(), model)
127 .with_context(|| format!("running pass {p:?}"))?;
128 model.compact()?;
129 model
130 .check_consistency()
131 .with_context(|| format!("consistency check after pass {p:?}"))?;
132 }
133 Ok(())
134 }
135
136 pub fn run_one_pass_outer(
137 &mut self,
138 i: usize,
139 p: &mut dyn TypedPass,
140 model: &mut TypedModel,
141 ) -> TractResult<()> {
142 loop {
143 let old_counter = self.counter;
144 self.run_one_pass_inner(i, p, model)?;
145 if self.counter == old_counter {
146 return Ok(());
147 }
148 model.compact().with_context(|| format!("after pass {p:?}"))?;
149 }
150 }
151
152 pub fn run_one_pass_inner(
153 &mut self,
154 i: usize,
155 p: &mut dyn TypedPass,
156 model: &mut TypedModel,
157 ) -> TractResult<()> {
158 p.reset()?;
159 if let Some(steps) = self.optimizer.steps {
160 if self.counter >= steps {
161 return Ok(());
162 }
163 }
164 while let Some(mut patch) = p.next(self, model)? {
165 patch.push_context(format!("{p:?}/{i}"));
166 patch.model.check_consistency().context("checking patch internal consistency")?;
167 model
168 .check_consistency()
169 .context("Checking target model consistency before patching")?;
170 if let Some(watchdog) = patch.dont_apply_twice.take() {
171 if self.seen.contains(&watchdog) {
172 debug!("Loop detected: {watchdog} seen before");
173 continue;
174 } else {
175 self.seen.insert(watchdog);
176 }
177 }
178 let patch_name = patch.context.iter().rev().join(" >> ");
179 debug!("applying patch #{}: {patch_name}", self.counter);
180 patch.apply(model).with_context(|| format!("Applying patch {patch_name}"))?;
181 model
182 .check_consistency()
183 .context("Checking target model consistency after patching")?;
184 self.counter += 1;
185 if let Some(steps) = self.optimizer.steps {
186 if self.counter >= steps {
187 return Ok(());
188 }
189 }
190 }
191 model.check_consistency().with_context(|| format!("after pass {p:?}"))?;
192 Ok(())
193 }
194}