1use crate::internal::*;
2use std::collections::HashSet;
3use std::fmt::Debug;
4use tract_itertools::Itertools;
5
6pub mod change_axes;
7mod concat_then_einsum;
8mod op_optim;
9mod prop_const;
10mod push_split_down;
11mod slice;
12
13use self::change_axes::ChangeAxes;
14use self::prop_const::PropConst;
15use self::push_split_down::PushSplitDown;
16use self::slice::PushSliceUp;
17use op_optim::OpOptim;
18
19pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
20 fn reset(&mut self) -> TractResult<()>;
21 fn next(
22 &mut self,
23 session: &mut OptimizerSession,
24 model: &TypedModel,
25 ) -> TractResult<Option<TypedModelPatch>>;
26}
27
28dyn_clone::clone_trait_object!(TypedPass);
29
30#[derive(Debug)]
31pub struct Optimizer {
32 pub passes: Vec<Box<dyn TypedPass>>,
33 pub steps: Option<usize>,
34}
35
36impl Optimizer {
37 fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
38 Optimizer { passes, steps: None }
39 }
40
41 pub fn add_pass(&mut self, idx: usize, pass: Box<dyn TypedPass>) {
42 let num_pass = self.passes.len();
43 if idx > num_pass {
44 log::warn!(
45 "Cannot add new pass {pass:?} at index {idx}. Optimizer currently as {num_pass} passes, pass will be added as the last pass."
46 );
47 self.passes.push(pass);
48 } else {
49 self.passes.insert(idx, pass);
50 }
51 }
52
53 pub fn stopping_at(self, steps: usize) -> Optimizer {
54 Optimizer { steps: Some(steps), ..self }
55 }
56
57 pub fn prop_consts() -> Optimizer {
58 Optimizer::passes(vec![Box::<PropConst>::default()])
59 }
60
61 pub fn declutter() -> Optimizer {
62 Optimizer::passes(vec![
63 Box::<PropConst>::default(),
64 Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
65 Box::new(PushSliceUp),
66 Box::new(PushSplitDown),
67 Box::<concat_then_einsum::ConcatThenEinsum>::default(),
68 Box::<ChangeAxes>::default(),
69 ])
70 }
71
72 pub fn codegen() -> Optimizer {
73 Optimizer::passes(vec![
74 Box::<PropConst>::default(),
75 Box::new(OpOptim(
76 "codegen",
77 |op, _session, model, node| TypedOp::codegen(op, model, node),
78 0,
79 )),
80 Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
81 Box::new(PushSplitDown),
82 Box::new(OpOptim(
83 "fuse",
84 |op, _session, model, node| TypedOp::fuse(op, model, node),
85 0,
86 )),
87 ])
88 }
89
90 pub fn optimize(&self, model: &mut TypedModel) -> TractResult<()> {
91 self.session().optimize(model)
92 }
93
94 pub fn session(&self) -> OptimizerSession<'_> {
95 OptimizerSession { optimizer: self, counter: 0, seen: Default::default() }
96 }
97}
98
99#[derive(Debug)]
100pub struct OptimizerSession<'o> {
101 optimizer: &'o Optimizer,
102 counter: usize,
103 seen: HashSet<String>,
104}
105
106impl OptimizerSession<'_> {
107 pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
108 model.check_consistency().context("during optimizer preflight check")?;
109 model.compact().context("during optimizer preflight compaction")?;
110 model.check_names().context("after optimizer preflight compaction")?;
111 for i in 0.. {
112 let old = self.counter;
113 self.run_all_passes(i, model)?;
114 if old == self.counter {
115 return Ok(());
116 }
117 model.compact()?;
118 }
119 unreachable!()
120 }
121
122 pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
123 let mut passes = self.optimizer.passes.clone();
124 for p in passes.iter_mut() {
125 self.run_one_pass_outer(i, p.as_mut(), model)
126 .with_context(|| format!("running pass {p:?}"))?;
127 model.compact()?;
128 model
129 .check_consistency()
130 .with_context(|| format!("consistency check after pass {p:?}"))?;
131 }
132 Ok(())
133 }
134
135 pub fn run_one_pass_outer(
136 &mut self,
137 i: usize,
138 p: &mut dyn TypedPass,
139 model: &mut TypedModel,
140 ) -> TractResult<()> {
141 loop {
142 let old_counter = self.counter;
143 self.run_one_pass_inner(i, p, model)?;
144 if self.counter == old_counter {
145 return Ok(());
146 }
147 model.compact().with_context(|| format!("after pass {p:?}"))?;
148 }
149 }
150
151 pub fn run_one_pass_inner(
152 &mut self,
153 i: usize,
154 p: &mut dyn TypedPass,
155 model: &mut TypedModel,
156 ) -> TractResult<()> {
157 p.reset()?;
158 if let Some(steps) = self.optimizer.steps {
159 if self.counter >= steps {
160 return Ok(());
161 }
162 }
163 while let Some(mut patch) = p.next(self, model)? {
164 patch.push_context(format!("{p:?}/{i}"));
165 patch.model.check_consistency().context("checking patch internal consistency")?;
166 model
167 .check_consistency()
168 .context("Checking target model consistency before patching")?;
169 if let Some(watchdog) = patch.dont_apply_twice.take() {
170 if self.seen.contains(&watchdog) {
171 debug!("Loop detected: {watchdog} seen before");
172 continue;
173 } else {
174 self.seen.insert(watchdog);
175 }
176 }
177 let patch_name = patch.context.iter().rev().join(" >> ");
178 debug!("applying patch #{}: {patch_name}", self.counter);
179 patch.apply(model).with_context(|| format!("Applying patch {patch_name}"))?;
180 model
181 .check_consistency()
182 .context("Checking target model consistency after patching")?;
183 self.counter += 1;
184 if let Some(steps) = self.optimizer.steps {
185 if self.counter >= steps {
186 return Ok(());
187 }
188 }
189 }
190 model.check_consistency().with_context(|| format!("after pass {p:?}"))?;
191 Ok(())
192 }
193}