1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
use crate::internal::*;
use std::fmt::Debug;
use tract_itertools::Itertools;

pub mod change_axes;
mod op_optim;
mod prop_const;
mod push_split_down;

use self::change_axes::ChangeAxes;
use self::prop_const::PropConst;
use self::push_split_down::PushSplitDown;
use op_optim::OpOptim;

pub trait TypedPass: Debug + Send + Sync + dyn_clone::DynClone {
    fn reset(&mut self) -> TractResult<()>;
    fn next(&mut self, model: &TypedModel) -> TractResult<Option<TypedModelPatch>>;
}

dyn_clone::clone_trait_object!(TypedPass);

pub struct Optimizer {
    passes: Vec<Box<dyn TypedPass>>,
    steps: Option<usize>,
}

impl Optimizer {
    fn passes(passes: Vec<Box<dyn TypedPass>>) -> Optimizer {
        Optimizer { passes, steps: None }
    }

    pub fn stopping_at(self, steps: usize) -> Optimizer {
        Optimizer { steps: Some(steps), ..self }
    }

    pub fn declutter() -> Optimizer {
        Optimizer::passes(vec![
            Box::new(OpOptim("declutter", TypedOp::declutter, 0)),
            Box::new(PropConst),
            Box::new(PushSplitDown),
            Box::new(ChangeAxes),
        ])
    }

    pub fn codegen() -> Optimizer {
        Optimizer::passes(vec![
            Box::new(OpOptim("codegen", TypedOp::codegen, 0)),
            Box::new(PushSplitDown),
            Box::new(OpOptim("fuse", TypedOp::fuse, 0)),
        ])
    }

    pub fn optimize(&self, model: &TypedModel) -> TractResult<TypedModel> {
        #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
        {
            model.check_consistent_facts()?;
        }
        let mut model = model.clone();
        let mut patches = 0;
        let mut passes = self.passes.clone();
        for i in 0.. {
            model = model.compact()?;
            let mut done_something_this_time = false;
            'pass: for p in passes.iter_mut() {
                loop {
                    let mut done_something_this_pass = false;
                    let mut seen = std::collections::HashSet::new();
                    p.reset()?;
                    while let Some(mut patch) = p.next(&model)? {
                        patch.push_context(format!("{:?}/{}", p, i));
                        #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
                        {
                            patch.model.check_consistent_facts()?;
                            model.check_consistent_facts()?;
                            patch.model.invariants()?;
                            model.invariants()?;
                        }
                        if let Some(watchdog) = patch.dont_apply_twice.take() {
                            if !seen.contains(&watchdog) {
                                debug!("Loop detected: {} seen before", watchdog);
                                break 'pass;
                            } else {
                                seen.insert(watchdog);
                            }
                        }
                        debug!(
                            "applying patch #{}: {}",
                            patches,
                            patch.context.iter().rev().join(" >> "),
                        );
                        done_something_this_pass = true;
                        done_something_this_time = true;
                        patch.apply(&mut model)?;
                        seen.clear();
                        patches += 1;
                        if let Some(steps) = self.steps {
                            if steps >= patches {
                                return Ok(model);
                            }
                        }
                    }
                    #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
                    {
                        model.check_edges()?;
                        model
                            .check_consistent_facts()
                            .with_context(|| format!("after declutter pass {:?}", p))?
                    }
                    if !done_something_this_pass {
                        continue 'pass;
                    }
                }
            }
            if !done_something_this_time {
                return Ok(model);
            }
            model = model.compact()?;
        }
        unreachable!()
    }
}