1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
use super::OptimizerSession;
use super::TypedPass;
use crate::internal::*;
use crate::model::*;
use std::collections::HashSet;
use std::fmt::Debug;
use crate::ops::change_axes::*;
#[derive(Clone, Default)]
pub struct ChangeAxes(HashSet<(usize, (InOut, AxisOp))>);
impl Debug for ChangeAxes {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ChangeAxes")
}
}
impl TypedPass for ChangeAxes {
fn reset(&mut self) -> TractResult<()> {
self.0.clear();
Ok(())
}
fn next(
&mut self,
_session: &mut OptimizerSession,
model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
let mut interfaces = model.output_outlets()?.to_vec();
interfaces.extend(model.input_outlets()?.iter());
for n in model.eval_order()? {
for suggestion in model.node(n).op.suggested_axis_changes()? {
if self.0.insert((n, suggestion.clone())) {
let outlet = suggestion.0.as_outlet(model.node(n));
let change = AxisChange { outlet, op: suggestion.1.clone() };
if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[])
.with_context(|| {
format!("Making patch for {:?} from {}", change, model.node(n))
})?
{
return Ok(Some(patch));
}
}
}
}
Ok(None)
}
}