1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
use crate::internal::*;
use super::OptimizerSession;
#[derive(Clone)]
pub struct OpOptim(
pub &'static str,
pub fn(
op: &dyn TypedOp,
session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>>,
pub usize,
);
impl OpOptim {
fn full_pass(
&mut self,
session: &mut OptimizerSession,
new: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
for (ix, &id) in new.eval_order()?.iter().enumerate().skip(self.2) {
let node = &new.nodes()[id];
let patch = (self.1)(node.op.as_ref(), session, new, node)
.with_context(|| format!("{:?} node {}", self, node))?;
if let Some(mut p) = patch {
p.push_context(format!("{:?} {}", self, node));
self.2 = ix + p.dont_apply_twice.is_some() as usize;
return Ok(Some(p));
}
}
Ok(None)
}
}
impl std::fmt::Debug for OpOptim {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(fmt, "{}", self.0)
}
}
impl super::TypedPass for OpOptim {
fn reset(&mut self) -> TractResult<()> {
self.2 = 0;
Ok(())
}
fn next(
&mut self,
session: &mut OptimizerSession,
model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
self.full_pass(session, model)
}
}