tract_core/optim/
change_axes.rs1use super::OptimizerSession;
2use super::TypedPass;
3use crate::internal::*;
4use crate::model::*;
5use crate::ops::dummy::Dummy;
6use crate::ops::einsum::EinSum;
7use crate::ops::konst::Const;
8use std::collections::HashSet;
9use std::collections::hash_map::Entry;
10use std::fmt::Debug;
11
12use crate::ops::change_axes::*;
13
14#[derive(Clone, Default)]
15pub struct ChangeAxes(HashSet<crate::ops::change_axes::AxisChange>, usize);
16
17impl Debug for ChangeAxes {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 write!(f, "ChangeAxes")
20 }
21}
22
23impl TypedPass for ChangeAxes {
24 fn reset(&mut self) -> TractResult<()> {
25 self.0.clear();
26 self.1 = 0;
27 Ok(())
28 }
29 fn next(
30 &mut self,
31 _session: &mut OptimizerSession,
32 model: &TypedModel,
33 ) -> TractResult<Option<TypedModelPatch>> {
34 let mut explored: HashSet<AxisChange> = Default::default();
35 let mut interfaces = model.output_outlets()?.to_vec();
36 interfaces.extend(model.input_outlets()?.iter());
37 for node in &model.nodes[self.1..] {
38 if node.op_is::<Dummy>() {
39 continue;
40 }
41 for suggestion in node.op.suggested_axis_changes()? {
42 let outlet = suggestion.0.as_outlet(node);
43 let change = AxisChange { outlet, op: suggestion.1 };
44 if self.0.insert(change.clone()) {
45 if let Some((patch, _)) =
46 change_axes(model, &change, &interfaces, &[], &mut explored)
47 .with_context(|| format!("Making patch for {change:?} from {node}"))?
48 {
49 self.1 = node.id;
50 return Ok(Some(patch));
51 }
52 }
53 }
54 for (slot, fact) in node.outputs.iter().enumerate() {
55 for (ix, dim) in fact.fact.shape.iter().enumerate() {
56 if dim.is_one() {
57 let change =
58 AxisChange { outlet: OutletId::new(node.id, slot), op: AxisOp::Rm(ix) };
59 if self.0.insert(change.clone()) {
60 if let Some((patch, _)) =
61 change_axes(model, &change, &interfaces, &[], &mut explored)
62 .with_context(|| {
63 format!("Making patch for {change:?} from {node}")
64 })?
65 {
66 self.1 = node.id;
67 return Ok(Some(patch));
68 }
69 }
70 }
71 }
72 }
73 }
74 Ok(None)
75 }
76}
77
78#[allow(clippy::type_complexity)]
79pub fn change_axes(
80 model: &TypedModel,
81 change: &AxisChange,
82 locked: &[OutletId],
83 bounds: &[TVec<OutletId>],
84 explored: &mut HashSet<AxisChange>,
85) -> TractResult<Option<(TypedModelPatch, TVec<(InOut, AxisOp)>)>> {
86 if explored.contains(change) {
87 debug!(" Not considering change because deja vu {change:?}");
88 return Ok(None);
89 }
90 if model
91 .node(change.outlet.node)
92 .op_as::<Const>()
93 .is_some_and(|c| c.val().volume() == 1 && c.val().datum_type() != Opaque::datum_type())
94 {
95 debug!(" Not considering change from const {change:?}");
96 return Ok(None);
97 }
98 debug!(" Considering change {change:?}");
99 let mut todo_changes = vec![(change.clone(), None)];
100 let mut changed_wires: HashMap<TVec<OutletId>, AxisOp> = HashMap::new();
101 let bound_outlets = |o: OutletId| -> TVec<OutletId> {
102 bounds.iter().find(|b| b.contains(&o)).cloned().unwrap_or_else(|| tvec!(o))
103 };
104 changed_wires.insert(bound_outlets(change.outlet), change.op.clone());
105 let mut changed_ops: HashMap<usize, Box<dyn TypedOp>> = HashMap::new();
106 let mut rewired_scalar_input: HashMap<InletId, (OutletId, AxisOp)> = Default::default();
107 while let Some((change, emitter)) = todo_changes.pop() {
108 rule_if!(explored.insert(change.clone()));
109 let outlet_group = bound_outlets(change.outlet);
110 for &outlet in &outlet_group {
111 if locked.contains(&outlet) {
112 debug!(" Change {change:?} blocked by locked interface {outlet:?}");
113 return Ok(None);
114 }
115 let mut interfaces: Vec<(usize, InOut)> = vec![(outlet.node, InOut::Out(outlet.slot))];
116 for inlet in model.outlet_successors(outlet) {
117 interfaces.push((inlet.node, InOut::In(inlet.slot)));
118 }
119 for (node_id, io) in interfaces {
120 if Some(node_id) == emitter {
121 continue;
122 }
123 let node = model.node(node_id);
124 let op = if let Some(op) = changed_ops.get(&node_id) {
126 trace!(" Change {:?} revisiting {}", change, model.node(node_id));
127 if op.is::<EinSum>() {
128 op
130 } else {
131 debug!(" Change {:?} blocked: revisiting {}", change, model.node(node_id));
132 return Ok(None);
133 }
134 } else {
135 &node.op
136 };
137 let more = op
138 .change_axes(model, node, io, &change.op)
139 .with_context(|| format!("Propagating {change:?} to node {node}"))?;
140 if more.is_none() {
141 debug!(" Propagation of {change:?} blocked by {node}");
142 return Ok(None);
143 }
144 let AxisChangeConsequence { substitute_op, wire_changes } = more.unwrap();
145 trace!(" Change {:?} enters {} from {:?}", change.op, node, io);
146 trace!(" propagates as {wire_changes:?}");
147 if let Some(op) = substitute_op {
148 trace!(" replace op by {op:?}");
149 changed_ops.insert(node.id, op);
150 }
151 for (wire, op) in wire_changes.into_iter() {
152 let outlet = wire.as_outlet(node);
153 if let InOut::In(inlet) = wire {
156 if model
157 .node(outlet.node)
158 .op_as::<Const>()
159 .is_some_and(|k| k.val().volume() == 1)
160 {
161 rewired_scalar_input.insert(InletId::new(node.id, inlet), (outlet, op));
162 continue;
163 }
164 }
165 let outlet_group = bound_outlets(wire.as_outlet(node));
166 match changed_wires.entry(outlet_group.clone()) {
167 Entry::Vacant(entry) => {
168 trace!(" {wire:?} {op:?} change on {outlet_group:?} is new");
169 entry.insert(op.clone());
170 todo_changes
171 .push((AxisChange { outlet: outlet_group[0], op }, Some(node_id)));
172 }
173 Entry::Occupied(previous) => {
174 if *previous.get() == op {
175 trace!(
176 " {wire:?} {op:?} change on {outlet_group:?} already done"
177 );
178 } else {
179 debug!(
180 " {wire:?} {op:?} change on {outlet_group:?} conflicting with {previous:?}. Blocked."
181 );
182 return Ok(None);
183 }
184 }
185 }
186 }
187 }
188 }
189 }
190 debug!("Translating {change:?} to patch");
191 let mut patch = TypedModelPatch::new(format!("{change:?}"));
192 let mut replaced_wires: HashMap<OutletId, OutletId> = HashMap::default();
193 let nodes_to_replace = changed_wires
194 .keys()
195 .flat_map(|outlets| outlets.iter().map(|o| o.node))
196 .chain(changed_ops.keys().copied())
197 .collect::<std::collections::HashSet<usize>>();
198 for node_id in model.eval_order()? {
199 let node = model.node(node_id);
200 if nodes_to_replace.contains(&node_id) {
201 let mut inputs = tvec!();
202 for (slot, orig) in node.inputs.iter().enumerate() {
203 let tgt = if let Some((outlet, alteration)) =
204 rewired_scalar_input.get(&InletId::new(node_id, slot))
205 {
206 let const_node = model.node(outlet.node);
207 let mut value =
208 const_node.op_as::<Const>().unwrap().val().clone().into_tensor();
209 alteration.change_tensor(&mut value, false)?;
210 let name = model.unique_name(&const_node.name);
211 patch.add_const(name, value)?
212 } else {
213 *replaced_wires
214 .entry(*orig)
215 .or_insert_with(|| patch.tap_model(model, *orig).unwrap())
216 };
217 inputs.push(tgt);
218 }
219 let op: Box<dyn TypedOp> =
220 changed_ops.get(&node_id).cloned().unwrap_or_else(|| node.op.clone());
221 let new_wires = patch
222 .wire_node(&node.name, op.clone(), &inputs)
223 .with_context(|| format!("wriring changed_op {op:?}"))?;
224 if new_wires.len() == 1
225 && patch.node(new_wires[0].node).op_is::<crate::ops::source::TypedSource>()
226 {
227 patch.inputs.insert(new_wires[0].node, node_id);
228 }
229 for (ix, w) in new_wires.iter().enumerate() {
230 replaced_wires.insert((node_id, ix).into(), *w);
231 }
232 } else {
233 for orig in &node.inputs {
234 if let Some(replacement) = replaced_wires.get(orig) {
235 patch.shunt_outside(model, *orig, *replacement)?;
236 }
237 }
238 }
239 }
240 for output in model.output_outlets()? {
241 if let Some(replacement) = replaced_wires.get(output) {
242 unsafe {
243 patch.shunt_outside_unchecked(*output, *replacement)?;
244 }
245 }
246 }
247 let mut interface_change = tvec!();
248 for (ix, input) in model.input_outlets()?.iter().enumerate() {
249 if let Some(change) = changed_wires.get(&bound_outlets(*input)) {
250 interface_change.push((InOut::In(ix), change.clone()));
251 }
252 }
253 for (ix, output) in model.output_outlets()?.iter().enumerate() {
254 if let Some(change) = changed_wires.get(&bound_outlets(*output)) {
255 interface_change.push((InOut::Out(ix), change.clone()));
256 }
257 }
258 debug!("Patch ready for {change:?}");
259 Ok(Some((patch, interface_change)))
260}