Skip to main content

tract_core/optim/
change_axes.rs

1use super::OptimizerSession;
2use super::TypedPass;
3use crate::internal::*;
4use crate::model::*;
5use crate::ops::dummy::Dummy;
6use crate::ops::einsum::EinSum;
7use crate::ops::konst::Const;
8use std::collections::HashSet;
9use std::collections::hash_map::Entry;
10use std::fmt::Debug;
11
12use crate::ops::change_axes::*;
13
14#[derive(Clone, Default)]
15pub struct ChangeAxes(HashSet<crate::ops::change_axes::AxisChange>, usize);
16
17impl Debug for ChangeAxes {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        write!(f, "ChangeAxes")
20    }
21}
22
23impl TypedPass for ChangeAxes {
24    fn reset(&mut self) -> TractResult<()> {
25        self.0.clear();
26        self.1 = 0;
27        Ok(())
28    }
29    fn next(
30        &mut self,
31        _session: &mut OptimizerSession,
32        model: &TypedModel,
33    ) -> TractResult<Option<TypedModelPatch>> {
34        let mut explored: HashSet<AxisChange> = Default::default();
35        let mut interfaces = model.output_outlets()?.to_vec();
36        interfaces.extend(model.input_outlets()?.iter());
37        for node in &model.nodes[self.1..] {
38            if node.op_is::<Dummy>() {
39                continue;
40            }
41            for suggestion in node.op.suggested_axis_changes()? {
42                let outlet = suggestion.0.as_outlet(node);
43                let change = AxisChange { outlet, op: suggestion.1 };
44                if self.0.insert(change.clone())
45                    && let Some((patch, _)) =
46                        change_axes(model, &change, &interfaces, &[], &mut explored)
47                            .with_context(|| format!("Making patch for {change:?} from {node}"))?
48                {
49                    self.1 = node.id;
50                    return Ok(Some(patch));
51                }
52            }
53            for (slot, fact) in node.outputs.iter().enumerate() {
54                for (ix, dim) in fact.fact.shape.iter().enumerate() {
55                    if dim.is_one() {
56                        let change =
57                            AxisChange { outlet: OutletId::new(node.id, slot), op: AxisOp::Rm(ix) };
58                        if self.0.insert(change.clone())
59                            && let Some((patch, _)) =
60                                change_axes(model, &change, &interfaces, &[], &mut explored)
61                                    .with_context(|| {
62                                        format!("Making patch for {change:?} from {node}")
63                                    })?
64                        {
65                            self.1 = node.id;
66                            return Ok(Some(patch));
67                        }
68                    }
69                }
70            }
71        }
72        Ok(None)
73    }
74}
75
76#[allow(clippy::type_complexity)]
77pub fn change_axes(
78    model: &TypedModel,
79    change: &AxisChange,
80    locked: &[OutletId],
81    bounds: &[TVec<OutletId>],
82    explored: &mut HashSet<AxisChange>,
83) -> TractResult<Option<(TypedModelPatch, TVec<(InOut, AxisOp)>)>> {
84    if explored.contains(change) {
85        debug!("  Not considering change because deja vu {change:?}");
86        return Ok(None);
87    }
88    if model
89        .node(change.outlet.node)
90        .op_as::<Const>()
91        .is_some_and(|c| c.val().volume() == 1 && c.val().is_plain())
92    {
93        debug!("  Not considering change from const {change:?}");
94        return Ok(None);
95    }
96    debug!("  Considering change {change:?}");
97    let mut todo_changes = vec![(change.clone(), None)];
98    let mut changed_wires: HashMap<TVec<OutletId>, AxisOp> = HashMap::new();
99    let bound_outlets = |o: OutletId| -> TVec<OutletId> {
100        bounds.iter().find(|b| b.contains(&o)).cloned().unwrap_or_else(|| tvec!(o))
101    };
102    changed_wires.insert(bound_outlets(change.outlet), change.op.clone());
103    let mut changed_ops: HashMap<usize, Box<dyn TypedOp>> = HashMap::new();
104    let mut rewired_scalar_input: HashMap<InletId, (OutletId, AxisOp)> = Default::default();
105    while let Some((change, emitter)) = todo_changes.pop() {
106        rule_if!(explored.insert(change.clone()));
107        let outlet_group = bound_outlets(change.outlet);
108        for &outlet in &outlet_group {
109            if locked.contains(&outlet) {
110                debug!("    Change {change:?} blocked by locked interface {outlet:?}");
111                return Ok(None);
112            }
113            let mut interfaces: Vec<(usize, InOut)> = vec![(outlet.node, InOut::Out(outlet.slot))];
114            for inlet in model.outlet_successors(outlet) {
115                interfaces.push((inlet.node, InOut::In(inlet.slot)));
116            }
117            for (node_id, io) in interfaces {
118                if Some(node_id) == emitter {
119                    continue;
120                }
121                let node = model.node(node_id);
122                // if this is a revisit...
123                let op = if let Some(op) = changed_ops.get(&node_id) {
124                    trace!("  Change {:?} revisiting {}", change, model.node(node_id));
125                    if op.is::<EinSum>() {
126                        // FIXME Einsum can swallow any combination of axis change on all interfaces
127                        op
128                    } else {
129                        debug!("  Change {:?} blocked: revisiting {}", change, model.node(node_id));
130                        return Ok(None);
131                    }
132                } else {
133                    &node.op
134                };
135                let more = op
136                    .change_axes(model, node, io, &change.op)
137                    .with_context(|| format!("Propagating {change:?} to node {node}"))?;
138                if more.is_none() {
139                    debug!("    Propagation of {change:?} blocked by {node}");
140                    return Ok(None);
141                }
142                let AxisChangeConsequence { substitute_op, wire_changes } = more.unwrap();
143                trace!("    Change {:?} enters {} from {:?}", change.op, node, io);
144                trace!("       propagates as {wire_changes:?}");
145                if let Some(op) = substitute_op {
146                    trace!("       replace op by {op:?}");
147                    changed_ops.insert(node.id, op);
148                }
149                for (wire, op) in wire_changes.into_iter() {
150                    let outlet = wire.as_outlet(node);
151                    // stop upstram propagation to a scalar constant: we will clone it and alter it
152                    // at patch generation time
153                    if let InOut::In(inlet) = wire
154                        && model
155                            .node(outlet.node)
156                            .op_as::<Const>()
157                            .is_some_and(|k| k.val().volume() == 1)
158                    {
159                        rewired_scalar_input.insert(InletId::new(node.id, inlet), (outlet, op));
160                        continue;
161                    }
162                    let outlet_group = bound_outlets(wire.as_outlet(node));
163                    match changed_wires.entry(outlet_group.clone()) {
164                        Entry::Vacant(entry) => {
165                            trace!("         {wire:?} {op:?} change on {outlet_group:?} is new");
166                            entry.insert(op.clone());
167                            todo_changes
168                                .push((AxisChange { outlet: outlet_group[0], op }, Some(node_id)));
169                        }
170                        Entry::Occupied(previous) => {
171                            if *previous.get() == op {
172                                trace!(
173                                    "         {wire:?} {op:?} change on {outlet_group:?} already done"
174                                );
175                            } else {
176                                debug!(
177                                    "         {wire:?} {op:?} change on {outlet_group:?} conflicting with {previous:?}. Blocked."
178                                );
179                                return Ok(None);
180                            }
181                        }
182                    }
183                }
184            }
185        }
186    }
187    debug!("Translating {change:?} to patch");
188    let mut patch = TypedModelPatch::new(format!("{change:?}"));
189    let mut replaced_wires: HashMap<OutletId, OutletId> = HashMap::default();
190    let nodes_to_replace = changed_wires
191        .keys()
192        .flat_map(|outlets| outlets.iter().map(|o| o.node))
193        .chain(changed_ops.keys().copied())
194        .collect::<std::collections::HashSet<usize>>();
195    for node_id in model.eval_order()? {
196        let node = model.node(node_id);
197        if nodes_to_replace.contains(&node_id) {
198            let mut inputs = tvec!();
199            for (slot, orig) in node.inputs.iter().enumerate() {
200                let tgt = if let Some((outlet, alteration)) =
201                    rewired_scalar_input.get(&InletId::new(node_id, slot))
202                {
203                    let const_node = model.node(outlet.node);
204                    let mut value =
205                        const_node.op_as::<Const>().unwrap().val().clone().into_tensor();
206                    alteration.change_tensor(&mut value, false)?;
207                    let name = model.unique_name(&const_node.name);
208                    patch.add_const(name, value)?
209                } else {
210                    *replaced_wires
211                        .entry(*orig)
212                        .or_insert_with(|| patch.tap_model(model, *orig).unwrap())
213                };
214                inputs.push(tgt);
215            }
216            let op: Box<dyn TypedOp> =
217                changed_ops.get(&node_id).cloned().unwrap_or_else(|| node.op.clone());
218            let new_wires = patch
219                .wire_node(&node.name, op.clone(), &inputs)
220                .with_context(|| format!("wriring changed_op {op:?}"))?;
221            if new_wires.len() == 1
222                && patch.node(new_wires[0].node).op_is::<crate::ops::source::TypedSource>()
223            {
224                patch.inputs.insert(new_wires[0].node, node_id);
225            }
226            for (ix, w) in new_wires.iter().enumerate() {
227                replaced_wires.insert((node_id, ix).into(), *w);
228            }
229        } else {
230            for orig in &node.inputs {
231                if let Some(replacement) = replaced_wires.get(orig) {
232                    patch.shunt_outside(model, *orig, *replacement)?;
233                }
234            }
235        }
236    }
237    for output in model.output_outlets()? {
238        if let Some(replacement) = replaced_wires.get(output) {
239            unsafe {
240                patch.shunt_outside_unchecked(*output, *replacement)?;
241            }
242        }
243    }
244    let mut interface_change = tvec!();
245    for (ix, input) in model.input_outlets()?.iter().enumerate() {
246        if let Some(change) = changed_wires.get(&bound_outlets(*input)) {
247            interface_change.push((InOut::In(ix), change.clone()));
248        }
249    }
250    for (ix, output) in model.output_outlets()?.iter().enumerate() {
251        if let Some(change) = changed_wires.get(&bound_outlets(*output)) {
252            interface_change.push((InOut::Out(ix), change.clone()));
253        }
254    }
255    debug!("Patch ready for {change:?}");
256    Ok(Some((patch, interface_change)))
257}