use super::OptimizerSession;
use super::TypedPass;
use crate::internal::*;
use crate::model::*;
use crate::ops::einsum::EinSum;
use std::collections::HashSet;
use std::collections::hash_map::Entry;
use std::fmt::Debug;
use crate::ops::change_axes::*;
#[derive(Clone, Default)]
pub struct ChangeAxes(HashSet<crate::ops::change_axes::AxisChange>);
impl Debug for ChangeAxes {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ChangeAxes")
}
}
impl TypedPass for ChangeAxes {
fn reset(&mut self) -> TractResult<()> {
self.0.clear();
Ok(())
}
fn next(
&mut self,
_session: &mut OptimizerSession,
model: &TypedModel,
) -> TractResult<Option<TypedModelPatch>> {
let mut interfaces = model.output_outlets()?.to_vec();
interfaces.extend(model.input_outlets()?.iter());
for n in model.eval_order()? {
for suggestion in model.node(n).op.suggested_axis_changes()? {
let outlet = suggestion.0.as_outlet(model.node(n));
let change = AxisChange { outlet, op: suggestion.1 };
if self.0.insert(change.clone()) {
if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[])
.with_context(|| {
format!("Making patch for {:?} from {}", change, model.node(n))
})?
{
return Ok(Some(patch));
}
}
}
for (slot, fact) in model.node(n).outputs.iter().enumerate() {
for (ix, dim) in fact.fact.shape.iter().enumerate() {
if dim.is_one() {
let change =
AxisChange { outlet: OutletId::new(n, slot), op: AxisOp::Rm(ix) };
if self.0.insert(change.clone()) {
if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[])
.with_context(|| {
format!("Making patch for {:?} from {}", change, model.node(n))
})?
{
return Ok(Some(patch));
}
}
}
}
}
}
Ok(None)
}
}
#[allow(clippy::type_complexity)]
pub fn change_axes(
model: &TypedModel,
change: &AxisChange,
locked: &[OutletId],
bounds: &[TVec<OutletId>],
) -> TractResult<Option<(TypedModelPatch, TVec<(InOut, AxisOp)>)>> {
debug!(" Considering change {:?}", change);
let mut todo_changes = vec![(change.clone(), None)];
let mut changed_wires: HashMap<TVec<OutletId>, AxisOp> = HashMap::new();
let bound_outlets = |o: OutletId| -> TVec<OutletId> {
bounds.iter().find(|b| b.contains(&o)).cloned().unwrap_or_else(|| tvec!(o))
};
changed_wires.insert(bound_outlets(change.outlet), change.op.clone());
let mut changed_ops: HashMap<usize, Box<dyn TypedOp>> = HashMap::new();
while let Some((c, emitter)) = todo_changes.pop() {
let outlet_group = bound_outlets(c.outlet);
for &outlet in &outlet_group {
if locked.contains(&outlet) {
debug!(" Change {:?} blocked by locked interface {:?}", change, outlet);
return Ok(None);
}
let mut interfaces = vec![(outlet.node, InOut::Out(outlet.slot))];
for inlet in model.outlet_successors(outlet) {
interfaces.push((inlet.node, InOut::In(inlet.slot)));
}
for (node_id, io) in interfaces {
if Some(node_id) == emitter {
continue;
}
let node = model.node(node_id);
let op = if let Some(op) = changed_ops.get(&node_id) {
trace!(" Change {:?} revisiting {}", change, model.node(node_id));
if op.is::<EinSum>() {
op
} else {
debug!(" Change {:?} blocked: revisiting {}", change, model.node(node_id));
return Ok(None);
}
} else {
&node.op
};
let more = op
.change_axes(model, node, io, &c.op)
.with_context(|| format!("Propagating {change:?} to node {node}"))?;
if more.is_none() {
debug!(" Propagation of {:?} blocked by {}", change, node);
return Ok(None);
}
let AxisChangeConsequence { substitute_op, wire_changes } = more.unwrap();
trace!(" Change {:?} enters {} from {:?}", c.op, node, io);
trace!(" propagates as {:?}", wire_changes);
if let Some(op) = substitute_op {
trace!(" replace op by {:?}", op);
changed_ops.insert(node.id, op);
}
for (wire, op) in wire_changes.into_iter() {
let outlet_group = bound_outlets(wire.as_outlet(node));
match changed_wires.entry(outlet_group.clone()) {
Entry::Vacant(entry) => {
trace!(
" {:?} {:?} change on {:?} is new",
wire,
op,
outlet_group
);
entry.insert(op.clone());
todo_changes
.push((AxisChange { outlet: outlet_group[0], op }, Some(node_id)));
}
Entry::Occupied(previous) => {
if *previous.get() == op {
trace!(
" {:?} {:?} change on {:?} already done",
wire,
op,
outlet_group
);
} else {
debug!(
" {:?} {:?} change on {:?} conflicting with {:?}. Blocked.",
wire,
op,
outlet_group,
previous
);
return Ok(None);
}
}
}
}
}
}
}
debug!("Translating {:?} to patch", change);
let mut patch = TypedModelPatch::new(format!("{change:?}"));
let mut replaced_wires: HashMap<OutletId, OutletId> = HashMap::default();
let nodes_to_replace = changed_wires
.keys()
.flat_map(|outlets| outlets.iter().map(|o| o.node))
.chain(changed_ops.keys().copied())
.collect::<std::collections::HashSet<usize>>();
for node_id in model.eval_order()? {
let node = model.node(node_id);
if nodes_to_replace.contains(&node_id) {
let mut inputs = tvec!();
for orig in &node.inputs {
let tgt = replaced_wires
.entry(*orig)
.or_insert_with(|| patch.tap_model(model, *orig).unwrap());
inputs.push(*tgt);
}
let op: Box<dyn TypedOp> =
changed_ops.get(&node_id).cloned().unwrap_or_else(|| node.op.clone());
let new_wires = patch.wire_node(&node.name, op, &inputs)?;
if new_wires.len() == 1
&& patch.node(new_wires[0].node).op_is::<crate::ops::source::TypedSource>()
{
patch.inputs.insert(new_wires[0].node, node_id);
}
for (ix, w) in new_wires.iter().enumerate() {
replaced_wires.insert((node_id, ix).into(), *w);
}
} else {
for orig in &node.inputs {
if let Some(replacement) = replaced_wires.get(orig) {
patch.shunt_outside(model, *orig, *replacement)?;
}
}
}
}
for output in model.output_outlets()? {
if let Some(replacement) = replaced_wires.get(output) {
unsafe {
patch.shunt_outside_unchecked(*output, *replacement)?;
}
}
}
let mut interface_change = tvec!();
for (ix, input) in model.input_outlets()?.iter().enumerate() {
if let Some(change) = changed_wires.get(&bound_outlets(*input)) {
interface_change.push((InOut::In(ix), change.clone()));
}
}
for (ix, output) in model.output_outlets()?.iter().enumerate() {
if let Some(change) = changed_wires.get(&bound_outlets(*output)) {
interface_change.push((InOut::Out(ix), change.clone()));
}
}
debug!("Patch ready for {:?}", change);
Ok(Some((patch, interface_change)))
}