tract_core/model/
patch.rs

1use std::collections::HashSet;
2use std::fmt::{Debug, Display};
3use std::ops::{Deref, DerefMut};
4
5use tract_data::itertools::{izip, Itertools};
6
7use crate::internal::*;
8use crate::model::*;
9
10/// A change to apply to a model.
11///
12/// Actually structured around a model that represent the new nodes to be
13/// inserted, plus information about how to connect these new nodes to the
14/// pre-existing graph.
15#[derive(Clone, Debug)]
16pub struct ModelPatch<F, O>
17where
18    F: Fact + Clone + 'static,
19    O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
20{
21    /// patch label for auditing and debugging
22    pub context: Vec<String>,
23    /// optimizer will ignore this patch in node to node loop if it was already
24    /// encountered
25    pub dont_apply_twice: Option<String>,
26    /// the model-like 'patch' of nodes to add to the model
27    pub model: Graph<F, O>,
28    /// map of replaced inputs (patch node id to model node id)
29    pub inputs: HashMap<usize, usize>,
30    /// map of patch inputs to model wires
31    pub taps: HashMap<OutletId, OutletId>,
32    /// map of old model wires to be replaced by wires from the patch
33    pub shunts: HashMap<OutletId, OutletId>,
34    /// operations to discard from the model
35    pub obliterate: Vec<usize>,
36}
37
38impl<F, O> Default for ModelPatch<F, O>
39where
40    F: Fact + Clone + 'static,
41    O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
42{
43    fn default() -> ModelPatch<F, O> {
44        ModelPatch {
45            context: vec![],
46            dont_apply_twice: None,
47            model: Graph::default(),
48            inputs: HashMap::default(),
49            taps: HashMap::new(),
50            shunts: HashMap::new(),
51            obliterate: vec![],
52        }
53    }
54}
55
56impl<F, O> Deref for ModelPatch<F, O>
57where
58    F: Fact + Clone + 'static,
59    O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
60{
61    type Target = Graph<F, O>;
62    fn deref(&self) -> &Graph<F, O> {
63        &self.model
64    }
65}
66
67impl<F, O> DerefMut for ModelPatch<F, O>
68where
69    F: Fact + Clone + 'static,
70    O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
71{
72    fn deref_mut(&mut self) -> &mut Graph<F, O> {
73        &mut self.model
74    }
75}
76
77impl<F, O> ModelPatch<F, O>
78where
79    F: Fact + Clone + 'static,
80    O: Display + Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
81    Graph<F, O>: SpecialOps<F, O>,
82{
83    pub fn new(s: impl Into<String>) -> Self {
84        Self::default().with_context(s)
85    }
86
87    pub fn push_context(&mut self, s: impl Into<String>) {
88        self.context.push(s.into());
89    }
90
91    pub fn with_context(mut self, s: impl Into<String>) -> Self {
92        self.context.push(s.into());
93        self
94    }
95
96    pub fn is_empty(&self) -> bool {
97        self.model.nodes.is_empty() && self.shunts.is_empty() && self.obliterate.is_empty()
98    }
99
100    /// Draw a tap from a preexisting node.
101    ///
102    /// returns an OutletId usable in the little "patch" model
103    pub fn tap_model(&mut self, model: &Graph<F, O>, outlet: OutletId) -> TractResult<OutletId> {
104        let fact = model.outlet_fact(outlet)?;
105        let id = self.add_source(
106            format!("tap.{}-{}/{}", model.node(outlet.node).name, outlet.node, outlet.slot),
107            dyn_clone::clone(fact),
108        )?;
109        self.taps.insert(id, outlet);
110        Ok(id)
111    }
112
113    /// Draw taps from a preexisting node.
114    ///
115    /// returns an OutletId usable in the little "patch" model
116    pub fn taps<'a>(
117        &mut self,
118        model: &Graph<F, O>,
119        outlets: impl IntoIterator<Item = &'a OutletId>,
120    ) -> TractResult<TVec<OutletId>> {
121        outlets.into_iter().map(|o| self.tap_model(model, *o)).collect::<TractResult<TVec<_>>>()
122    }
123
124    pub unsafe fn shunt_outside_unchecked(
125        &mut self,
126        outlet: OutletId,
127        by: OutletId,
128    ) -> TractResult<()> {
129        self.shunts.insert(outlet, by);
130        Ok(())
131    }
132
133    /// Replace an Outlet in the target model by one from the patch.
134    pub fn shunt_outside(
135        &mut self,
136        model: &Graph<F, O>,
137        outlet: OutletId,
138        by: OutletId,
139    ) -> TractResult<()> {
140        let original_fact = model.outlet_fact(outlet)?;
141        let new_fact = self.model.outlet_fact(by)?;
142        if !original_fact.compatible_with(new_fact) {
143            bail!(
144                "Trying to substitute a {:?} by {:?} as output #{} of {}.\n{:?}",
145                original_fact,
146                new_fact,
147                outlet.slot,
148                model.node(outlet.node),
149                self
150            );
151        }
152        self.shunts.insert(outlet, by);
153        Ok(())
154    }
155
156    pub fn obliterate(&mut self, node: usize) -> TractResult<()> {
157        self.obliterate.push(node);
158        Ok(())
159    }
160
161    /// Convenience method creating a patch that replaces a single operation.
162    pub fn replace_single_op<IO: Into<O>>(
163        patched_model: &Graph<F, O>,
164        node: &Node<F, O>,
165        inputs: &[OutletId],
166        new_op: IO,
167    ) -> TractResult<ModelPatch<F, O>> {
168        let mut patch = ModelPatch::default();
169        let new_op = new_op.into();
170        let inputs = patch.taps(patched_model, inputs)?;
171        let wires = patch.wire_node(&node.name, new_op, &inputs)?;
172        for (ix, o) in wires.iter().enumerate() {
173            patch.shunt_outside(patched_model, OutletId::new(node.id, ix), *o)?;
174        }
175        patch.obliterate(node.id)?;
176        Ok(patch)
177    }
178
179    /// Convenience method creating a patch that fuses an op with the next one.
180    pub fn fuse_with_next<IO: Into<O>>(
181        patched_model: &Graph<F, O>,
182        node: &Node<F, O>,
183        new_op: IO,
184    ) -> TractResult<ModelPatch<F, O>> {
185        let mut patch = ModelPatch::default();
186        let succ = if let Some(succ) = patched_model.single_succ(node.id)? {
187            succ
188        } else {
189            bail!("Non single successor fuse attempt")
190        };
191        let inputs = patch.taps(patched_model, &node.inputs)?;
192        let output = patch.wire_node(&node.name, new_op.into(), &inputs)?;
193        patch.shunt_outside(patched_model, succ.id.into(), output[0])?;
194        Ok(patch)
195    }
196
197    /// Convenience method creating a patch that shunts the given node.
198    pub fn shunt_one_op(
199        patched_model: &Graph<F, O>,
200        node: &Node<F, O>,
201    ) -> TractResult<Option<ModelPatch<F, O>>> {
202        ensure!(node.inputs.len() == 1);
203        ensure!(node.outputs.len() == 1);
204        if patched_model.outputs.contains(&node.id.into())
205            && patched_model.outputs.contains(&node.inputs[0])
206        {
207            Ok(None)
208        } else {
209            Self::rewire(patched_model, &node.inputs, &[node.id.into()], &|_p, xs| Ok(xs.into()))
210                .with_context(|| format!("Shunting {node}"))
211                .map(Some)
212        }
213    }
214
215    #[allow(clippy::type_complexity)]
216    pub fn rewire(
217        patched_model: &Graph<F, O>,
218        from: &[OutletId],
219        to: &[OutletId],
220        wiring: &dyn Fn(&mut Self, &[OutletId]) -> TractResult<TVec<OutletId>>,
221    ) -> TractResult<ModelPatch<F, O>> {
222        let mut patch = ModelPatch::default();
223        let taps = patch.taps(patched_model, from)?;
224        let news = wiring(&mut patch, &taps)?;
225        if news.len() != to.len() {
226            bail!(
227                "Wrong number of outputs for rewiring, expected {}, function returned {}",
228                to.len(),
229                news.len()
230            );
231        }
232        for (new, &old) in izip!(news, to) {
233            patch.shunt_outside(patched_model, old, new)?;
234        }
235        Ok(patch)
236    }
237
238    /// Convenience method creating a patch that replace a single unary operation.
239    pub fn single_unary_op<IO: Into<O>>(
240        patched_model: &Graph<F, O>,
241        node: &Node<F, O>,
242        new_op: IO,
243    ) -> TractResult<ModelPatch<F, O>> {
244        Self::replace_single_op(patched_model, node, &[node.inputs[0]], new_op)
245    }
246
247    /// Convenience method creating a patch that insert an unary op on an outlet.
248    pub fn intercept<IO: Into<O>>(
249        patched_model: &Graph<F, O>,
250        outlet: OutletId,
251        name: impl Into<String>,
252        new_op: IO,
253        fact: F,
254    ) -> TractResult<ModelPatch<F, O>> {
255        let mut patch = ModelPatch::default();
256        let tap = patch.tap_model(patched_model, outlet)?;
257        let new_id = patch.add_node(name, new_op, tvec!(fact))?;
258        patch.add_edge(tap, InletId::new(new_id, 0))?;
259        patch.shunt_outside(patched_model, outlet, OutletId::new(new_id, 0))?;
260        Ok(patch)
261    }
262
263    pub fn wire_node(
264        &mut self,
265        name: impl Into<String>,
266        op: impl Into<O>,
267        inputs: &[OutletId],
268    ) -> TractResult<TVec<OutletId>> {
269        let mut name = name.into();
270        if self.nodes.iter().any(|n| n.name == *name) {
271            for i in 1.. {
272                let s = format!("{name}#{i}");
273                if self.nodes.iter().all(|n| n.name != s) {
274                    name = s;
275                    break;
276                }
277            }
278        }
279        self.model.wire_node(name, op.into(), inputs)
280    }
281
282    /// Apply all changes in the patch to the target model.
283    pub fn apply(self, target: &mut Graph<F, O>) -> TractResult<()> {
284        let prior_target_inputs = target.input_outlets()?.len();
285        let prior_target_outputs = target.output_outlets()?.len();
286        let ModelPatch {
287            model: patch,
288            taps: mut mapping,
289            shunts: shunt_outlet_by,
290            obliterate,
291            inputs: replaced_inputs,
292            ..
293        } = self;
294        let mut all_inputs = HashMap::new(); // new_node_id_in_model -> [ patch_outlet_id ]
295        let mut model_input_outlets = target.input_outlets()?.to_vec();
296        let mut new_nodes = HashSet::new();
297        for node in patch.nodes {
298            if <Graph<F, O>>::is_source(&node.op)
299                && mapping.contains_key(&OutletId::new(node.id, 0))
300            {
301                // this is a tap
302                continue;
303            }
304            let Node { id: patch_node_id, name, inputs, op, outputs } = node;
305            let n_outputs = outputs.len();
306            for dup in 0..target.nodes.len() {
307                if target.node(dup).op().same_as(op.as_ref())
308                    && inputs.len() == target.node(dup).inputs.len()
309                    && inputs
310                        .iter()
311                        .zip(target.node(dup).inputs.iter())
312                        .all(|(patch_input, d)| mapping[patch_input] == *d)
313                {
314                    for ix in 0..n_outputs {
315                        mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(dup, ix));
316                    }
317                    continue;
318                }
319            }
320            let facts = outputs.into_iter().map(|of| of.fact).collect();
321            let added_node_id = target.add_node(name, op, facts)?;
322            new_nodes.insert(added_node_id);
323            for ix in 0..n_outputs {
324                mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(added_node_id, ix));
325            }
326            all_inputs.insert(added_node_id, inputs);
327            if <Graph<F, O>>::is_source(&target.node(added_node_id).op) {
328                // this is actually an input replacement
329                model_input_outlets.iter_mut().for_each(|oo| {
330                    if oo.node == replaced_inputs[&patch_node_id] {
331                        oo.node = added_node_id;
332                    }
333                });
334            }
335        }
336        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
337        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
338        for (&outlet, &by) in shunt_outlet_by.iter().sorted() {
339            let replace_by = mapping[&by];
340            let succs = target.nodes()[outlet.node].outputs[outlet.slot].successors.clone();
341            for succ in succs {
342                target.add_edge(replace_by, succ)?;
343            }
344            for o in target.outputs.iter_mut() {
345                if *o == outlet {
346                    *o = replace_by;
347                }
348            }
349            if let Some(label) = target.outlet_labels.remove(&outlet) {
350                target.set_outlet_label(replace_by, label)?;
351            }
352        }
353        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
354        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
355        for (&node, inputs) in all_inputs.iter().sorted() {
356            for (ix, input) in inputs.iter().enumerate() {
357                target.add_edge(mapping[input], InletId::new(node, ix))?;
358            }
359        }
360        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
361        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
362        for node in obliterate {
363            target.node_mut(node).op = target.create_dummy();
364        }
365        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
366        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
367        target.set_input_outlets(&model_input_outlets)?;
368        let mut maybe_garbage: HashSet<usize> = shunt_outlet_by.iter().map(|o| o.0.node).collect();
369        while let Some(&maybe) = maybe_garbage.iter().next() {
370            maybe_garbage.remove(&maybe);
371            if !target.outputs.iter().any(|output| output.node == maybe)
372                && !target.inputs.iter().any(|input| input.node == maybe)
373                && target.node(maybe).outputs.iter().all(|of| of.successors.is_empty())
374            {
375                target.node_mut(maybe).op = target.create_dummy();
376                target.node_mut(maybe).name = format!("Dummy-node-{}", maybe);
377                target.node_mut(maybe).outputs.clear(); // necessary to drop facts and consts
378                let inputs = std::mem::take(&mut target.node_mut(maybe).inputs);
379                for &i in &inputs {
380                    target.node_mut(i.node).outputs[i.slot].successors.retain(|s| s.node != maybe);
381                    maybe_garbage.insert(i.node);
382                }
383                target.check_edges()?;
384            }
385        }
386        for n in new_nodes.iter() {
387            if let Some((prefix, _)) = target.nodes[*n].name.split_once('#') {
388                target.nodes[*n].name = target.unique_name(prefix).into();
389            } else if target
390                .nodes
391                .iter()
392                .any(|node| node.id != *n && target.nodes[*n].name == node.name)
393            {
394                target.nodes[*n].name = target.unique_name(&target.nodes[*n].name).to_string();
395            }
396        }
397        Ok(())
398    }
399}