1use std::collections::HashSet;
2use std::fmt::{Debug, Display};
3use std::ops::{Deref, DerefMut};
4
5use tract_data::itertools::{izip, Itertools};
6
7use crate::internal::*;
8use crate::model::*;
9
10#[derive(Clone, Debug)]
16pub struct ModelPatch<F, O>
17where
18 F: Fact + Clone + 'static,
19 O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
20{
21 pub context: Vec<String>,
23 pub dont_apply_twice: Option<String>,
26 pub model: Graph<F, O>,
28 pub inputs: HashMap<usize, usize>,
30 pub taps: HashMap<OutletId, OutletId>,
32 pub shunts: HashMap<OutletId, OutletId>,
34 pub obliterate: Vec<usize>,
36}
37
38impl<F, O> Default for ModelPatch<F, O>
39where
40 F: Fact + Clone + 'static,
41 O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
42{
43 fn default() -> ModelPatch<F, O> {
44 ModelPatch {
45 context: vec![],
46 dont_apply_twice: None,
47 model: Graph::default(),
48 inputs: HashMap::default(),
49 taps: HashMap::new(),
50 shunts: HashMap::new(),
51 obliterate: vec![],
52 }
53 }
54}
55
56impl<F, O> Deref for ModelPatch<F, O>
57where
58 F: Fact + Clone + 'static,
59 O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
60{
61 type Target = Graph<F, O>;
62 fn deref(&self) -> &Graph<F, O> {
63 &self.model
64 }
65}
66
67impl<F, O> DerefMut for ModelPatch<F, O>
68where
69 F: Fact + Clone + 'static,
70 O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
71{
72 fn deref_mut(&mut self) -> &mut Graph<F, O> {
73 &mut self.model
74 }
75}
76
77impl<F, O> ModelPatch<F, O>
78where
79 F: Fact + Clone + 'static,
80 O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
81 Graph<F, O>: SpecialOps<F, O>,
82{
83 pub fn new(s: impl Into<String>) -> Self {
84 Self::default().with_context(s)
85 }
86
87 pub fn push_context(&mut self, s: impl Into<String>) {
88 self.context.push(s.into());
89 }
90
91 pub fn with_context(mut self, s: impl Into<String>) -> Self {
92 self.context.push(s.into());
93 self
94 }
95
96 pub fn is_empty(&self) -> bool {
97 self.model.nodes.is_empty() && self.shunts.is_empty() && self.obliterate.is_empty()
98 }
99
100 pub fn tap_model(&mut self, model: &Graph<F, O>, outlet: OutletId) -> TractResult<OutletId> {
104 let fact = model.outlet_fact(outlet)?;
105 let id = self.add_source(
106 format!("tap.{}-{}/{}", model.node(outlet.node).name, outlet.node, outlet.slot),
107 dyn_clone::clone(fact),
108 )?;
109 self.taps.insert(id, outlet);
110 Ok(id)
111 }
112
113 pub fn taps<'a>(
117 &mut self,
118 model: &Graph<F, O>,
119 outlets: impl IntoIterator<Item = &'a OutletId>,
120 ) -> TractResult<TVec<OutletId>> {
121 outlets.into_iter().map(|o| self.tap_model(model, *o)).collect::<TractResult<TVec<_>>>()
122 }
123
124 pub unsafe fn shunt_outside_unchecked(
125 &mut self,
126 outlet: OutletId,
127 by: OutletId,
128 ) -> TractResult<()> {
129 self.shunts.insert(outlet, by);
130 Ok(())
131 }
132
133 pub fn shunt_outside(
135 &mut self,
136 model: &Graph<F, O>,
137 outlet: OutletId,
138 by: OutletId,
139 ) -> TractResult<()> {
140 let original_fact = model.outlet_fact(outlet)?;
141 let new_fact = self.model.outlet_fact(by)?;
142 if !original_fact.compatible_with(new_fact) {
143 bail!(
144 "Trying to substitute a {:?} by {:?} as output #{} of {}.\n{:?}",
145 original_fact,
146 new_fact,
147 outlet.slot,
148 model.node(outlet.node),
149 self
150 );
151 }
152 self.shunts.insert(outlet, by);
153 Ok(())
154 }
155
156 pub fn obliterate(&mut self, node: usize) -> TractResult<()> {
157 self.obliterate.push(node);
158 Ok(())
159 }
160
161 pub fn replace_single_op<IO: Into<O>>(
163 patched_model: &Graph<F, O>,
164 node: &Node<F, O>,
165 inputs: &[OutletId],
166 new_op: IO,
167 ) -> TractResult<ModelPatch<F, O>> {
168 let mut patch = ModelPatch::default();
169 let new_op = new_op.into();
170 let inputs = patch.taps(patched_model, inputs)?;
171 let wires = patch.wire_node(&node.name, new_op, &inputs)?;
172 for (ix, o) in wires.iter().enumerate() {
173 patch.shunt_outside(patched_model, OutletId::new(node.id, ix), *o)?;
174 }
175 patch.obliterate(node.id)?;
176 Ok(patch)
177 }
178
179 pub fn fuse_with_next<IO: Into<O>>(
181 patched_model: &Graph<F, O>,
182 node: &Node<F, O>,
183 new_op: IO,
184 ) -> TractResult<ModelPatch<F, O>> {
185 let mut patch = ModelPatch::default();
186 let succ = if let Some(succ) = patched_model.single_succ(node.id)? {
187 succ
188 } else {
189 bail!("Non single successor fuse attempt")
190 };
191 let inputs = patch.taps(patched_model, &node.inputs)?;
192 let output = patch.wire_node(&node.name, new_op.into(), &inputs)?;
193 patch.shunt_outside(patched_model, succ.id.into(), output[0])?;
194 Ok(patch)
195 }
196
197 pub fn shunt_one_op(
199 patched_model: &Graph<F, O>,
200 node: &Node<F, O>,
201 ) -> TractResult<Option<ModelPatch<F, O>>> {
202 ensure!(node.inputs.len() == 1);
203 ensure!(node.outputs.len() == 1);
204 if patched_model.outputs.contains(&node.id.into())
205 && patched_model.outputs.contains(&node.inputs[0])
206 {
207 Ok(None)
208 } else {
209 Self::rewire(patched_model, &node.inputs, &[node.id.into()], &|_p, xs| Ok(xs.into()))
210 .with_context(|| format!("Shunting {node}"))
211 .map(Some)
212 }
213 }
214
215 #[allow(clippy::type_complexity)]
216 pub fn rewire(
217 patched_model: &Graph<F, O>,
218 from: &[OutletId],
219 to: &[OutletId],
220 wiring: &dyn Fn(&mut Self, &[OutletId]) -> TractResult<TVec<OutletId>>,
221 ) -> TractResult<ModelPatch<F, O>> {
222 let mut patch = ModelPatch::default();
223 let taps = patch.taps(patched_model, from)?;
224 let news = wiring(&mut patch, &taps)?;
225 if news.len() != to.len() {
226 bail!(
227 "Wrong number of outputs for rewiring, expected {}, function returned {}",
228 to.len(),
229 news.len()
230 );
231 }
232 for (new, &old) in izip!(news, to) {
233 patch.shunt_outside(patched_model, old, new)?;
234 }
235 Ok(patch)
236 }
237
238 pub fn single_unary_op<IO: Into<O>>(
240 patched_model: &Graph<F, O>,
241 node: &Node<F, O>,
242 new_op: IO,
243 ) -> TractResult<ModelPatch<F, O>> {
244 Self::replace_single_op(patched_model, node, &[node.inputs[0]], new_op)
245 }
246
247 pub fn intercept<IO: Into<O>>(
249 patched_model: &Graph<F, O>,
250 outlet: OutletId,
251 name: impl Into<String>,
252 new_op: IO,
253 fact: F,
254 ) -> TractResult<ModelPatch<F, O>> {
255 let mut patch = ModelPatch::default();
256 let tap = patch.tap_model(patched_model, outlet)?;
257 let new_id = patch.add_node(name, new_op, tvec!(fact))?;
258 patch.add_edge(tap, InletId::new(new_id, 0))?;
259 patch.shunt_outside(patched_model, outlet, OutletId::new(new_id, 0))?;
260 Ok(patch)
261 }
262
263 pub fn wire_node(
264 &mut self,
265 name: impl Into<String>,
266 op: impl Into<O>,
267 inputs: &[OutletId],
268 ) -> TractResult<TVec<OutletId>> {
269 let mut name = name.into();
270 if self.nodes.iter().any(|n| n.name == *name) {
271 for i in 1.. {
272 let s = format!("{name}#{i}");
273 if self.nodes.iter().all(|n| n.name != s) {
274 name = s;
275 break;
276 }
277 }
278 }
279 self.model.wire_node(name, op.into(), inputs)
280 }
281
282 pub fn apply(self, target: &mut Graph<F, O>) -> TractResult<()> {
284 let prior_target_inputs = target.input_outlets()?.len();
285 let prior_target_outputs = target.output_outlets()?.len();
286 let ModelPatch {
287 model: patch,
288 taps: mut mapping,
289 shunts: shunt_outlet_by,
290 obliterate,
291 inputs: replaced_inputs,
292 ..
293 } = self;
294 let mut all_inputs = HashMap::new(); let mut model_input_outlets = target.input_outlets()?.to_vec();
296 let mut new_nodes = HashSet::new();
297 for node in patch.nodes {
298 if <Graph<F, O>>::is_source(&node.op)
299 && mapping.contains_key(&OutletId::new(node.id, 0))
300 {
301 continue;
303 }
304 let Node { id: patch_node_id, name, inputs, op, outputs } = node;
305 let n_outputs = outputs.len();
306 for dup in 0..target.nodes.len() {
307 if target.node(dup).op().same_as(op.as_ref())
308 && inputs.len() == target.node(dup).inputs.len()
309 && inputs
310 .iter()
311 .zip(target.node(dup).inputs.iter())
312 .all(|(patch_input, d)| mapping[patch_input] == *d)
313 {
314 for ix in 0..n_outputs {
315 mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(dup, ix));
316 }
317 continue;
318 }
319 }
320 let facts = outputs.into_iter().map(|of| of.fact).collect();
321 let added_node_id = target.add_node(name, op, facts)?;
322 new_nodes.insert(added_node_id);
323 for ix in 0..n_outputs {
324 mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(added_node_id, ix));
325 }
326 all_inputs.insert(added_node_id, inputs);
327 if <Graph<F, O>>::is_source(&target.node(added_node_id).op) {
328 model_input_outlets.iter_mut().for_each(|oo| {
330 if oo.node == replaced_inputs[&patch_node_id] {
331 oo.node = added_node_id;
332 }
333 });
334 }
335 }
336 debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
337 debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
338 for (&outlet, &by) in shunt_outlet_by.iter().sorted() {
339 let replace_by = mapping[&by];
340 let succs = target.nodes()[outlet.node].outputs[outlet.slot].successors.clone();
341 for succ in succs {
342 target.add_edge(replace_by, succ)?;
343 }
344 for o in target.outputs.iter_mut() {
345 if *o == outlet {
346 *o = replace_by;
347 }
348 }
349 if let Some(label) = target.outlet_labels.remove(&outlet) {
350 target.set_outlet_label(replace_by, label)?;
351 }
352 }
353 debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
354 debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
355 for (&node, inputs) in all_inputs.iter().sorted() {
356 for (ix, input) in inputs.iter().enumerate() {
357 target.add_edge(mapping[input], InletId::new(node, ix))?;
358 }
359 }
360 debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
361 debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
362 for node in obliterate {
363 target.node_mut(node).op = target.create_dummy();
364 }
365 debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
366 debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
367 target.set_input_outlets(&model_input_outlets)?;
368 let mut maybe_garbage: HashSet<usize> = shunt_outlet_by.iter().map(|o| o.0.node).collect();
369 while let Some(&maybe) = maybe_garbage.iter().next() {
370 maybe_garbage.remove(&maybe);
371 if !target.outputs.iter().any(|output| output.node == maybe)
372 && !target.inputs.iter().any(|input| input.node == maybe)
373 && target.node(maybe).outputs.iter().all(|of| of.successors.is_empty())
374 {
375 target.node_mut(maybe).op = target.create_dummy();
376 target.node_mut(maybe).name = format!("Dummy-node-{}", maybe);
377 target.node_mut(maybe).outputs.clear(); let inputs = std::mem::take(&mut target.node_mut(maybe).inputs);
379 for &i in &inputs {
380 target.node_mut(i.node).outputs[i.slot].successors.retain(|s| s.node != maybe);
381 maybe_garbage.insert(i.node);
382 }
383 target.check_edges()?;
384 }
385 }
386 for n in new_nodes.iter() {
387 if let Some((prefix, _)) = target.nodes[*n].name.split_once('#') {
388 target.nodes[*n].name = target.unique_name(prefix).into();
389 } else if target
390 .nodes
391 .iter()
392 .any(|node| node.id != *n && target.nodes[*n].name == node.name)
393 {
394 target.nodes[*n].name = target.unique_name(&target.nodes[*n].name).to_string();
395 }
396 }
397 Ok(())
398 }
399}