1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
use crate::internal::*; #[derive(Clone)] pub struct OpOptim( pub &'static str, pub fn( op: &dyn TypedOp, model: &TypedModel, node: &TypedNode, ) -> TractResult<Option<TypedModelPatch>>, pub usize, ); impl OpOptim { fn full_pass(&mut self, new: &TypedModel) -> TractResult<Option<TypedModelPatch>> { for (ix, &id) in new.eval_order()?.iter().enumerate().skip(self.2) { let node = &new.nodes()[id]; let patch = (self.1)(node.op.as_ref(), &new, node) .with_context(|| format!("{:?} node {}", self, node))?; if let Some(mut p) = patch { p.push_context(format!("{:?} {}", self, node)); self.2 = ix; return Ok(Some(p)); } } Ok(None) } } impl std::fmt::Debug for OpOptim { fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result { write!(fmt, "{}", self.0) } } impl super::TypedPass for OpOptim { fn reset(&mut self) -> TractResult<()> { self.2 = 0; Ok(()) } fn next(&mut self, model: &TypedModel) -> TractResult<Option<TypedModelPatch>> { self.full_pass(model) } }