1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
use crate::internal::*; use std::fmt::Debug; use tract_itertools::Itertools; pub mod change_axes; mod op_optim; mod prop_const; mod push_split_down; use self::change_axes::ChangeAxes; use self::prop_const::PropConst; use self::push_split_down::PushSplitDown; use op_optim::OpOptim; pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone { fn reset(&mut self) -> TractResult<()>; fn next(&mut self, model: &TypedModel) -> TractResult<Option<TypedModelPatch>>; } dyn_clone::clone_trait_object!(TypedPass); pub struct Optimizer { passes: Vec<Box<dyn TypedPass>>, steps: Option<usize>, } impl Optimizer { fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer { Optimizer { passes, steps: None } } pub fn stopping_at(self, steps: usize) -> Optimizer { Optimizer { steps: Some(steps), ..self } } pub fn declutter() -> Optimizer { Optimizer::passes(vec![ Box::new(OpOptim("declutter", TypedOp::declutter, 0)), Box::new(PropConst), Box::new(PushSplitDown), Box::new(ChangeAxes), ]) } pub fn codegen() -> Optimizer { Optimizer::passes(vec![ Box::new(OpOptim("codegen", TypedOp::codegen, 0)), Box::new(PushSplitDown), Box::new(OpOptim("fuse", TypedOp::fuse, 0)), ]) } pub fn optimize(&self, model: &TypedModel) -> TractResult<TypedModel> { #[cfg(all(debug_assertions, feature = "paranoid_assertions"))] { model.check_consistent_facts()?; } let mut model = model.clone(); let mut patches = 0; let mut passes = self.passes.clone(); for i in 0.. { model = model.compact()?; let mut done_something_this_time = false; 'pass: for p in passes.iter_mut() { loop { let mut done_something_this_pass = false; let mut seen = std::collections::HashSet::new(); p.reset()?; while let Some(mut patch) = p.next(&model)? { patch.push_context(format!("{:?}/{}", p, i)); #[cfg(all(debug_assertions, feature = "paranoid_assertions"))] { patch.model.check_consistent_facts()?; model.check_consistent_facts()?; patch.model.invariants()?; model.invariants()?; } if let Some(watchdog) = patch.dont_apply_twice.take() { if !seen.contains(&watchdog) { debug!("Loop detected: {} seen before", watchdog); break 'pass; } else { seen.insert(watchdog); } } debug!( "applying patch #{}: {}", patches, patch.context.iter().rev().join(" >> "), ); done_something_this_pass = true; done_something_this_time = true; patch.apply(&mut model)?; seen.clear(); patches += 1; if let Some(steps) = self.steps { if steps >= patches { return Ok(model); } } } #[cfg(all(debug_assertions, feature = "paranoid_assertions"))] { model.check_edges()?; model .check_consistent_facts() .with_context(|| format!("after declutter pass {:?}", p))? } if !done_something_this_pass { continue 'pass; } } } if !done_something_this_time { return Ok(model); } model = model.compact()?; } unreachable!() } }