use crate::ops::einsum::EinSum;
use crate::ops::konst::Const;
use crate::optim::OptimizerSession;
use super::lir::{LirScan, LirScanOpParams};
use tract_data::internal::*;
use super::*;
#[derive(Debug, Clone, Default)]
pub struct Scan {
pub skip: usize,
pub reset_every_turn: bool,
pub body: TypedModel,
pub decluttered: bool,
pub input_mapping: Vec<InputMapping>,
pub output_mapping: Vec<OutputMapping<TDim>>,
}
impl Scan {
pub fn to_codegen_op(&self, optimize_inner: bool) -> TractResult<LirScan> {
let mut model = self.body.clone();
if optimize_inner {
model = model.into_optimized()?;
}
let plan = SimplePlan::new(model)?;
Ok(LirScan::new(Arc::new(LirScanOpParams::new(
self.skip,
self.reset_every_turn,
Arc::new(plan),
self.input_mapping.clone(),
self.output_mapping.clone(),
))))
}
pub fn new(
body: TypedModel,
input_mapping: Vec<InputMapping>,
output_mapping: Vec<OutputMapping<TDim>>,
skip: usize,
) -> TractResult<Scan> {
body.check_consistency()?;
ensure!(input_mapping.len() == body.input_outlets()?.len());
ensure!(output_mapping.len() == body.output_outlets()?.len());
Ok(Scan {
skip,
reset_every_turn: false,
body,
decluttered: false,
input_mapping,
output_mapping,
})
}
pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
self.to_codegen_op(false).unwrap().iteration_count(inputs)
}
fn declutter_body(
&self,
session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
if !self.decluttered {
let mut new = self.clone();
let mut body = self.body.clone();
session.optimize(&mut body)?;
new.body = body;
new.decluttered = true;
Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
} else {
Ok(None)
}
}
fn declutter_body_axes(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let mut suggestions = vec![];
for n in self.body.eval_order()? {
let node = self.body.node(n);
for suggestion in node.op.suggested_axis_changes()? {
let outlet = suggestion.0.as_outlet(node);
suggestions.push(AxisChange { outlet, op: suggestion.1 })
}
for (slot, fact) in node.outputs.iter().enumerate() {
for (ix, dim) in fact.fact.shape.iter().enumerate() {
if dim.is_one() {
suggestions.push(AxisChange {
outlet: OutletId::new(n, slot),
op: AxisOp::Rm(ix),
});
}
}
}
}
let node_input_facts = model.node_input_facts(node.id)?;
for suggestion in suggestions.into_iter() {
if let Some(conseq) = self.try_body_axes_change(suggestion, true, &node_input_facts)? {
let mut patch = TypedModelPatch::default();
let mut inputs = tvec!();
for outlet in &node.inputs {
inputs.push(patch.tap_model(model, *outlet)?);
}
for change in conseq.wire_changes {
if let InOut::In(i) = change.0 {
let mut value = patch
.outlet_fact(inputs[i])?
.konst
.clone()
.context("Will only reshape constants")?
.into_tensor();
change.1.change_tensor(&mut value, false)?;
let konst_name = patch.node(inputs[i].node).name.clone();
inputs[i] = patch.add_const(konst_name, value)?;
}
}
let wires = patch.wire_node(
&node.name,
conseq.substitute_op.unwrap_or_else(|| Box::new(self.clone())),
&inputs,
)?;
for (ix, new) in wires.into_iter().enumerate() {
patch.shunt_outside(model, OutletId::new(node.id, ix), new)?;
}
return Ok(Some(patch));
}
}
Ok(None)
}
fn remove_outer_output_from_mappings(
mappings: &[OutputMapping<TDim>],
discarded: usize,
) -> Vec<OutputMapping<TDim>> {
mappings
.iter()
.map(|m| OutputMapping {
scan: m.scan.map(|(slot, info)| (slot - (slot > discarded) as usize, info)),
last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
full_dim_hint: m.full_dim_hint.clone(),
state: m.state,
})
.collect()
}
fn declutter_const_input(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let inputs = model.node_input_facts(node.id)?;
for (slot, mapping) in self.input_mapping.iter().enumerate() {
if let InputMapping::Full = mapping {
if let Some(konst) = inputs[slot].konst.as_ref() {
let mut op = self.clone();
let src = op.body.inputs[slot];
op.body.inputs.remove(slot);
op.body.nodes[src.node].inputs.clear();
op.body.nodes[src.node].op = Box::new(Const::new(konst.clone()));
op.input_mapping.remove(slot);
let mut inputs = node.inputs.clone();
inputs.remove(slot);
return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
}
}
}
Ok(None)
}
fn declutter_discard_unused_input(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (slot, input) in self.body.input_outlets()?.iter().enumerate() {
let source_node = self.body.node(input.node);
if source_node.outputs[0].successors.len() == 0
&& !self.body.output_outlets()?.contains(input)
{
let mut new_inputs = node.inputs.clone();
new_inputs.remove(slot);
let mut new_mappings: Vec<_> = self.input_mapping.clone();
new_mappings.remove(slot);
let mut model_inputs = self.body.input_outlets()?.to_vec();
model_inputs.remove(slot);
let mut body = self.body.clone();
let mut patch = TypedModelPatch::default();
patch.obliterate(source_node.id)?;
patch.apply(&mut body)?;
body.set_input_outlets(&model_inputs)?;
body.declutter()?;
let op =
Self { body, input_mapping: new_mappings, decluttered: true, ..self.clone() };
return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
}
}
Ok(None)
}
fn declutter_discard_useless_outer_output(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (ix, o) in node.outputs.iter().enumerate() {
if o.successors.len() == 0
&& !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
{
let mappings = self
.output_mapping
.iter()
.map(|m| OutputMapping {
scan: m.scan.filter(|(slot, _info)| *slot != ix),
last_value_slot: m.last_value_slot.filter(|s| *s != ix),
full_dim_hint: m.full_dim_hint.clone(),
state: m.state,
})
.collect::<Vec<_>>();
let mut op = self.clone();
op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
let mut patch = TypedModelPatch::default();
let inputs = node
.inputs
.iter()
.map(|&i| patch.tap_model(model, i))
.collect::<TractResult<Vec<_>>>()?;
let wires = patch.wire_node(&*node.name, op, &inputs)?;
for oix in 0..node.outputs.len() {
if oix != ix {
patch.shunt_outside(
model,
OutletId::new(node.id, oix),
wires[oix - (oix > ix) as usize],
)?;
}
}
return Ok(Some(patch));
}
}
Ok(None)
}
fn declutter_discard_empty_output_mapping_with_body_output(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (ix, om) in self.output_mapping.iter().enumerate() {
if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
let mut new_op = self.clone();
new_op.output_mapping.remove(ix);
new_op.body.outputs.remove(ix);
new_op.decluttered = false;
return Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&node.inputs,
new_op,
)?));
}
}
Ok(None)
}
fn declutter_pull_batcheable_input(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
'candidate: for (slot, input) in self.input_mapping.iter().enumerate() {
if let Some(scan_info) = input.as_scan() {
let scan_source = self.body.input_outlets()?[slot];
let scan_source_node = self.body.node(scan_source.node);
for mut succ in &scan_source_node.outputs[0].successors {
for &succ_input in &self.body.node(succ.node).inputs {
if succ_input != scan_source
&& self.body.outlet_fact(succ_input)?.konst.is_none()
{
continue 'candidate;
}
}
if self.body.node(succ.node).outputs.len() != 1 {
continue;
}
let mut new_body = self.body.clone();
if let Some(einsum) = new_body.node(succ.node).op_as::<EinSum>() {
if let Some(patch) = einsum
.propagate_axis(
&new_body,
new_body.node(succ.node),
InOut::In(succ.slot),
scan_info.axis,
)
.context("building axis propagating patch")?
{
patch.apply(&mut new_body)?;
let new_body_scan_input = new_body.input_outlets()?[slot];
succ = new_body.node(new_body_scan_input.node).outputs[0]
.successors
.last()
.unwrap();
}
}
let axes_mapping = {
let (input_facts, output_facts) =
new_body.node_facts(new_body.node(succ.node).id)?;
new_body.node(succ.node).op.axes_mapping(&input_facts, &output_facts)?
};
let axis_info = axes_mapping.axis((InOut::In(succ.slot), scan_info.axis))?;
if let &[axis_after] = &*axis_info.outputs[0] {
let mut outside_patch = TypedModelPatch::new(format!(
"Outer patch for input extraction of {}",
new_body.node(succ.node)
));
let mut patch_inputs = node
.inputs
.iter()
.map(|&i| outside_patch.tap_model(model, i))
.collect::<TractResult<TVec<_>>>()?;
let mut extracted_op_inputs = tvec!();
for (ix, outlet) in new_body.node(succ.node).inputs.iter().enumerate() {
let wire = if ix == succ.slot {
patch_inputs[slot]
} else if let Some(konst) =
new_body.outlet_fact(*outlet)?.konst.as_ref()
{
outside_patch.add_const(
format!(
"{}.extracted.{}",
node.name,
new_body.node(outlet.node).name
),
konst.clone(),
)?
} else {
unreachable!();
};
extracted_op_inputs.push(wire);
}
let new_input_wire = outside_patch.wire_node(
format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
new_body.node(succ.node).op.clone(),
&extracted_op_inputs,
)?[0];
patch_inputs.push(new_input_wire);
let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
let mut new_input_inner_fact = new_input_outer_fact.clone();
new_input_inner_fact.shape.set(axis_after, scan_info.chunk.abs().to_dim());
let mut new_body = new_body.clone();
let new_source_wire = new_body.add_source(
format!("{}.extracted.{}", node.name, new_body.node(succ.node).name),
new_input_inner_fact,
)?;
let mut inner_patch = TypedModelPatch::new(format!(
"Inner body patch for extraction of {}",
new_body.node(succ.node)
));
let new_source_wire_in_patch =
inner_patch.tap_model(&new_body, new_source_wire)?;
inner_patch
.shunt_outside(
&new_body,
OutletId::new(succ.node, 0),
new_source_wire_in_patch,
)
.with_context(|| "patching inner model")?;
inner_patch.apply(&mut new_body)?;
let mut input_mapping = self.input_mapping.clone();
input_mapping.push(InputMapping::Scan(ScanInfo {
axis: axis_after,
chunk: scan_info.chunk,
}));
let new_op = Self {
input_mapping,
decluttered: false,
body: new_body,
..self.clone()
};
let output_wires =
outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
for w in output_wires {
outside_patch
.shunt_outside(model, OutletId::new(node.id, w.slot), w)
.with_context(|| "patching outer model")?;
}
return Ok(Some(outside_patch));
}
}
}
}
Ok(None)
}
fn declutter_pull_constant_outputs(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
if let Some(slot) = mapping.last_value_slot {
if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
let inner_node = self.body.output_outlets()?[model_output_ix].node;
let inner_node = self.body.node(inner_node);
let mut patch =
TypedModelPatch::new(format!("Extract const node {inner_node}"));
let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
return Ok(Some(patch));
}
}
}
Ok(None)
}
fn declutter_pull_batcheable_output(
&self,
_session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
for (mapping_ix, mapping) in self.output_mapping.iter().enumerate() {
if let Some((_, scan_info)) = mapping.scan {
let emitter_outlet = self.body.output_outlets()?[mapping_ix];
if self.body.node(emitter_outlet.node).outputs[emitter_outlet.slot].successors.len()
> 0
|| self.body.inputs.contains(&emitter_outlet)
|| mapping.state
|| mapping.scan.map(|(_slot, i)| i.chunk > 1).unwrap_or(true)
{
continue;
}
let mut new_body = self.body.clone();
if let Some(einsum) = new_body.node(emitter_outlet.node).op_as::<EinSum>() {
if let Some(patch) = einsum
.propagate_axis(
&new_body,
new_body.node(emitter_outlet.node),
InOut::Out(0),
scan_info.axis,
)
.context("building axis propagating patch")?
{
patch.apply(&mut new_body)?;
}
}
let emitter_outlet = new_body.output_outlets()?[mapping_ix];
let invariants = {
let (input_facts, output_facts) = new_body.node_facts(emitter_outlet.node)?;
new_body
.node(emitter_outlet.node)
.op
.axes_mapping(&input_facts, &output_facts)?
};
let axis_tracking =
invariants.axis((InOut::Out(emitter_outlet.slot), scan_info.axis))?;
if axis_tracking.outputs.iter().any(|o| o.len() > 1) {
return Ok(None);
}
let mut new_output_mapping = self.output_mapping.clone();
let mut new_scan_outputs = node.outputs.len();
let mut outer_slots = vec![];
for (input_slot, input) in
new_body.node(emitter_outlet.node).inputs.clone().iter().enumerate()
{
if new_body.outputs.iter().all(|o| o != input) {
new_output_mapping.push(OutputMapping::default());
new_body.outputs.push(*input);
}
let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
let mapping = &mut new_output_mapping[body_output_id];
let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
if mapping.last_value_slot.is_none() {
mapping.last_value_slot = Some(new_scan_outputs);
new_scan_outputs += 1;
}
mapping.last_value_slot.unwrap()
} else if let &[axis] = &*axis_tracking.inputs[input_slot] {
if mapping.scan.is_none() {
mapping.scan =
Some((new_scan_outputs, ScanInfo { axis, chunk: scan_info.chunk }));
new_scan_outputs += 1;
}
mapping.scan.unwrap().0
} else {
return Ok(None);
};
outer_slots.push(outer_slot);
}
let mut outside_patch = TypedModelPatch::new(format!(
"Outside patch for output extraction of {}",
new_body.node(emitter_outlet.node)
));
let inputs = node
.inputs
.iter()
.map(|&i| outside_patch.tap_model(model, i))
.collect::<TractResult<TVec<_>>>()?;
let new_op = Self {
output_mapping: new_output_mapping,
decluttered: false,
body: new_body.clone(), ..self.clone()
};
let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
let output = mapping.scan.unwrap();
let inputs =
outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
let wire = outside_patch.wire_node(
&new_body.node(emitter_outlet.node).name,
new_body.node(emitter_outlet.node).op.clone(),
&inputs,
)?[0];
outside_patch.shunt_outside(model, OutletId::new(node.id, output.0), wire)?;
for output_slot in 0..node.outputs.len() {
if output_slot != output.0 {
outside_patch.shunt_outside(
model,
OutletId::new(node.id, output_slot),
OutletId::new(scan_outputs[0].node, output_slot),
)?;
}
}
return Ok(Some(outside_patch));
}
}
Ok(None)
}
fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
let input_state_outlets = self
.input_mapping
.iter()
.zip(self.body.input_outlets()?.iter())
.filter(|(m, _)| m.is_state())
.map(|(_, o)| o);
let output_state_outlets = self
.output_mapping
.iter()
.zip(self.body.output_outlets()?.iter())
.filter(|(m, _)| m.state)
.map(|(_, o)| o);
Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
}
fn body_locked_outlets(&self, node_input_facts: &[&TypedFact]) -> TractResult<TVec<OutletId>> {
let input_outlets =
self.body.input_outlets()?.iter().enumerate().filter_map(|(slot, o)| {
if node_input_facts[slot].konst.is_none() {
Some(o)
} else {
None
}
});
let output_outlets = self
.output_mapping
.iter()
.zip(self.body.output_outlets()?.iter())
.filter(|(m, _)| !m.invisible())
.map(|(_, o)| o);
Ok(input_outlets.chain(output_outlets).cloned().collect())
}
fn try_body_axes_change(
&self,
change: AxisChange,
locked_interface: bool,
node_input_facts: &[&TypedFact],
) -> TractResult<Option<AxisChangeConsequence>> {
self.body.check_consistency()?;
let locked_outlets = self.body_locked_outlets(node_input_facts)?;
let (body_patch, body_changed_wires) = if let Some(changes) =
crate::optim::change_axes::change_axes(
&self.body,
&change,
if locked_interface { &locked_outlets } else { &[] },
&self.body_bounds()?,
)? {
changes
} else {
return Ok(None);
};
let mut body = self.body.clone();
body_patch.apply(&mut body)?;
body.compact()?;
let mut wire_changes = tvec!();
let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
for (slot, m) in input_mapping.iter_mut().enumerate() {
if let Some(change) = body_changed_wires
.iter()
.find(|(iface, _change)| iface == &InOut::In(slot))
.map(|pair| pair.1.clone())
{
wire_changes.push((InOut::In(slot), change.clone()));
if let InputMapping::Scan(info) = m {
if let Some(axis) = change.transform_axis(info.axis) {
info.axis = axis;
} else {
return Ok(None);
};
};
}
}
let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
for (ix, m) in output_mapping.iter_mut().enumerate() {
if let Some(change) = body_changed_wires
.iter()
.find(|(iface, _change)| iface == &InOut::Out(ix))
.map(|pair| pair.1.clone())
{
if let Some((slot, info)) = m.scan.as_mut() {
if let Some(new_axis) = change.transform_axis(info.axis) {
info.axis = new_axis;
} else {
return Ok(None);
}
wire_changes.push((InOut::Out(*slot), change.clone()));
}
if let Some(slot) = m.last_value_slot {
wire_changes.push((InOut::Out(slot), change.clone()));
}
};
}
body.check_consistency()?;
let op = Some(Box::new(Scan {
body,
input_mapping,
output_mapping,
decluttered: false,
..self.clone()
}) as _);
Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
}
}
impl Op for Scan {
fn name(&self) -> Cow<str> {
"Scan".into()
}
fn info(&self) -> TractResult<Vec<String>> {
let mut lines = vec![];
for (ix, im) in self.input_mapping.iter().enumerate() {
lines.push(format!("Model input #{ix}: {im:?}"));
}
for (ix, om) in self.output_mapping.iter().enumerate() {
lines.push(format!("Model output #{ix}: {om:?}"));
}
lines.push(format!("skip:{} reset_every_turn:{:?}", self.skip, self.reset_every_turn));
Ok(lines)
}
fn validation(&self) -> Validation {
Validation::Rounding
}
op_as_typed_op!();
}
impl EvalOp for Scan {
fn is_stateless(&self) -> bool {
false
}
fn state(
&self,
session: &mut SessionState,
node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
self.to_codegen_op(false)?.state(session, node_id)
}
}
impl TypedOp for Scan {
as_op!();
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
anyhow::ensure!(inputs.len() == self.body.inputs.len());
anyhow::ensure!(self.input_mapping.len() == self.body.inputs.len());
anyhow::ensure!(
self.input_mapping.iter().filter(|m| m.is_state()).count()
== self.output_mapping.iter().filter(|m| m.state).count()
);
for (i, o) in
self.input_mapping.iter().enumerate().filter(|(_, m)| m.is_state()).map(|(i, _)| i).zip(
self.output_mapping.iter().enumerate().filter(|(_, m)| m.state).map(|(o, _)| o),
)
{
let ifact = self.body.outlet_fact(self.body.inputs[i])?;
let ofact = self.body.outlet_fact(self.body.outputs[o])?;
anyhow::ensure!(ifact == ofact,
"inconsistent state shape: body input {i} is {ifact:?} and body output {o} is {ofact:?}",
)
}
let mut outputs = tvec!();
let iters = super::iteration_count(&self.input_mapping, inputs).context("No scan input")?;
for (ix, output) in self.output_mapping.iter().enumerate() {
let fact = self.body.output_fact(ix)?;
if let Some((slot, info)) = output.scan {
let mut shape = fact.shape.clone();
let scanning_dim =
output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
shape.set(info.axis, scanning_dim);
outputs.push((slot, fact.datum_type.fact(shape)));
}
if let Some(slot) = output.last_value_slot {
outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
}
}
outputs.sort_by_key(|a| a.0);
anyhow::ensure!(outputs.iter().enumerate().all(|(ix, (slot, _))| ix == *slot));
let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
Ok(outputs)
}
fn axes_mapping(
&self,
inputs: &[&TypedFact],
outputs: &[&TypedFact],
) -> TractResult<AxesMapping> {
let mut mappings = vec![];
let body_invs = self.body.axes_mapping().with_context(|| "Computing body axes mapping")?;
for body_axis in body_invs.iter_all_axes() {
let mut info = Axis::new(body_axis.repr, inputs.len(), outputs.len());
info.inputs = body_axis.inputs.clone();
for (ix, output_mapping) in self.output_mapping.iter().enumerate() {
let mut slots = vec![];
if let Some((slot, _scan)) = output_mapping.scan {
slots.push(slot);
}
if let Some(slot) = output_mapping.last_value_slot {
slots.push(slot);
}
for slot in slots {
info.outputs[slot] = body_axis.outputs[ix].clone();
}
}
if info.inputs.iter().any(|i| i.len() > 0) || info.outputs.iter().any(|i| i.len() > 0) {
mappings.push(info);
}
}
AxesMapping::new(inputs.len(), outputs.len(), mappings)
}
fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
let mut suggestions = tvec!();
for (slot, input) in self.input_mapping.iter().enumerate() {
if let InputMapping::Scan(info) = input {
if info.axis != 0 {
suggestions.push((InOut::In(slot), AxisOp::Move(info.axis, 0)))
}
}
}
for output in &self.output_mapping {
if let Some((slot, scan)) = output.scan {
if scan.axis != 0 {
suggestions.push((InOut::Out(slot), AxisOp::Move(scan.axis, 0)))
}
}
}
Ok(suggestions)
}
fn change_axes(
&self,
model: &TypedModel,
node: &TypedNode,
io: InOut,
change: &AxisOp,
) -> TractResult<Option<AxisChangeConsequence>> {
trace!("Propagating through {}: {:?} {:?}", node, io, change);
let body_leading_outlet = match io {
InOut::In(ix) => self.body.input_outlets()?[ix],
InOut::Out(slot) => {
let output = self
.output_mapping
.iter()
.position(|im| {
im.scan.map(|(slot, _i)| slot) == Some(slot)
|| im.last_value_slot == Some(slot)
})
.unwrap();
self.body.output_outlets()?[output]
}
};
let axis_change = AxisChange { outlet: body_leading_outlet, op: change.clone() };
let node_input_facts = model.node_input_facts(node.id)?;
let result = self
.try_body_axes_change(axis_change, false, &node_input_facts)
.with_context(|| "Attemping to run change through scan body".to_string())?;
if result.is_some() {
trace!("{} accepted axis change", node);
} else {
trace!("{} rejected axis change", node);
}
Ok(result)
}
fn declutter_with_session(
&self,
session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
macro_rules! pass {
($func:ident) => {
if let Some(mut r) = self
.$func(session, model, node)
.with_context(|| format!("{}", stringify!($func)))?
{
trace!(stringify!($func));
r.push_context(stringify!($func));
return Ok(Some(r));
}
};
}
pass!(declutter_const_input);
pass!(declutter_discard_unused_input);
pass!(declutter_discard_useless_outer_output);
pass!(declutter_discard_empty_output_mapping_with_body_output);
pass!(declutter_body);
pass!(declutter_body_axes);
pass!(declutter_pull_constant_outputs);
pass!(declutter_pull_batcheable_input);
pass!(declutter_pull_batcheable_output);
Ok(None)
}
fn concretize_dims(
&self,
_source: &TypedModel,
node: &TypedNode,
target: &mut TypedModel,
mapping: &HashMap<OutletId, OutletId>,
values: &SymbolValues,
) -> TractResult<TVec<OutletId>> {
let inputs = node.inputs.iter().map(|o| mapping[o]).collect::<TVec<_>>();
let op = Self {
output_mapping: self
.output_mapping
.iter()
.map(|om| om.concretize_dims(values))
.collect::<TractResult<Vec<_>>>()?,
body: self.body.concretize_dims(values)?,
..self.clone()
};
target.wire_node(&node.name, op, &inputs)
}
fn codegen(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
Ok(Some(TypedModelPatch::replace_single_op(
model,
node,
&node.inputs,
self.to_codegen_op(true)?,
)?))
}
}