1use crate::internal::*;
2use std::collections::HashSet;
3use std::fmt::Debug;
4use tract_itertools::Itertools;
5
6pub mod change_axes;
7mod op_optim;
8mod prop_const;
9mod push_split_down;
10mod slice;
11
12use self::change_axes::ChangeAxes;
13use self::prop_const::PropConst;
14use self::push_split_down::PushSplitDown;
15use self::slice::PushSliceUp;
16use op_optim::OpOptim;
17
18pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
19 fn reset(&mut self) -> TractResult<()>;
20 fn next(
21 &mut self,
22 session: &mut OptimizerSession,
23 model: &TypedModel,
24 ) -> TractResult<Option<TypedModelPatch>>;
25}
26
27dyn_clone::clone_trait_object!(TypedPass);
28
29#[derive(Debug)]
30pub struct Optimizer {
31 pub passes: Vec<Box<dyn TypedPass>>,
32 pub steps: Option<usize>,
33}
34
35impl Optimizer {
36 fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
37 Optimizer { passes, steps: None }
38 }
39
40 pub fn add_pass(&mut self, idx: usize, pass: Box<dyn TypedPass>) {
41 let num_pass = self.passes.len();
42 if idx > num_pass {
43 log::warn!("Cannot add new pass {:?} at index {}. Optimizer currently as {} passes, pass will be added as the last pass.", pass, idx, num_pass);
44 self.passes.push(pass);
45 } else {
46 self.passes.insert(idx, pass);
47 }
48 }
49
50 pub fn stopping_at(self, steps: usize) -> Optimizer {
51 Optimizer { steps: Some(steps), ..self }
52 }
53
54 pub fn prop_consts() -> Optimizer {
55 Optimizer::passes(vec![Box::<PropConst>::default()])
56 }
57
58 pub fn declutter() -> Optimizer {
59 Optimizer::passes(vec![
60 Box::<PropConst>::default(),
61 Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
62 Box::new(PushSliceUp),
63 Box::new(PushSplitDown),
64 Box::<ChangeAxes>::default(),
65 ])
66 }
67
68 pub fn codegen() -> Optimizer {
69 Optimizer::passes(vec![
70 Box::<PropConst>::default(),
71 Box::new(OpOptim(
72 "codegen",
73 |op, _session, model, node| TypedOp::codegen(op, model, node),
74 0,
75 )),
76 Box::new(OpOptim("declutter", TypedOp::declutter_with_session, 0)),
77 Box::new(PushSplitDown),
78 Box::new(OpOptim(
79 "fuse",
80 |op, _session, model, node| TypedOp::fuse(op, model, node),
81 0,
82 )),
83 ])
84 }
85
86 pub fn optimize(&self, model: &mut TypedModel) -> TractResult<()> {
87 self.session().optimize(model)
88 }
89
90 pub fn session(&self) -> OptimizerSession {
91 OptimizerSession { optimizer: self, counter: 0, seen: Default::default() }
92 }
93}
94
95#[derive(Debug)]
96pub struct OptimizerSession<'o> {
97 optimizer: &'o Optimizer,
98 counter: usize,
99 seen: HashSet<String>,
100}
101
102impl OptimizerSession<'_> {
103 pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
104 model.check_consistency().context("during optimizer preflight check")?;
105 model.compact().context("during optimizer preflight compaction")?;
106 model.check_names().context("after optimizer preflight compaction")?;
107 for i in 0.. {
108 let old = self.counter;
109 self.run_all_passes(i, model)?;
110 if old == self.counter {
111 return Ok(());
112 }
113 model.compact()?;
114 }
115 unreachable!()
116 }
117
118 pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
119 let mut passes = self.optimizer.passes.clone();
120 for p in passes.iter_mut() {
121 self.run_one_pass_outer(i, p.as_mut(), model)
122 .with_context(|| format!("running pass {p:?}"))?;
123 model.compact()?;
124 model
125 .check_consistency()
126 .with_context(|| format!("consistency check after pass {p:?}"))?;
127 }
128 Ok(())
129 }
130
131 pub fn run_one_pass_outer(
132 &mut self,
133 i: usize,
134 p: &mut dyn TypedPass,
135 model: &mut TypedModel,
136 ) -> TractResult<()> {
137 loop {
138 let old_counter = self.counter;
139 self.run_one_pass_inner(i, p, model)?;
140 if self.counter == old_counter {
141 return Ok(());
142 }
143 model.compact().with_context(|| format!("after pass {p:?}"))?;
144 }
145 }
146
147 pub fn run_one_pass_inner(
148 &mut self,
149 i: usize,
150 p: &mut dyn TypedPass,
151 model: &mut TypedModel,
152 ) -> TractResult<()> {
153 p.reset()?;
154 if let Some(steps) = self.optimizer.steps {
155 if self.counter >= steps {
156 return Ok(());
157 }
158 }
159 while let Some(mut patch) = p.next(self, model)? {
160 patch.push_context(format!("{p:?}/{i}"));
161 patch.model.check_consistency().context("checking patch internal consistency")?;
162 model
163 .check_consistency()
164 .context("Checking target model consistency before patching")?;
165 if let Some(watchdog) = patch.dont_apply_twice.take() {
166 if self.seen.contains(&watchdog) {
167 debug!("Loop detected: {} seen before", watchdog);
168 continue;
169 } else {
170 self.seen.insert(watchdog);
171 }
172 }
173 let patch_name = patch.context.iter().rev().join(" >> ");
174 debug!("applying patch #{}: {patch_name}", self.counter);
175 patch.apply(model).with_context(|| format!("Applying patch {patch_name}"))?;
176 model
177 .check_consistency()
178 .context("Checking target model consistency after patching")?;
179 self.counter += 1;
180 if let Some(steps) = self.optimizer.steps {
181 if self.counter >= steps {
182 return Ok(());
183 }
184 }
185 }
186 model.check_consistency().with_context(|| format!("after pass {p:?}"))?;
187 Ok(())
188 }
189}