tract_core/optim/
change_axes.rs

1use super::OptimizerSession;
2use super::TypedPass;
3use crate::internal::*;
4use crate::model::*;
5use crate::ops::dummy::Dummy;
6use crate::ops::einsum::EinSum;
7use crate::ops::konst::Const;
8use std::collections::hash_map::Entry;
9use std::collections::HashSet;
10use std::fmt::Debug;
11
12use crate::ops::change_axes::*;
13
14#[derive(Clone, Default)]
15pub struct ChangeAxes(HashSet<crate::ops::change_axes::AxisChange>, usize);
16
17impl Debug for ChangeAxes {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        write!(f, "ChangeAxes")
20    }
21}
22
23impl TypedPass for ChangeAxes {
24    fn reset(&mut self) -> TractResult<()> {
25        self.0.clear();
26        self.1 = 0;
27        Ok(())
28    }
29    fn next(
30        &mut self,
31        _session: &mut OptimizerSession,
32        model: &TypedModel,
33    ) -> TractResult<Option<TypedModelPatch>> {
34        let mut explored: HashSet<AxisChange> = Default::default();
35        let mut interfaces = model.output_outlets()?.to_vec();
36        interfaces.extend(model.input_outlets()?.iter());
37        for node in &model.nodes[self.1..] {
38            if node.op_is::<Dummy>() {
39                continue;
40            }
41            for suggestion in node.op.suggested_axis_changes()? {
42                let outlet = suggestion.0.as_outlet(node);
43                let change = AxisChange { outlet, op: suggestion.1 };
44                if self.0.insert(change.clone()) {
45                    if let Some((patch, _)) =
46                        change_axes(model, &change, &interfaces, &[], &mut explored).with_context(
47                            || format!("Making patch for {:?} from {}", change, node),
48                        )?
49                    {
50                        self.1 = node.id;
51                        return Ok(Some(patch));
52                    }
53                }
54            }
55            for (slot, fact) in node.outputs.iter().enumerate() {
56                for (ix, dim) in fact.fact.shape.iter().enumerate() {
57                    if dim.is_one() {
58                        let change =
59                            AxisChange { outlet: OutletId::new(node.id, slot), op: AxisOp::Rm(ix) };
60                        if self.0.insert(change.clone()) {
61                            if let Some((patch, _)) =
62                                change_axes(model, &change, &interfaces, &[], &mut explored)
63                                    .with_context(|| {
64                                        format!("Making patch for {:?} from {}", change, node)
65                                    })?
66                            {
67                                self.1 = node.id;
68                                return Ok(Some(patch));
69                            }
70                        }
71                    }
72                }
73            }
74        }
75        Ok(None)
76    }
77}
78
79#[allow(clippy::type_complexity)]
80pub fn change_axes(
81    model: &TypedModel,
82    change: &AxisChange,
83    locked: &[OutletId],
84    bounds: &[TVec<OutletId>],
85    explored: &mut HashSet<AxisChange>,
86) -> TractResult<Option<(TypedModelPatch, TVec<(InOut, AxisOp)>)>> {
87    if explored.contains(change) {
88        debug!("  Not considering change because deja vu {:?}", change);
89        return Ok(None);
90    }
91    if model
92        .node(change.outlet.node)
93        .op_as::<Const>()
94        .is_some_and(|c| c.0.volume() == 1 && c.0.datum_type() != Opaque::datum_type())
95    {
96        debug!("  Not considering change from const {:?}", change);
97        return Ok(None);
98    }
99    debug!("  Considering change {:?}", change);
100    let mut todo_changes = vec![(change.clone(), None)];
101    let mut changed_wires: HashMap<TVec<OutletId>, AxisOp> = HashMap::new();
102    let bound_outlets = |o: OutletId| -> TVec<OutletId> {
103        bounds.iter().find(|b| b.contains(&o)).cloned().unwrap_or_else(|| tvec!(o))
104    };
105    changed_wires.insert(bound_outlets(change.outlet), change.op.clone());
106    let mut changed_ops: HashMap<usize, Box<dyn TypedOp>> = HashMap::new();
107    let mut rewired_scalar_input: HashMap<InletId, (OutletId, AxisOp)> = Default::default();
108    while let Some((change, emitter)) = todo_changes.pop() {
109        if !explored.insert(change.clone()) {
110            return Ok(None);
111        }
112        let outlet_group = bound_outlets(change.outlet);
113        for &outlet in &outlet_group {
114            if locked.contains(&outlet) {
115                debug!("    Change {:?} blocked by locked interface {:?}", change, outlet);
116                return Ok(None);
117            }
118            let mut interfaces: Vec<(usize, InOut)> = vec![(outlet.node, InOut::Out(outlet.slot))];
119            for inlet in model.outlet_successors(outlet) {
120                interfaces.push((inlet.node, InOut::In(inlet.slot)));
121            }
122            for (node_id, io) in interfaces {
123                if Some(node_id) == emitter {
124                    continue;
125                }
126                let node = model.node(node_id);
127                // if this is a revisit...
128                let op = if let Some(op) = changed_ops.get(&node_id) {
129                    trace!("  Change {:?} revisiting {}", change, model.node(node_id));
130                    if op.is::<EinSum>() {
131                        // FIXME Einsum can swallow any combination of axis change on all interfaces
132                        op
133                    } else {
134                        debug!("  Change {:?} blocked: revisiting {}", change, model.node(node_id));
135                        return Ok(None);
136                    }
137                } else {
138                    &node.op
139                };
140                let more = op
141                    .change_axes(model, node, io, &change.op)
142                    .with_context(|| format!("Propagating {change:?} to node {node}"))?;
143                if more.is_none() {
144                    debug!("    Propagation of {:?} blocked by {}", change, node);
145                    return Ok(None);
146                }
147                let AxisChangeConsequence { substitute_op, wire_changes } = more.unwrap();
148                trace!("    Change {:?} enters {} from {:?}", change.op, node, io);
149                trace!("       propagates as {:?}", wire_changes);
150                if let Some(op) = substitute_op {
151                    trace!("       replace op by {:?}", op);
152                    changed_ops.insert(node.id, op);
153                }
154                for (wire, op) in wire_changes.into_iter() {
155                    let outlet = wire.as_outlet(node);
156                    // stop upstram propagation to a scalar constant: we will clone it and alter it
157                    // at patch generation time
158                    if let InOut::In(inlet) = wire {
159                        if model
160                            .node(outlet.node)
161                            .op_as::<Const>()
162                            .is_some_and(|k| k.0.volume() == 1)
163                        {
164                            rewired_scalar_input.insert(InletId::new(node.id, inlet), (outlet, op));
165                            continue;
166                        }
167                    }
168                    let outlet_group = bound_outlets(wire.as_outlet(node));
169                    match changed_wires.entry(outlet_group.clone()) {
170                        Entry::Vacant(entry) => {
171                            trace!(
172                                "         {:?} {:?} change on {:?} is new",
173                                wire,
174                                op,
175                                outlet_group
176                            );
177                            entry.insert(op.clone());
178                            todo_changes
179                                .push((AxisChange { outlet: outlet_group[0], op }, Some(node_id)));
180                        }
181                        Entry::Occupied(previous) => {
182                            if *previous.get() == op {
183                                trace!(
184                                    "         {:?} {:?} change on {:?} already done",
185                                    wire,
186                                    op,
187                                    outlet_group
188                                );
189                            } else {
190                                debug!(
191                                    "         {:?} {:?} change on {:?} conflicting with {:?}. Blocked.",
192                                    wire,
193                                    op,
194                                    outlet_group,
195                                    previous
196                                    );
197                                return Ok(None);
198                            }
199                        }
200                    }
201                }
202            }
203        }
204    }
205    debug!("Translating {:?} to patch", change);
206    let mut patch = TypedModelPatch::new(format!("{change:?}"));
207    let mut replaced_wires: HashMap<OutletId, OutletId> = HashMap::default();
208    let nodes_to_replace = changed_wires
209        .keys()
210        .flat_map(|outlets| outlets.iter().map(|o| o.node))
211        .chain(changed_ops.keys().copied())
212        .collect::<std::collections::HashSet<usize>>();
213    for node_id in model.eval_order()? {
214        let node = model.node(node_id);
215        if nodes_to_replace.contains(&node_id) {
216            let mut inputs = tvec!();
217            for (slot, orig) in node.inputs.iter().enumerate() {
218                let tgt = if let Some((outlet, alteration)) =
219                    rewired_scalar_input.get(&InletId::new(node_id, slot))
220                {
221                    let const_node = model.node(outlet.node);
222                    let mut value = const_node.op_as::<Const>().unwrap().0.clone().into_tensor();
223                    alteration.change_tensor(&mut value, false)?;
224                    let name = model.unique_name(&const_node.name);
225                    patch.add_const(name, value)?
226                } else {
227                    *replaced_wires
228                        .entry(*orig)
229                        .or_insert_with(|| patch.tap_model(model, *orig).unwrap())
230                };
231                inputs.push(tgt);
232            }
233            let op: Box<dyn TypedOp> =
234                changed_ops.get(&node_id).cloned().unwrap_or_else(|| node.op.clone());
235            let new_wires = patch.wire_node(&node.name, op, &inputs)?;
236            if new_wires.len() == 1
237                && patch.node(new_wires[0].node).op_is::<crate::ops::source::TypedSource>()
238            {
239                patch.inputs.insert(new_wires[0].node, node_id);
240            }
241            for (ix, w) in new_wires.iter().enumerate() {
242                replaced_wires.insert((node_id, ix).into(), *w);
243            }
244        } else {
245            for orig in &node.inputs {
246                if let Some(replacement) = replaced_wires.get(orig) {
247                    patch.shunt_outside(model, *orig, *replacement)?;
248                }
249            }
250        }
251    }
252    for output in model.output_outlets()? {
253        if let Some(replacement) = replaced_wires.get(output) {
254            unsafe {
255                patch.shunt_outside_unchecked(*output, *replacement)?;
256            }
257        }
258    }
259    let mut interface_change = tvec!();
260    for (ix, input) in model.input_outlets()?.iter().enumerate() {
261        if let Some(change) = changed_wires.get(&bound_outlets(*input)) {
262            interface_change.push((InOut::In(ix), change.clone()));
263        }
264    }
265    for (ix, output) in model.output_outlets()?.iter().enumerate() {
266        if let Some(change) = changed_wires.get(&bound_outlets(*output)) {
267            interface_change.push((InOut::Out(ix), change.clone()));
268        }
269    }
270    debug!("Patch ready for {:?}", change);
271    Ok(Some((patch, interface_change)))
272}