1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
use crate::internal::*; use std::fmt::Debug; use tract_itertools::Itertools; pub mod change_axes; mod op_optim; mod prop_const; mod push_split_down; use self::change_axes::ChangeAxes; use self::prop_const::PropConst; use self::push_split_down::PushSplitDown; use op_optim::OpOptim; pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone { fn reset(&mut self) -> TractResult<()>; fn next(&mut self, model: &TypedModel) -> TractResult<Option<TypedModelPatch>>; } dyn_clone::clone_trait_object!(TypedPass); pub struct Optimizer { passes: Vec<Box<dyn TypedPass>>, steps: Option<usize>, } impl Optimizer { fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer { Optimizer { passes, steps: None } } pub fn stopping_at(self, steps: usize) -> Optimizer { Optimizer { steps: Some(steps), ..self } } pub fn declutter() -> Optimizer { Optimizer::passes(vec![ Box::new(OpOptim("declutter", TypedOp::declutter, 0)), Box::new(PropConst), Box::new(PushSplitDown), Box::new(ChangeAxes), ]) } pub fn codegen() -> Optimizer { Optimizer::passes(vec![ Box::new(OpOptim("codegen", TypedOp::codegen, 0)), Box::new(OpOptim("declutter", TypedOp::declutter, 0)), Box::new(PropConst), Box::new(PushSplitDown), Box::new(OpOptim("fuse", TypedOp::fuse, 0)), ]) } pub fn optimize(&self, model: &TypedModel) -> TractResult<TypedModel> { #[cfg(all(debug_assertions, feature = "paranoid_assertions"))] { model.check_consistent_facts()?; } let mut model = model.compact()?; let mut counter = 0; for i in 0.. { let counter_and_model = self.run_all_passes(i, counter, model)?; if counter_and_model.0 == counter { return Ok(counter_and_model.1); } counter = counter_and_model.0; model = counter_and_model.1.compact()?; model = model.compact()?; } unreachable!() } pub fn run_all_passes( &self, i: usize, mut counter: usize, mut model: TypedModel, ) -> TractResult<(usize, TypedModel)> { let mut passes = self.passes.clone(); for p in passes.iter_mut() { let counter_and_model = self.run_one_pass_outer(i, p.as_mut(), counter, model)?; counter = counter_and_model.0; model = counter_and_model.1.compact()?; } Ok((counter, model)) } pub fn run_one_pass_outer( &self, i: usize, p: &mut dyn TypedPass, mut counter: usize, mut model: TypedModel, ) -> TractResult<(usize, TypedModel)> { loop { let counter_and_model = self.run_one_pass_inner(i, p, counter, model)?; if counter_and_model.0 == counter { return Ok(counter_and_model); } counter = counter_and_model.0; model = counter_and_model.1.compact()?; } } pub fn run_one_pass_inner( &self, i: usize, p: &mut dyn TypedPass, mut counter: usize, mut model: TypedModel, ) -> TractResult<(usize, TypedModel)> { let mut seen = std::collections::HashSet::new(); p.reset()?; while let Some(mut patch) = p.next(&model)? { if let Some(steps) = self.steps { if counter >= steps { return Ok((counter, model)); } } patch.push_context(format!("{:?}/{}", p, i)); #[cfg(all(debug_assertions, feature = "paranoid_assertions"))] { patch.model.check_consistent_facts()?; model.check_consistent_facts()?; patch.model.invariants()?; model.invariants()?; } if let Some(watchdog) = patch.dont_apply_twice.take() { if seen.contains(&watchdog) { debug!("Loop detected: {} seen before", watchdog); continue; } else { seen.insert(watchdog); } } debug!("applying patch #{}: {}", counter, patch.context.iter().rev().join(" >> "),); patch.apply(&mut model)?; seen.clear(); counter += 1; } #[cfg(all(debug_assertions, feature = "paranoid_assertions"))] { model.check_edges()?; model .check_consistent_facts() .with_context(|| format!("after declutter pass {:?}", p))? } Ok((counter, model)) } }