Struct tract_core::model::Graph

source ·
pub struct Graph<F, O>where
    F: Fact + Hash + Clone + 'static,
    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{ pub nodes: Vec<Node<F, O>>, pub inputs: Vec<OutletId>, pub outputs: Vec<OutletId>, pub outlet_labels: HashMap<OutletId, String>, pub properties: HashMap<String, Arc<Tensor>>, pub symbol_table: SymbolTable, }
Expand description

Main model class

Parameterized by a Fact class.

Fields§

§nodes: Vec<Node<F, O>>

all nodes in the model

§inputs: Vec<OutletId>

model inputs

§outputs: Vec<OutletId>

model outputs

§outlet_labels: HashMap<OutletId, String>

outlet labels

§properties: HashMap<String, Arc<Tensor>>

model properties

§symbol_table: SymbolTable

symbol table

Implementations§

Examples found in repository?
src/model/patch.rs (lines 104-107)
102
103
104
105
106
107
108
109
110
    pub fn tap_model(&mut self, model: &Graph<F, O>, outlet: OutletId) -> TractResult<OutletId> {
        let fact = model.outlet_fact(outlet)?;
        let id = self.add_source(
            format!("incoming-{}/{}", outlet.node, outlet.slot),
            dyn_clone::clone(fact),
        )?;
        self.incoming.insert(id, outlet);
        Ok(id)
    }
More examples
Hide additional examples
src/ops/cnn/deconv/unary.rs (line 159)
156
157
158
159
160
161
162
163
    fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let input = args_1!(inputs);
        let mut model = TypedModel::default();
        let source = model.add_source("source", input.datum_type().fact(input.shape()))?;
        let output = self.wire_with_deconv_sum("adhoc", &mut model, source)?;
        model.set_output_outlets(&output)?;
        model.into_runnable()?.run(tvec!(input))
    }
src/ops/cnn/conv/unary.rs (line 786)
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let mut model = TypedModel::default();

        let mut wires: TVec<OutletId> = inputs
            .iter()
            .enumerate()
            .map(|(ix, v)| {
                model.add_source(format!("source.{}", ix), v.datum_type().fact(v.shape()))
            })
            .collect::<TractResult<_>>()?;
        let new_op = self.kernel_offset_u8_as_i8(&mut wires, &mut model)?;
        let wire = unsafe {
            if self.q_params.is_some() {
                let op_ref = if let Some(op) = new_op.as_ref() { op } else { self };
                op_ref.wire_as_quant_im2col(
                    &mut model,
                    "im2col-adhoc",
                    inputs[0].datum_type(),
                    &wires,
                )?
            } else {
                self.wire_as_im2col_pair(&mut model, "im2col-adhoc", wires[0])?
            }
        };
        model.set_output_outlets(&[wire])?;
        model.into_runnable()?.run(inputs)
    }
src/model/translator.rs (lines 95-102)
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    fn translate_node(
        &self,
        source: &Graph<TI1, O1>,
        node: &Node<TI1, O1>,
        target: &mut Graph<TI2, O2>,
        mapping: &HashMap<OutletId, OutletId>,
    ) -> TractResult<TVec<OutletId>> {
        let node_is_input =
            (0..node.outputs.len()).all(|o| source.inputs.contains(&(node.id, o).into()));
        if node_is_input {
            (0..node.outputs.len())
                .map(|i| {
                    target.add_source(
                        if node.outputs.len() > 1 {
                            format!("{}-{}", node.name, i)
                        } else {
                            node.name.to_string()
                        },
                        TI2::try_from(&node.outputs[i].fact)?,
                    )
                })
                .collect()
        } else {
            let new_op = O2::try_from(&node.op)?;
            let facts = node
                .outputs
                .iter()
                .map(|of| Ok(TI2::try_from(&of.fact)?))
                .collect::<TractResult<TVec<_>>>()?;
            let new_id = target.add_node(node.name.clone(), new_op, facts)?;
            for (ix, o) in node.inputs.iter().enumerate() {
                target.add_edge(mapping[o], InletId::new(new_id, ix))?
            }
            Ok(node.outputs.iter().enumerate().map(|(ix, _)| OutletId::new(new_id, ix)).collect())
        }
    }
src/ops/scan/mir.rs (lines 332-335)
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    fn declutter_pull_batcheable_input(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (model_input, input) in self.input_mapping.iter().enumerate() {
            if let Some(info) = input.as_scan() {
                let scan_source = self.body.input_outlets()?[model_input];
                let scan_source_node = self.body.node(scan_source.node);
                for successor in &scan_source_node.outputs[0].successors {
                    let successor_node = self.body.node(successor.node);
                    if successor_node.inputs.len() != 1 || successor_node.outputs.len() != 1 {
                        continue;
                    }
                    let (input_facts, output_facts) = self.body.node_facts(successor_node.id)?;
                    let invariants = successor_node.op.invariants(&input_facts, &output_facts)?;
                    if let Some(axis_after) = invariants.unary_track_axis_down(info.axis, false) {
                        let mut outside_patch = TypedModelPatch::new(format!(
                            "Outer patch for input extraction of {}",
                            successor_node
                        ));
                        let mut patch_inputs = node
                            .inputs
                            .iter()
                            .map(|&i| outside_patch.tap_model(model, i))
                            .collect::<TractResult<TVec<_>>>()?;
                        let input = patch_inputs[info.slot];
                        let new_input_wire = outside_patch.wire_node(
                            format!("{}.extracted.{}", node.name, successor_node.name),
                            successor_node.op.clone(),
                            &[input],
                        )?[0];
                        patch_inputs.push(new_input_wire);
                        let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
                        let mut new_input_inner_fact = new_input_outer_fact.clone();
                        new_input_inner_fact.shape.set(axis_after, info.chunk.abs().to_dim());

                        let mut new_body = self.body.clone();
                        let new_source_wire = new_body.add_source(
                            format!("{}.extracted.{}", node.name, successor_node.name),
                            new_input_inner_fact,
                        )?;
                        let mut inner_patch = TypedModelPatch::new(format!(
                            "Inner body patch for extraction of {}",
                            successor_node
                        ));
                        let new_source_wire_in_patch =
                            inner_patch.tap_model(&new_body, new_source_wire)?;
                        inner_patch
                            .shunt_outside(
                                &new_body,
                                OutletId::new(successor.node, 0),
                                new_source_wire_in_patch,
                            )
                            .with_context(|| "patching inner model")?;
                        inner_patch.apply(&mut new_body)?;

                        let mut input_mapping = self.input_mapping.clone();
                        input_mapping.push(InputMapping::Scan(ScanInfo {
                            axis: axis_after,
                            chunk: info.chunk,
                            slot: node.inputs.len(),
                        }));

                        let new_op = Self {
                            input_mapping,
                            output_mapping: self.output_mapping.clone(),
                            decluttered: false,
                            body: new_body,
                            skip: self.skip,
                            seq_length_input_slot: self.seq_length_input_slot,
                        };
                        let output_wires =
                            outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
                        for w in output_wires {
                            outside_patch
                                .shunt_outside(model, OutletId::new(node.id, w.slot), w)
                                .with_context(|| "patching outer model")?;
                        }
                        return Ok(Some(outside_patch));
                    }
                }
            }
        }
        Ok(None)
    }
src/ops/quant.rs (line 189)
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    fn declutter(
        &self,
        model: &TypedModel,
        dequant: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut current = dequant;
        let incoming_dt = model.node_input_facts(dequant.id)?[0].datum_type;
        while let Some(quant) = model.single_succ(current.id)? {
            let q_params = if let Some(op) = quant.op_as::<ElementWiseOp>() {
                if let Some(mop) = op.0.downcast_ref::<QuantizeLinearU8>() {
                    Some((mop.scale, mop.zero_point as i32, u8::datum_type()))
                } else {
                    op.0.downcast_ref::<QuantizeLinearI8>()
                        .map(|mop| (mop.scale, mop.zero_point as i32, i8::datum_type()))
                }
            } else {
                None
            };
            if let Some((scale, zero_point, dt)) = q_params {
                // first, try Op::quantize() on all ops in the chain
                let mut patch = TypedModelPatch::default();
                let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
                let mut next = model.single_succ(dequant.id)?.unwrap();
                loop {
                    if let Some(op) = next
                        .op
                        .quantize(model, dequant, dt, scale, zero_point)
                        .with_context(|| format!("Quantizing {}", next))?
                    {
                        wire = patch.wire_node(&*next.name, op, [wire].as_ref())?[0];
                    } else {
                        break;
                    }
                    if next.id == current.id {
                        patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
                        return Ok(Some(patch));
                    } else {
                        next = model.single_succ(next.id)?.unwrap();
                    }
                }
                // or else make a lookup table
                if incoming_dt == DatumType::I8 || incoming_dt == DatumType::U8 {
                    let mut adhoc_model = TypedModel::default();
                    let mut wire = adhoc_model.add_source("ad-hoc", dt.fact([256]))?;
                    let mut next = model.single_succ(dequant.id)?.unwrap();
                    let mut name = None;
                    // plug in dequant
                    wire = adhoc_model.wire_node(
                        &*dequant.name,
                        dequant.op.clone(),
                        [wire].as_ref(),
                    )?[0];
                    while next.id != quant.id {
                        name.get_or_insert(&*next.name);
                        wire =
                            adhoc_model.wire_node(&*next.name, next.op.clone(), [wire].as_ref())?
                                [0];
                        next = model.single_succ(next.id)?.unwrap();
                    }
                    // plug in quant
                    wire =
                        adhoc_model.wire_node(&*quant.name, quant.op.clone(), [wire].as_ref())?[0];
                    adhoc_model.set_output_outlets(&[wire])?;
                    let input = (0u8..=255).collect::<Vec<u8>>();
                    let input = match dt {
                        DatumType::I8 => unsafe {
                            tensor1(std::mem::transmute::<&[u8], &[i8]>(&*input))
                        },
                        DatumType::U8 => tensor1(&input),
                        _ => unreachable!(),
                    };
                    let output =
                        SimplePlan::new(adhoc_model)?.run(tvec!(input.into_tvalue()))?.remove(0);
                    let table: &[u8] = match dt {
                        DatumType::I8 => unsafe { std::mem::transmute(output.as_slice::<i8>()?) },
                        DatumType::U8 => output.as_slice::<u8>()?,
                        _ => unreachable!(),
                    };
                    let op = lookup_table((tract_linalg::ops().lut_u8)(table));
                    let mut patch = TypedModelPatch::default();
                    let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;

                    wire = patch.wire_node(name.unwrap_or(&*dequant.name), op, [wire].as_ref())?[0];
                    patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
                    return Ok(Some(patch));
                }
            }
            let (input_facts, output_facts) = model.node_facts(quant.id)?;
            let invariants = quant
                .op
                .invariants(&input_facts, &output_facts)
                .with_context(|| format!("Querying invariants for {}", quant))?;
            if invariants.element_wise() {
                current = quant;
            } else {
                break;
            }
        }
        Ok(None)
    }
Examples found in repository?
src/model/graph.rs (line 91)
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
    pub fn add_source(&mut self, name: impl Into<String>, fact: F) -> TractResult<OutletId> {
        let source = self.create_source(fact.clone());
        let id = self.add_node(name, source, tvec!(fact))?;
        let id = OutletId::new(id, 0);
        self.inputs.push(id);
        Ok(id)
    }
}

impl<F, O> Graph<F, O>
where
    F: Fact + Hash + Clone + 'static,
    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
    pub fn add_node(
        &mut self,
        name: impl Into<String>,
        op: impl Into<O>,
        output_facts: TVec<F>,
    ) -> TractResult<usize> {
        let op = op.into();
        let name = name.into();
        let id = self.nodes.len();
        let outputs =
            output_facts.into_iter().map(|fact| Outlet { fact, successors: tvec!() }).collect();
        let node = Node { id, name, op, inputs: vec![], outputs };
        self.nodes.push(node);
        Ok(id)
    }

    /// Connect a node outlet to a node inlet.
    pub fn add_edge(&mut self, outlet: OutletId, inlet: InletId) -> TractResult<()> {
        if let Some(previous) = self.nodes[inlet.node].inputs.get(inlet.slot).cloned() {
            self.nodes[previous.node].outputs[previous.slot]
                .successors
                .retain(|&mut succ| succ != inlet);
        }
        {
            let prec = &mut self.nodes[outlet.node];
            prec.outputs[outlet.slot].successors.push(inlet);
        }
        let succ = &mut self.nodes[inlet.node];
        #[allow(clippy::comparison_chain)]
        if inlet.slot == succ.inputs.len() {
            succ.inputs.push(outlet);
        } else if inlet.slot < succ.inputs.len() {
            succ.inputs[inlet.slot] = outlet;
        } else {
            bail!("Edges must be added in order and consecutive. Trying to connect input {:?} of node {:?} ", inlet.slot, succ)
        }
        Ok(())
    }

    // Inputs

    /// Get model inputs.
    pub fn input_outlets(&self) -> TractResult<&[OutletId]> {
        Ok(&self.inputs)
    }

    /// Change model inputs.
    pub fn set_input_outlets(&mut self, inputs: &[OutletId]) -> TractResult<()> {
        self.inputs = inputs.to_vec();
        Ok(())
    }

    /// Change model inputs and return `self`.
    pub fn with_input_outlets(mut self, inputs: &[OutletId]) -> TractResult<Self> {
        self.set_input_outlets(inputs)?;
        Ok(self)
    }

    /// Set model inputs by the node name.
    pub fn set_input_names(
        &mut self,
        inputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<()> {
        let mut ids = vec![];
        for i in inputs.into_iter() {
            let node = self.node_by_name(&i)?;
            for o in 0..node.outputs.len() {
                ids.push(OutletId::new(node.id, o))
            }
        }
        self.inputs = ids;
        Ok(())
    }

    /// Set model inputs by the node name and return `self`.
    pub fn with_input_names(
        mut self,
        inputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<Self> {
        self.set_input_names(inputs)?;
        Ok(self)
    }

    /// Get the `ix`-th input tensor type information.
    pub fn input_fact(&self, ix: usize) -> TractResult<&F> {
        let input = self.input_outlets()?[ix];
        self.outlet_fact(input)
    }

    /// Get the `ix`-th input tensor type information, mutably.
    pub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
        let input = self.input_outlets()?[ix];
        self.outlet_fact_mut(input)
    }

    /// Set the `ix`-th input tensor type information.
    pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
        let outlet = self.inputs[input];
        self.set_outlet_fact(outlet, fact)
    }

    /// Set the `ix`-th input tensor type information and return `self`.
    pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
        self.set_input_fact(input, fact)?;
        Ok(self)
    }

    // Outputs
    /// Get model outputs.
    pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
        Ok(&self.outputs)
    }

    /// Guess outputs from the topology: node or nodes with no successors.
    pub fn auto_outputs(&mut self) -> TractResult<()> {
        let outputs = self
            .nodes
            .iter()
            .flat_map(|n| {
                let id = n.id;
                n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
                    (OutletId::new(id, ix), output_fact.successors.len())
                })
            })
            .filter(|(_f, succs)| *succs == 0)
            .map(|(f, _)| f)
            .collect();
        self.outputs = outputs;
        Ok(())
    }

    /// Change model outputs.
    pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
        self.outputs = outputs.to_vec();
        Ok(())
    }

    /// Change model outputs and return `self`.
    pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
        self.set_output_outlets(outputs)?;
        Ok(self)
    }

    /// Set model outputs by node names.
    pub fn set_output_names(
        &mut self,
        outputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<()> {
        let mut labels: HashMap<Cow<str>, OutletId> =
            self.outlet_labels.iter().map(|(o, s)| (Cow::Borrowed(&**s), *o)).collect();
        for n in self.nodes() {
            for ix in 0..n.outputs.len() {
                labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
            }
        }
        let ids: Vec<OutletId> = outputs
            .into_iter()
            .map(|s| {
                let s = s.as_ref();
                labels
                    .get(s)
                    .cloned()
                    .or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
                    .ok_or_else(|| format_err!("Node {} not found", s))
            })
            .collect::<TractResult<_>>()?;
        self.outputs = ids;
        Ok(())
    }

    /// Set model outputs by node names and return `self`.
    pub fn with_output_names(
        mut self,
        outputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<Self> {
        self.set_output_names(outputs)?;
        Ok(self)
    }

    /// Get the `ix`-th input tensor type information.
    pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
        let output = self.output_outlets()?[ix];
        self.outlet_fact(output)
    }

    /// Get the `ix`-th input tensor type information, mutably.
    pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
        let output = self.output_outlets()?[ix];
        self.outlet_fact_mut(output)
    }

    /// Set the `ix`-th output tensor type information.
    pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
        let outlet = self.outputs[output];
        self.set_outlet_fact(outlet, fact)
    }

    /// Set the `ix`-th output tensor type information and return `self`.
    pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
        self.set_output_fact(output, fact)?;
        Ok(self)
    }

    // nodes and their facts

    /// Iterate over all node names.
    pub fn node_names(&self) -> impl Iterator<Item = &str> {
        self.nodes.iter().map(|s| &*s.name)
    }

    pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
        self.nodes
            .iter()
            .find(|n| n.name == name)
            .map(|n| n.id)
            .with_context(|| format!("No node found for name: \"{}\"", name))
    }

    /// Find a node by its name.
    pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
        let id: usize = self.node_id_by_name(name.as_ref())?;
        Ok(&self.nodes[id])
    }

    /// Borrow mutably a node by its name.
    pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
        let id: usize = self.node_id_by_name(name.as_ref())?;
        Ok(&mut self.nodes[id])
    }

    pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
        self.node_mut(id).name = name.to_string();
        Ok(())
    }

    /// Find a node by its id.
    pub fn node(&self, id: usize) -> &Node<F, O> {
        &self.nodes[id]
    }

    /// Find a node by its id.
    pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
        &mut self.nodes[id]
    }

    /// Access the nodes table.
    pub fn nodes(&self) -> &[Node<F, O>] {
        &self.nodes
    }

    /// Access the nodes table.
    pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
        &mut self.nodes
    }

    /// Get input and output tensor information for a node.
    pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
        Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
    }

    /// Get input tensor information for a node.
    pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
    }

    /// Get output tensor information for a node.
    pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
    }

    // outlets

    /// Get tensor information for a single outlet.
    pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
        anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
        let outlets = &self.nodes[outlet.node].outputs;
        outlets
            .get(outlet.slot)
            .map(|o| &o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get tensor information for a single outlet.
    pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        outlets
            .get_mut(outlet.slot)
            .map(|o| &mut o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get multiple mutable tensor information for outlets.
    pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
        assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
        unsafe {
            outlets
                .iter()
                .map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
                .collect()
        }
    }

    /// Set tensor information for a single outlet.
    pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        if outlets.len() <= outlet.slot {
            bail!("Invalid outlet refererence: {:?}", outlet)
        }
        outlets[outlet.slot].fact = fact;
        Ok(())
    }

    /// Set tensor information for a single outlet and return `self`.
    pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
        self.set_outlet_fact(outlet, fact)?;
        Ok(self)
    }

    // outlet labels

    /// Get label for an outlet.
    pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
        self.outlet_labels.get(&outlet).map(|s| &**s)
    }

    /// Set label for an outlet.
    pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
        self.outlet_labels.insert(outlet, label);
        Ok(())
    }

    /// Set label for an outlet and return `self`.
    pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
        self.set_outlet_label(outlet, label)?;
        Ok(self)
    }

    /// Find outlet by label.
    pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
        self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
    }

    // misc

    /// Computes an evalutation order for the graph inputs and outputs
    pub fn eval_order(&self) -> TractResult<Vec<usize>> {
        eval_order(self)
    }

    #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
    #[inline]
    pub fn check_edges(&self) -> TractResult<()> {
        Ok(())
    }

    /// Performs a sanity check on network connections.
    #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
    #[inline]
    pub fn check_edges(&self) -> TractResult<()> {
        for node_id in self.eval_order()? {
            let node = &self.nodes[node_id];
            for (ix, input) in node.inputs.iter().enumerate() {
                let prec = &self.nodes[input.node];
                if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
                    bail!(
                        "Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
                        node.id,
                        ix,
                        prec
                    )
                }
            }
            for (ix, output) in node.outputs.iter().enumerate() {
                for succ in &output.successors {
                    if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
                        bail!(
                            "Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
                            node.id,
                            ix,
                            succ
                        )
                    }
                }
            }
        }
        Ok(())
    }

    /// Converts the model into a `RunnableModel` which fixes the inputs and outputs and allows passing data through the model.
    pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
        crate::plan::SimplePlan::new(self)
    }

    pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
        let node = &self.nodes()[id];
        if node.inputs.len() != 1 {
            return Ok(None);
        }
        let prec = &self.nodes()[node.inputs[0].node];
        if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        Ok(Some(prec))
    }

    pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_prec(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_succ(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
        let node = &self.nodes()[id];
        if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        let succ = node.outputs[0].successors[0];
        let succ = &self.nodes()[succ.node];
        if succ.inputs.len() != 1 {
            return Ok(None);
        }
        Ok(Some(succ))
    }

    pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
        &self.nodes[outlet.node].outputs[outlet.slot].successors
    }
}

impl<F: Fact + Clone + 'static, O> Graph<F, O>
where
    F: Fact + Clone + 'static + From<std::sync::Arc<Tensor>> + Hash,
    O: fmt::Debug
        + fmt::Display
        + From<crate::ops::konst::Const>
        + AsRef<dyn Op>
        + AsMut<dyn Op>
        + Clone
        + Hash
        + 'static,
{
    pub fn add_const(
        &mut self,
        name: impl Into<String>,
        v: impl IntoArcTensor,
    ) -> TractResult<OutletId> {
        let v = v.into_arc_tensor();
        let fact = F::from(v.clone());
        let name = name.into();
        self.add_node(name, crate::ops::konst::Const::new(v), tvec!(fact)).map(|id| id.into())
    }
More examples
Hide additional examples
src/model/patch.rs (line 176)
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    pub fn fuse_with_next<IO: Into<O>>(
        patched_model: &Graph<F, O>,
        node: &Node<F, O>,
        new_op: IO,
    ) -> TractResult<ModelPatch<F, O>> {
        let mut patch = ModelPatch::default();
        let succ = if let Some(succ) = patched_model.single_succ(node.id)? {
            succ
        } else {
            bail!("Non single successor fuse attempt")
        };
        let new_op = new_op.into();
        let by = patch.add_node(&*node.name, new_op, tvec!(succ.outputs[0].fact.clone()))?;
        for (ix, i) in node.inputs.iter().enumerate() {
            let o = patch.tap_model(patched_model, *i)?;
            patch.add_edge(o, InletId::new(by, ix))?;
        }
        for ix in 0..node.outputs.len() {
            patch.shunt_outside(
                patched_model,
                OutletId::new(succ.id, ix),
                OutletId::new(by, ix),
            )?;
        }
        Ok(patch)
    }

    /// Convenience method creating a patch that shunts the given node.
    pub fn shunt_one_op(
        patched_model: &Graph<F, O>,
        node: &Node<F, O>,
    ) -> TractResult<ModelPatch<F, O>> {
        Self::rewire(patched_model, &node.inputs, &[node.id.into()], &|_p, xs| Ok(xs.into()))
    }

    #[allow(clippy::type_complexity)]
    pub fn rewire(
        patched_model: &Graph<F, O>,
        from: &[OutletId],
        to: &[OutletId],
        wiring: &dyn Fn(&mut Self, &[OutletId]) -> TractResult<TVec<OutletId>>,
    ) -> TractResult<ModelPatch<F, O>> {
        let mut patch = ModelPatch::default();
        let taps = from
            .iter()
            .map(|f| patch.tap_model(patched_model, *f))
            .collect::<TractResult<TVec<_>>>()?;
        let news = wiring(&mut patch, &taps)?;
        if news.len() != to.len() {
            bail!(
                "Wrong number of outputs for rewiring, expected {}, function returned {}",
                to.len(),
                news.len()
            );
        }
        for (new, &old) in izip!(news, to) {
            patch.shunt_outside(patched_model, old, new)?;
        }
        Ok(patch)
    }

    /// Convenience method creating a patch that replace a single unary operation.
    pub fn single_unary_op<IO: Into<O>>(
        patched_model: &Graph<F, O>,
        node: &Node<F, O>,
        new_op: IO,
    ) -> TractResult<ModelPatch<F, O>> {
        Self::replace_single_op(patched_model, node, &[node.inputs[0]], new_op)
    }

    /// Convenience method creating a patch that insert an unary op on an outlet.
    pub fn intercept<IO: Into<O>>(
        patched_model: &Graph<F, O>,
        outlet: OutletId,
        name: impl Into<String>,
        new_op: IO,
        fact: F,
    ) -> TractResult<ModelPatch<F, O>> {
        let mut patch = ModelPatch::default();
        let tap = patch.tap_model(patched_model, outlet)?;
        let new_id = patch.add_node(name, new_op, tvec!(fact))?;
        patch.add_edge(tap, InletId::new(new_id, 0))?;
        patch.shunt_outside(patched_model, outlet, OutletId::new(new_id, 0))?;
        Ok(patch)
    }

    /// Apply all changes in the patch to the target model.
    pub fn apply(self, target: &mut Graph<F, O>) -> TractResult<()> {
        let prior_target_inputs = target.input_outlets()?.len();
        let prior_target_outputs = target.output_outlets()?.len();
        let ModelPatch {
            model: patch,
            incoming: mut mapping,
            shunt_outlet_by,
            obliterate,
            inputs: replaced_inputs,
            ..
        } = self;
        let mut all_inputs = HashMap::new(); // new_node_id_in_model -> [ patch_outlet_id ]
        let mut model_input_outlets = target.input_outlets()?.to_vec();
        for node in patch.nodes {
            if <Graph<F, O>>::is_source(&node.op)
                && mapping.contains_key(&OutletId::new(node.id, 0))
            {
                // this is a tap
                continue;
            }
            let Node { id: patch_node_id, name, inputs, op, outputs } = node;
            let n_outputs = outputs.len();
            for dup in 0..target.nodes.len() {
                if target.node(dup).op().same_as(op.as_ref())
                    && inputs.len() == target.node(dup).inputs.len()
                    && inputs
                        .iter()
                        .zip(target.node(dup).inputs.iter())
                        .all(|(patch_input, d)| mapping[patch_input] == *d)
                {
                    for ix in 0..n_outputs {
                        mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(dup, ix));
                    }
                    continue;
                }
            }
            let facts = outputs.into_iter().map(|of| of.fact).collect();
            let added_node_id = target.add_node(name, op, facts)?;
            for ix in 0..n_outputs {
                mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(added_node_id, ix));
            }
            all_inputs.insert(added_node_id, inputs);
            if <Graph<F, O>>::is_source(&target.node(added_node_id).op) {
                // this is actually an input replacement
                model_input_outlets.iter_mut().for_each(|oo| {
                    if oo.node == replaced_inputs[&patch_node_id] {
                        oo.node = added_node_id;
                    }
                });
            }
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for (outlet, by) in shunt_outlet_by {
            let replace_by = mapping[&by];
            let succs = target.nodes()[outlet.node].outputs[outlet.slot].successors.clone();
            for succ in succs {
                target.add_edge(replace_by, succ)?;
            }
            for o in target.outputs.iter_mut() {
                if *o == outlet {
                    *o = replace_by;
                }
            }
            if let Some(label) = target.outlet_labels.remove(&outlet) {
                target.set_outlet_label(replace_by, label)?;
            }
        }
        if target.outputs.len() > target.outputs.iter().sorted().dedup().count() {
            bail!("Duplicate usage of node as output");
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for (node, inputs) in all_inputs {
            for (ix, input) in inputs.into_iter().enumerate() {
                target.add_edge(mapping[&input], InletId::new(node, ix))?;
            }
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for node in obliterate {
            target.node_mut(node).op = target.create_dummy();
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        target.set_input_outlets(&model_input_outlets)?;
        Ok(())
    }
src/model/translator.rs (line 112)
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    fn translate_node(
        &self,
        source: &Graph<TI1, O1>,
        node: &Node<TI1, O1>,
        target: &mut Graph<TI2, O2>,
        mapping: &HashMap<OutletId, OutletId>,
    ) -> TractResult<TVec<OutletId>> {
        let node_is_input =
            (0..node.outputs.len()).all(|o| source.inputs.contains(&(node.id, o).into()));
        if node_is_input {
            (0..node.outputs.len())
                .map(|i| {
                    target.add_source(
                        if node.outputs.len() > 1 {
                            format!("{}-{}", node.name, i)
                        } else {
                            node.name.to_string()
                        },
                        TI2::try_from(&node.outputs[i].fact)?,
                    )
                })
                .collect()
        } else {
            let new_op = O2::try_from(&node.op)?;
            let facts = node
                .outputs
                .iter()
                .map(|of| Ok(TI2::try_from(&of.fact)?))
                .collect::<TractResult<TVec<_>>>()?;
            let new_id = target.add_node(node.name.clone(), new_op, facts)?;
            for (ix, o) in node.inputs.iter().enumerate() {
                target.add_edge(mapping[o], InletId::new(new_id, ix))?
            }
            Ok(node.outputs.iter().enumerate().map(|(ix, _)| OutletId::new(new_id, ix)).collect())
        }
    }
src/model/typed.rs (line 69)
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    fn wire_node(
        &mut self,
        name: impl Into<String>,
        op: impl Into<Box<dyn TypedOp>>,
        inputs: &[OutletId],
    ) -> TractResult<TVec<OutletId>> {
        let op = op.into();
        let name = name.into();

        {
            let output_facts = || -> TractResult<TVec<TypedFact>> {
                let input_facts = inputs
                    .iter()
                    .map(|o| self.outlet_fact(*o))
                    .collect::<TractResult<TVec<_>>>()?;
                let facts = op.output_facts(&input_facts).context("in output_facts invocation")?;
                if input_facts.iter().all(|f| f.konst.is_some()) && op.is_stateless() {
                    let tensors = input_facts
                        .iter()
                        .map(|f| f.konst.clone().unwrap().into_tvalue())
                        .collect::<TVec<_>>();
                    if let Ok(outputs) = op.eval(tensors) {
                        return Ok(outputs.into_iter().map(|t| TypedFact::from(&*t)).collect());
                    }
                }
                Ok(facts)
            };

            let output_facts = output_facts()
                .with_context(|| format!("wiring {} ({:?}), determining output_facts", name, op))?;
            let id = self.add_node(&name, &op, output_facts)?;
            inputs
                .iter()
                .enumerate()
                .try_for_each(|(ix, i)| self.add_edge(*i, InletId::new(id, ix)))?;
            TractResult::Ok(
                self.node(id)
                    .outputs
                    .iter()
                    .enumerate()
                    .map(|(ix, _)| OutletId::new(id, ix))
                    .collect(),
            )
        }
        .with_context(|| format!("Wiring node \"{}\", {:?}", name, op))
    }

Connect a node outlet to a node inlet.

Examples found in repository?
src/model/patch.rs (line 179)
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    pub fn fuse_with_next<IO: Into<O>>(
        patched_model: &Graph<F, O>,
        node: &Node<F, O>,
        new_op: IO,
    ) -> TractResult<ModelPatch<F, O>> {
        let mut patch = ModelPatch::default();
        let succ = if let Some(succ) = patched_model.single_succ(node.id)? {
            succ
        } else {
            bail!("Non single successor fuse attempt")
        };
        let new_op = new_op.into();
        let by = patch.add_node(&*node.name, new_op, tvec!(succ.outputs[0].fact.clone()))?;
        for (ix, i) in node.inputs.iter().enumerate() {
            let o = patch.tap_model(patched_model, *i)?;
            patch.add_edge(o, InletId::new(by, ix))?;
        }
        for ix in 0..node.outputs.len() {
            patch.shunt_outside(
                patched_model,
                OutletId::new(succ.id, ix),
                OutletId::new(by, ix),
            )?;
        }
        Ok(patch)
    }

    /// Convenience method creating a patch that shunts the given node.
    pub fn shunt_one_op(
        patched_model: &Graph<F, O>,
        node: &Node<F, O>,
    ) -> TractResult<ModelPatch<F, O>> {
        Self::rewire(patched_model, &node.inputs, &[node.id.into()], &|_p, xs| Ok(xs.into()))
    }

    #[allow(clippy::type_complexity)]
    pub fn rewire(
        patched_model: &Graph<F, O>,
        from: &[OutletId],
        to: &[OutletId],
        wiring: &dyn Fn(&mut Self, &[OutletId]) -> TractResult<TVec<OutletId>>,
    ) -> TractResult<ModelPatch<F, O>> {
        let mut patch = ModelPatch::default();
        let taps = from
            .iter()
            .map(|f| patch.tap_model(patched_model, *f))
            .collect::<TractResult<TVec<_>>>()?;
        let news = wiring(&mut patch, &taps)?;
        if news.len() != to.len() {
            bail!(
                "Wrong number of outputs for rewiring, expected {}, function returned {}",
                to.len(),
                news.len()
            );
        }
        for (new, &old) in izip!(news, to) {
            patch.shunt_outside(patched_model, old, new)?;
        }
        Ok(patch)
    }

    /// Convenience method creating a patch that replace a single unary operation.
    pub fn single_unary_op<IO: Into<O>>(
        patched_model: &Graph<F, O>,
        node: &Node<F, O>,
        new_op: IO,
    ) -> TractResult<ModelPatch<F, O>> {
        Self::replace_single_op(patched_model, node, &[node.inputs[0]], new_op)
    }

    /// Convenience method creating a patch that insert an unary op on an outlet.
    pub fn intercept<IO: Into<O>>(
        patched_model: &Graph<F, O>,
        outlet: OutletId,
        name: impl Into<String>,
        new_op: IO,
        fact: F,
    ) -> TractResult<ModelPatch<F, O>> {
        let mut patch = ModelPatch::default();
        let tap = patch.tap_model(patched_model, outlet)?;
        let new_id = patch.add_node(name, new_op, tvec!(fact))?;
        patch.add_edge(tap, InletId::new(new_id, 0))?;
        patch.shunt_outside(patched_model, outlet, OutletId::new(new_id, 0))?;
        Ok(patch)
    }

    /// Apply all changes in the patch to the target model.
    pub fn apply(self, target: &mut Graph<F, O>) -> TractResult<()> {
        let prior_target_inputs = target.input_outlets()?.len();
        let prior_target_outputs = target.output_outlets()?.len();
        let ModelPatch {
            model: patch,
            incoming: mut mapping,
            shunt_outlet_by,
            obliterate,
            inputs: replaced_inputs,
            ..
        } = self;
        let mut all_inputs = HashMap::new(); // new_node_id_in_model -> [ patch_outlet_id ]
        let mut model_input_outlets = target.input_outlets()?.to_vec();
        for node in patch.nodes {
            if <Graph<F, O>>::is_source(&node.op)
                && mapping.contains_key(&OutletId::new(node.id, 0))
            {
                // this is a tap
                continue;
            }
            let Node { id: patch_node_id, name, inputs, op, outputs } = node;
            let n_outputs = outputs.len();
            for dup in 0..target.nodes.len() {
                if target.node(dup).op().same_as(op.as_ref())
                    && inputs.len() == target.node(dup).inputs.len()
                    && inputs
                        .iter()
                        .zip(target.node(dup).inputs.iter())
                        .all(|(patch_input, d)| mapping[patch_input] == *d)
                {
                    for ix in 0..n_outputs {
                        mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(dup, ix));
                    }
                    continue;
                }
            }
            let facts = outputs.into_iter().map(|of| of.fact).collect();
            let added_node_id = target.add_node(name, op, facts)?;
            for ix in 0..n_outputs {
                mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(added_node_id, ix));
            }
            all_inputs.insert(added_node_id, inputs);
            if <Graph<F, O>>::is_source(&target.node(added_node_id).op) {
                // this is actually an input replacement
                model_input_outlets.iter_mut().for_each(|oo| {
                    if oo.node == replaced_inputs[&patch_node_id] {
                        oo.node = added_node_id;
                    }
                });
            }
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for (outlet, by) in shunt_outlet_by {
            let replace_by = mapping[&by];
            let succs = target.nodes()[outlet.node].outputs[outlet.slot].successors.clone();
            for succ in succs {
                target.add_edge(replace_by, succ)?;
            }
            for o in target.outputs.iter_mut() {
                if *o == outlet {
                    *o = replace_by;
                }
            }
            if let Some(label) = target.outlet_labels.remove(&outlet) {
                target.set_outlet_label(replace_by, label)?;
            }
        }
        if target.outputs.len() > target.outputs.iter().sorted().dedup().count() {
            bail!("Duplicate usage of node as output");
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for (node, inputs) in all_inputs {
            for (ix, input) in inputs.into_iter().enumerate() {
                target.add_edge(mapping[&input], InletId::new(node, ix))?;
            }
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for node in obliterate {
            target.node_mut(node).op = target.create_dummy();
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        target.set_input_outlets(&model_input_outlets)?;
        Ok(())
    }
More examples
Hide additional examples
src/model/translator.rs (line 114)
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    fn translate_node(
        &self,
        source: &Graph<TI1, O1>,
        node: &Node<TI1, O1>,
        target: &mut Graph<TI2, O2>,
        mapping: &HashMap<OutletId, OutletId>,
    ) -> TractResult<TVec<OutletId>> {
        let node_is_input =
            (0..node.outputs.len()).all(|o| source.inputs.contains(&(node.id, o).into()));
        if node_is_input {
            (0..node.outputs.len())
                .map(|i| {
                    target.add_source(
                        if node.outputs.len() > 1 {
                            format!("{}-{}", node.name, i)
                        } else {
                            node.name.to_string()
                        },
                        TI2::try_from(&node.outputs[i].fact)?,
                    )
                })
                .collect()
        } else {
            let new_op = O2::try_from(&node.op)?;
            let facts = node
                .outputs
                .iter()
                .map(|of| Ok(TI2::try_from(&of.fact)?))
                .collect::<TractResult<TVec<_>>>()?;
            let new_id = target.add_node(node.name.clone(), new_op, facts)?;
            for (ix, o) in node.inputs.iter().enumerate() {
                target.add_edge(mapping[o], InletId::new(new_id, ix))?
            }
            Ok(node.outputs.iter().enumerate().map(|(ix, _)| OutletId::new(new_id, ix)).collect())
        }
    }
src/model/typed.rs (line 73)
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    fn wire_node(
        &mut self,
        name: impl Into<String>,
        op: impl Into<Box<dyn TypedOp>>,
        inputs: &[OutletId],
    ) -> TractResult<TVec<OutletId>> {
        let op = op.into();
        let name = name.into();

        {
            let output_facts = || -> TractResult<TVec<TypedFact>> {
                let input_facts = inputs
                    .iter()
                    .map(|o| self.outlet_fact(*o))
                    .collect::<TractResult<TVec<_>>>()?;
                let facts = op.output_facts(&input_facts).context("in output_facts invocation")?;
                if input_facts.iter().all(|f| f.konst.is_some()) && op.is_stateless() {
                    let tensors = input_facts
                        .iter()
                        .map(|f| f.konst.clone().unwrap().into_tvalue())
                        .collect::<TVec<_>>();
                    if let Ok(outputs) = op.eval(tensors) {
                        return Ok(outputs.into_iter().map(|t| TypedFact::from(&*t)).collect());
                    }
                }
                Ok(facts)
            };

            let output_facts = output_facts()
                .with_context(|| format!("wiring {} ({:?}), determining output_facts", name, op))?;
            let id = self.add_node(&name, &op, output_facts)?;
            inputs
                .iter()
                .enumerate()
                .try_for_each(|(ix, i)| self.add_edge(*i, InletId::new(id, ix)))?;
            TractResult::Ok(
                self.node(id)
                    .outputs
                    .iter()
                    .enumerate()
                    .map(|(ix, _)| OutletId::new(id, ix))
                    .collect(),
            )
        }
        .with_context(|| format!("Wiring node \"{}\", {:?}", name, op))
    }

Get model inputs.

Examples found in repository?
src/model/graph.rs (line 188)
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
    pub fn input_fact(&self, ix: usize) -> TractResult<&F> {
        let input = self.input_outlets()?[ix];
        self.outlet_fact(input)
    }

    /// Get the `ix`-th input tensor type information, mutably.
    pub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
        let input = self.input_outlets()?[ix];
        self.outlet_fact_mut(input)
    }

    /// Set the `ix`-th input tensor type information.
    pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
        let outlet = self.inputs[input];
        self.set_outlet_fact(outlet, fact)
    }

    /// Set the `ix`-th input tensor type information and return `self`.
    pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
        self.set_input_fact(input, fact)?;
        Ok(self)
    }

    // Outputs
    /// Get model outputs.
    pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
        Ok(&self.outputs)
    }

    /// Guess outputs from the topology: node or nodes with no successors.
    pub fn auto_outputs(&mut self) -> TractResult<()> {
        let outputs = self
            .nodes
            .iter()
            .flat_map(|n| {
                let id = n.id;
                n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
                    (OutletId::new(id, ix), output_fact.successors.len())
                })
            })
            .filter(|(_f, succs)| *succs == 0)
            .map(|(f, _)| f)
            .collect();
        self.outputs = outputs;
        Ok(())
    }

    /// Change model outputs.
    pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
        self.outputs = outputs.to_vec();
        Ok(())
    }

    /// Change model outputs and return `self`.
    pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
        self.set_output_outlets(outputs)?;
        Ok(self)
    }

    /// Set model outputs by node names.
    pub fn set_output_names(
        &mut self,
        outputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<()> {
        let mut labels: HashMap<Cow<str>, OutletId> =
            self.outlet_labels.iter().map(|(o, s)| (Cow::Borrowed(&**s), *o)).collect();
        for n in self.nodes() {
            for ix in 0..n.outputs.len() {
                labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
            }
        }
        let ids: Vec<OutletId> = outputs
            .into_iter()
            .map(|s| {
                let s = s.as_ref();
                labels
                    .get(s)
                    .cloned()
                    .or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
                    .ok_or_else(|| format_err!("Node {} not found", s))
            })
            .collect::<TractResult<_>>()?;
        self.outputs = ids;
        Ok(())
    }

    /// Set model outputs by node names and return `self`.
    pub fn with_output_names(
        mut self,
        outputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<Self> {
        self.set_output_names(outputs)?;
        Ok(self)
    }

    /// Get the `ix`-th input tensor type information.
    pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
        let output = self.output_outlets()?[ix];
        self.outlet_fact(output)
    }

    /// Get the `ix`-th input tensor type information, mutably.
    pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
        let output = self.output_outlets()?[ix];
        self.outlet_fact_mut(output)
    }

    /// Set the `ix`-th output tensor type information.
    pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
        let outlet = self.outputs[output];
        self.set_outlet_fact(outlet, fact)
    }

    /// Set the `ix`-th output tensor type information and return `self`.
    pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
        self.set_output_fact(output, fact)?;
        Ok(self)
    }

    // nodes and their facts

    /// Iterate over all node names.
    pub fn node_names(&self) -> impl Iterator<Item = &str> {
        self.nodes.iter().map(|s| &*s.name)
    }

    pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
        self.nodes
            .iter()
            .find(|n| n.name == name)
            .map(|n| n.id)
            .with_context(|| format!("No node found for name: \"{}\"", name))
    }

    /// Find a node by its name.
    pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
        let id: usize = self.node_id_by_name(name.as_ref())?;
        Ok(&self.nodes[id])
    }

    /// Borrow mutably a node by its name.
    pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
        let id: usize = self.node_id_by_name(name.as_ref())?;
        Ok(&mut self.nodes[id])
    }

    pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
        self.node_mut(id).name = name.to_string();
        Ok(())
    }

    /// Find a node by its id.
    pub fn node(&self, id: usize) -> &Node<F, O> {
        &self.nodes[id]
    }

    /// Find a node by its id.
    pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
        &mut self.nodes[id]
    }

    /// Access the nodes table.
    pub fn nodes(&self) -> &[Node<F, O>] {
        &self.nodes
    }

    /// Access the nodes table.
    pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
        &mut self.nodes
    }

    /// Get input and output tensor information for a node.
    pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
        Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
    }

    /// Get input tensor information for a node.
    pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
    }

    /// Get output tensor information for a node.
    pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
    }

    // outlets

    /// Get tensor information for a single outlet.
    pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
        anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
        let outlets = &self.nodes[outlet.node].outputs;
        outlets
            .get(outlet.slot)
            .map(|o| &o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get tensor information for a single outlet.
    pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        outlets
            .get_mut(outlet.slot)
            .map(|o| &mut o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get multiple mutable tensor information for outlets.
    pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
        assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
        unsafe {
            outlets
                .iter()
                .map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
                .collect()
        }
    }

    /// Set tensor information for a single outlet.
    pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        if outlets.len() <= outlet.slot {
            bail!("Invalid outlet refererence: {:?}", outlet)
        }
        outlets[outlet.slot].fact = fact;
        Ok(())
    }

    /// Set tensor information for a single outlet and return `self`.
    pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
        self.set_outlet_fact(outlet, fact)?;
        Ok(self)
    }

    // outlet labels

    /// Get label for an outlet.
    pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
        self.outlet_labels.get(&outlet).map(|s| &**s)
    }

    /// Set label for an outlet.
    pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
        self.outlet_labels.insert(outlet, label);
        Ok(())
    }

    /// Set label for an outlet and return `self`.
    pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
        self.set_outlet_label(outlet, label)?;
        Ok(self)
    }

    /// Find outlet by label.
    pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
        self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
    }

    // misc

    /// Computes an evalutation order for the graph inputs and outputs
    pub fn eval_order(&self) -> TractResult<Vec<usize>> {
        eval_order(self)
    }

    #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
    #[inline]
    pub fn check_edges(&self) -> TractResult<()> {
        Ok(())
    }

    /// Performs a sanity check on network connections.
    #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
    #[inline]
    pub fn check_edges(&self) -> TractResult<()> {
        for node_id in self.eval_order()? {
            let node = &self.nodes[node_id];
            for (ix, input) in node.inputs.iter().enumerate() {
                let prec = &self.nodes[input.node];
                if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
                    bail!(
                        "Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
                        node.id,
                        ix,
                        prec
                    )
                }
            }
            for (ix, output) in node.outputs.iter().enumerate() {
                for succ in &output.successors {
                    if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
                        bail!(
                            "Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
                            node.id,
                            ix,
                            succ
                        )
                    }
                }
            }
        }
        Ok(())
    }

    /// Converts the model into a `RunnableModel` which fixes the inputs and outputs and allows passing data through the model.
    pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
        crate::plan::SimplePlan::new(self)
    }

    pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
        let node = &self.nodes()[id];
        if node.inputs.len() != 1 {
            return Ok(None);
        }
        let prec = &self.nodes()[node.inputs[0].node];
        if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        Ok(Some(prec))
    }

    pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_prec(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_succ(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
        let node = &self.nodes()[id];
        if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        let succ = node.outputs[0].successors[0];
        let succ = &self.nodes()[succ.node];
        if succ.inputs.len() != 1 {
            return Ok(None);
        }
        Ok(Some(succ))
    }

    pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
        &self.nodes[outlet.node].outputs[outlet.slot].successors
    }
}

impl<F: Fact + Clone + 'static, O> Graph<F, O>
where
    F: Fact + Clone + 'static + From<std::sync::Arc<Tensor>> + Hash,
    O: fmt::Debug
        + fmt::Display
        + From<crate::ops::konst::Const>
        + AsRef<dyn Op>
        + AsMut<dyn Op>
        + Clone
        + Hash
        + 'static,
{
    pub fn add_const(
        &mut self,
        name: impl Into<String>,
        v: impl IntoArcTensor,
    ) -> TractResult<OutletId> {
        let v = v.into_arc_tensor();
        let fact = F::from(v.clone());
        let name = name.into();
        self.add_node(name, crate::ops::konst::Const::new(v), tvec!(fact)).map(|id| id.into())
    }
}

impl<F, O> fmt::Display for Graph<F, O>
where
    F: Fact + Hash + Clone + 'static,
    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
        for i in 0..self.nodes.len() {
            let input_1 = self.nodes[i]
                .inputs
                .get(0)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let input_2 = self.nodes[i]
                .inputs
                .get(1)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let output_1 = self
                .outlet_successors(OutletId::new(i, 0))
                .get(0)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let output_2 = self
                .outlet_successors(OutletId::new(i, 0))
                .get(1)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            writeln!(
                fmt,
                "{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
                i,
                input_1,
                input_2,
                output_1,
                output_2,
                self.nodes[i].op().name(),
                self.nodes[i].name,
                self.node_input_facts(i).unwrap(),
                self.node_output_facts(i).unwrap(),
            )?;
            if self.nodes[i].inputs.len() > 2 {
                writeln!(
                    fmt,
                    "                                               |   * inputs: {}",
                    self.nodes[i].inputs.iter().map(|s| format!("{:?}", s)).join(", ")
                )?;
            }
            if self.nodes[i].outputs.len() > 1
                || self.outlet_successors((i, 0).into()).len() > 2
                || (self.outlet_label(i.into()).is_some()
                    && self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
            {
                for o in 0..self.nodes[i].outputs.len() {
                    if self.outlet_successors((i, o).into()).len() > 0 {
                        writeln!(
                                    fmt,
                                    "                                               |   * output #{}: {} {}",
                                    o,
                                    self.outlet_label((i, o).into()).unwrap_or(""),
                                    self.outlet_successors((i, o).into())
                                    .iter()
                                    .map(|s| format!("{:?}", s))
                                    .join(", "),
                                    )?;
                    }
                }
            }
        }
        writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{:?}", o)).join(", "))?;
        Ok(())
    }
}

impl<F, O> Graph<F, O>
where
    F: Fact + Clone + 'static + std::hash::Hash + for<'a> std::convert::From<&'a F>,
    O: std::fmt::Display
        + std::fmt::Debug
        + Clone
        + AsRef<dyn Op>
        + AsMut<dyn Op>
        + Clone
        + 'static
        + std::hash::Hash
        + for<'a> std::convert::From<&'a O>,
    Graph<F, O>: SpecialOps<F, O>,
{
    #[cfg(debug_assertions)]
    pub fn check_compact(&self) -> TractResult<()> {
        let order = self.eval_order()?;
        let useless_sources = self
            .input_outlets()?
            .iter()
            .filter(|io| {
                self.outlet_successors(**io).len() == 0
                    && !self.output_outlets().unwrap().contains(io)
            })
            .count();
        if order.len() + useless_sources != self.nodes.len() {
            bail!(
                "Eval order is {} long, nodes are {}, including {} unused sources",
                order.len(),
                self.nodes.len(),
                useless_sources
            );
        }
        if (0..order.len()).any(|ix| order[ix] != ix) {
            bail!("eval order is not trivial");
        }
        let mut seen = std::collections::HashSet::new();
        for (ix, n) in self.nodes.iter().enumerate() {
            if ix != n.id {
                bail!("Invalid node id: position is {}, node is {}", ix, n);
            }
            if seen.contains(&n.name) {
                eprintln!("{}", self);
                bail!("duplicate name {}", n.name);
            }
            seen.insert(&n.name);
        }
        Ok(())
    }
More examples
Hide additional examples
src/model/order.rs (line 13)
8
9
10
11
12
13
14
15
16
pub fn eval_order<F, O>(model: &super::Graph<F, O>) -> TractResult<Vec<usize>>
where
    F: Fact + Hash + Clone + 'static,
    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
    let inputs = model.input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
    let targets = model.output_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
    eval_order_for_nodes(model.nodes(), &inputs, &targets, &[])
}
src/ops/invariants.rs (line 346)
341
342
343
344
345
346
347
348
349
350
351
352
pub fn for_model(model: &TypedModel) -> TractResult<Invariants> {
    full_axis_tracking(model)?
        .into_iter()
        .map(|tracking| {
            let inputs =
                model.input_outlets()?.iter().map(|i| tracking.outlets.get(i).cloned()).collect();
            let outputs =
                model.output_outlets()?.iter().map(|i| tracking.outlets.get(i).cloned()).collect();
            Ok(AxisInfo { inputs, outputs, disposable: tracking.disposable, period: 1 })
        })
        .collect()
}
src/optim/change_axes.rs (line 30)
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    fn next(
        &mut self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut interfaces = model.output_outlets()?.to_vec();
        interfaces.extend(model.input_outlets()?.iter());
        for n in model.eval_order()? {
            for suggestion in model.node(n).op.suggested_axis_changes()? {
                if self.0.insert((n, suggestion.clone())) {
                    let outlet = suggestion.0.as_outlet(model.node(n));
                    let change = AxisChange { outlet, op: suggestion.1.clone() };
                    if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[])
                        .with_context(|| {
                            format!("Making patch for {:?} from {}", change, model.node(n))
                        })?
                    {
                        return Ok(Some(patch));
                    }
                }
            }
        }
        Ok(None)
    }
src/plan.rs (line 78)
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    pub fn new_for_outputs_and_deps(
        model: M,
        outputs: &[OutletId],
        deps: &[(usize, usize)],
    ) -> TractResult<SimplePlan<F, O, M>> {
        let inputs = model.borrow().input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
        let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
        let order = eval_order_for_nodes(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?;
        let mut values_needed_until_step = vec![0; model.borrow().nodes().len()];
        for (step, node) in order.iter().enumerate() {
            for i in &model.borrow().node(*node).inputs {
                values_needed_until_step[i.node] = step;
            }
        }
        for o in outputs.iter() {
            values_needed_until_step[o.node] = order.len();
        }
        let mut flush_lists: Vec<TVec<usize>> = vec![tvec!(); order.len() + 1];
        for (node, &flush_at) in values_needed_until_step.iter().enumerate() {
            if flush_at != 0 {
                flush_lists[flush_at].push(node)
            }
        }
        let mut symbols: std::collections::HashSet<Symbol> = Default::default();
        for node in &model.borrow().nodes {
            for output in &node.outputs {
                if let Ok(fact) = output.fact.to_typed_fact() {
                    symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
                }
            }
        }
        Ok(SimplePlan {
            model,
            order,
            flush_lists,
            outputs: outputs.to_vec(),
            has_unresolved_symbols: !symbols.is_empty(),
            _casper: PhantomData,
        })
    }

    pub fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let mut state = SimpleState::new(self)?;
        state.run(inputs)
    }

    pub fn model(&self) -> &Graph<F, O> {
        self.model.borrow()
    }
}

#[derive(Clone, Debug)]
pub struct SimpleState<F, O, M, P>
where
    F: Fact + Hash + Clone + 'static,
    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
    M: Borrow<Graph<F, O>> + Hash,
    P: Borrow<SimplePlan<F, O, M>>,
{
    plan: P,
    pub states: Vec<Option<Box<dyn OpState>>>,
    pub session_state: SessionState,
    pub values: Vec<Option<TVec<TValue>>>,
    _phantom: PhantomData<(M, F, O)>,
}

impl<F, O, M, P> SimpleState<F, O, M, P>
where
    F: Fact + Hash + Clone + 'static,
    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
    M: Borrow<Graph<F, O>> + Hash,
    P: Borrow<SimplePlan<F, O, M>> + Clone,
{
    pub fn new(plan: P) -> TractResult<SimpleState<F, O, M, P>> {
        let values = vec![None; plan.borrow().model.borrow().nodes().len()];
        let mut session = SessionState::default();
        let model = plan.borrow().model();
        let states: Vec<Option<Box<dyn OpState>>> = model
            .nodes()
            .iter()
            .map(|n: &Node<F, O>| n.op().state(&mut session, n.id))
            .collect::<TractResult<_>>()?;
        Ok(SimpleState { plan, states, session_state: session, values, _phantom: PhantomData })
    }

    /// Reset wires state.
    pub fn reset_turn(&mut self) -> TractResult<()> {
        self.values.iter_mut().for_each(|s| *s = None);
        Ok(())
    }

    /// Reset op inner state.
    pub fn reset_op_states(&mut self) -> TractResult<()> {
        let &mut SimpleState { ref plan, ref mut session_state, ref mut states, .. } = self;
        *states = plan
            .borrow()
            .model()
            .nodes()
            .iter()
            .map(|n| n.op().state(session_state, n.id))
            .collect::<TractResult<_>>()?;
        Ok(())
    }

    pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        self.run_plan_with_eval(inputs, self::eval)
    }

    pub fn exec(&mut self) -> TractResult<()> {
        self.exec_plan_with_eval(self::eval)
    }

    pub fn run_plan_with_eval<Eval, E>(
        &mut self,
        inputs: TVec<TValue>,
        eval: Eval,
    ) -> TractResult<TVec<TValue>>
    where
        Eval: for<'a, 'b, 'c> FnMut(
            &'a mut SessionState,
            Option<&'b mut (dyn OpState + 'static)>,
            &'c Node<F, O>,
            TVec<TValue>,
        ) -> Result<TVec<TValue>, E>,
        E: Into<anyhow::Error> + Send + Sync + 'static,
    {
        self.set_inputs(inputs)?;
        self.exec_plan_with_eval(eval)?;
        let outputs = self.outputs()?;
        self.reset_turn()?;
        Ok(outputs)
    }

    pub fn exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
    where
        Eval: for<'a, 'b, 'c> FnMut(
            &'a mut SessionState,
            Option<&'b mut (dyn OpState + 'static)>,
            &'c Node<F, O>,
            TVec<TValue>,
        ) -> Result<TVec<TValue>, E>,
        E: Into<anyhow::Error> + Send + Sync + 'static,
    {
        {
            let &mut SimpleState {
                ref plan,
                ref mut session_state,
                ref mut states,
                ref mut values,
                ..
            } = self;
            let plan = plan.borrow();
            let model = plan.model().borrow();
            for (step, n) in plan.order.iter().enumerate() {
                let node = model.node(*n);
                trace!("Running step {}, node {}", step, node);
                let mut inputs: TVec<TValue> = tvec![];
                for i in &node.inputs {
                    trace!("  use input {:?}", i);
                    let prec_node = model.node(i.node);
                    let prec = values[i.node].as_ref().ok_or_else(|| {
                        format_err!("Computing {}, precursor {} not done:", node, prec_node)
                    })?;
                    inputs.push(prec[i.slot].clone())
                }

                for flush in &plan.flush_lists[step] {
                    trace!("  Ran {} can now flush {}", node, model.node(*flush));
                    values[*flush] = None;
                }

                if cfg!(debug_assertions) {
                    let facts = model.node_input_facts(node.id)?;
                    if facts.len() != inputs.len() {
                        bail!(
                            "Evaluating {}: expected {} inputs, got {}",
                            node,
                            facts.len(),
                            inputs.len()
                        );
                    }
                    for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
                        if !f.matches(v, Some(&session_state.resolved_symbols))? {
                            bail!(
                                "Evaluating {}: input {:?}, expected {:?}, got {:?}",
                                node,
                                ix,
                                f,
                                v
                            );
                        }
                    }
                }

                let vs = eval(session_state, states[node.id].as_deref_mut(), node, inputs)
                    .map_err(|e| e.into())?;

                if plan.has_unresolved_symbols {
                    for (o, v) in node.outputs.iter().zip(vs.iter()) {
                        if let Ok(f) = o.fact.to_typed_fact() {
                            for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
                                Self::resolve(
                                    &mut session_state.resolved_symbols,
                                    &dim_abstract,
                                    *dim_concrete as i64,
                                );
                            }
                        }
                    }
                }
                if cfg!(debug_assertions) {
                    let facts = model.node_output_facts(node.id)?;
                    if facts.len() != vs.len() {
                        bail!(
                            "Evaluating {}: expected {} outputs, got {}",
                            node,
                            facts.len(),
                            vs.len()
                        );
                    }
                    for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
                        if node.outputs[ix].successors.len() == 0 {
                            continue;
                        }
                        if !f.matches(v, Some(&session_state.resolved_symbols))? {
                            bail!(
                                "Evaluating {}: output {:?}, expected {:?}, got {:?}",
                                node,
                                ix,
                                f,
                                v
                            );
                        }
                    }
                }

                values[node.id] = Some(vs);
            }
        }
        Ok(())
    }

    pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
        ensure!(
            inputs.len() == self.model().inputs.len(),
            "Wrong number of inputs for model. Expected {} got {}",
            self.model().inputs.len(),
            inputs.len()
        );
        for (ix, t) in inputs.into_iter().enumerate() {
            self.set_input(ix, t)?
        }
        Ok(())
    }

    fn resolve(symbols: &mut SymbolValues, expected: &TDim, provided: i64) {
        match expected {
            TDim::Sym(s) => symbols[s] = Some(provided),
            TDim::MulInt(x, expr) => Self::resolve(symbols, expr, provided / *x),
            _ => (),
        }
    }

    pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
        let outlet: OutletId = *self
            .model()
            .input_outlets()?
            .get(input)
            .ok_or_else(|| format_err!("Invalid input id for model ({}).", input))?;
        let SimpleState { plan, session_state, .. } = self;
        let plan = (*plan).borrow();
        let model = plan.model.borrow();
        if let Ok(fact) = model.outlet_fact(outlet)?.to_typed_fact() {
            for (expected, provided) in fact.shape.iter().zip(t.shape()) {
                Self::resolve(&mut session_state.resolved_symbols, &expected, *provided as i64)
            }
        }
        let fact = self.plan.borrow().model().outlet_fact(outlet)?;
        ensure!(
            fact.matches(&t, Some(&self.session_state.resolved_symbols))
            .with_context(|| format!("Setting input {}", input))?,
            "Input at index {} has incorrect dtype or shape (got shape {:?} and dtype {:?}, expected to match fact {:?})",
            input,
            t.shape(),
            t.datum_type(),
            fact
            );
        self.session_state.inputs.insert(outlet.node, t);
        Ok(())
    }
src/model/translator.rs (line 46)
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    fn translate_model_with_mappings(
        &self,
        source: &Graph<TI1, O1>,
    ) -> TractResult<(Graph<TI2, O2>, HashMap<OutletId, OutletId>)> {
        let mut target = Graph::default();
        let mut mapping = HashMap::new();
        for old_id in source.eval_order()? {
            let node = source.node(old_id);
            trace!("Translating {} {:?}", node, self);
            let outlets = self
                .translate_node(source, node, &mut target, &mapping)
                .with_context(|| format!("Translating node {} {:?}", node, self))?;
            for (ix, outlet) in outlets.into_iter().enumerate() {
                mapping.insert(OutletId::new(node.id, ix), outlet);
                if let Some(label) = source.outlet_label(OutletId::new(node.id, ix)) {
                    target.set_outlet_label(outlet, label.to_string())?;
                }
            }
        }
        // do not drop inputs, even if they are useless, to maintain interface
        for i in source.input_outlets()? {
            if !mapping.contains_key(i) {
                let node = source.node(i.node);
                trace!("Translate useless source {}", node);
                let outlets = self
                    .translate_node(source, node, &mut target, &mapping)
                    .with_context(|| format!("Translating input {} {:?}", node, self))?;
                mapping.insert(*i, outlets[0]);
            }
        }
        // maintaining order of i/o interface
        target.inputs = source.input_outlets()?.iter().map(|i| mapping[i]).collect();
        target.outputs = source.output_outlets()?.iter().map(|o| mapping[o]).collect();
        target.symbol_table = source.symbol_table.clone();
        target.properties = source.properties.clone();
        Ok((target, mapping))
    }

Change model inputs.

Examples found in repository?
src/model/graph.rs (line 157)
156
157
158
159
    pub fn with_input_outlets(mut self, inputs: &[OutletId]) -> TractResult<Self> {
        self.set_input_outlets(inputs)?;
        Ok(self)
    }
More examples
Hide additional examples
src/ops/scan/mir.rs (line 210)
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    fn declutter_discard_unused_input_mapping(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (inner_input_id, input) in self.body.input_outlets()?.iter().enumerate() {
            let source_node = self.body.node(input.node);
            if source_node.outputs[0].successors.len() == 0
                && !self.body.output_outlets()?.contains(input)
            {
                let mut new_inputs = node.inputs.clone();
                let slot = match &self.input_mapping[inner_input_id] {
                    InputMapping::Full { slot } => Some(*slot),
                    InputMapping::Scan(info) => Some(info.slot),
                    InputMapping::State { initializer } => match initializer {
                        StateInitializer::FromInput(n) => Some(*n),
                        _ => None,
                    },
                };
                let mut new_mappings: Vec<_> = self.input_mapping.clone();
                new_mappings.remove(inner_input_id);
                if let Some(slot) = slot {
                    new_mappings = Self::remove_outer_input_from_mappings(&new_mappings, slot);
                }
                let mut model_inputs = self.body.input_outlets()?.to_vec();
                if let Some(slot) = slot {
                    new_inputs.remove(slot);
                }
                model_inputs.remove(inner_input_id);
                let mut body = self.body.clone();
                let mut patch = TypedModelPatch::default();
                patch.obliterate(source_node.id)?;
                patch.apply(&mut body)?;
                body.set_input_outlets(&model_inputs)?;
                body.declutter()?;
                let op = Self {
                    body,
                    skip: self.skip,
                    seq_length_input_slot: self.seq_length_input_slot,
                    input_mapping: new_mappings,
                    decluttered: true,
                    output_mapping: self.output_mapping.clone(),
                };
                return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
            }
        }
        Ok(None)
    }
src/model/patch.rs (line 336)
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    pub fn apply(self, target: &mut Graph<F, O>) -> TractResult<()> {
        let prior_target_inputs = target.input_outlets()?.len();
        let prior_target_outputs = target.output_outlets()?.len();
        let ModelPatch {
            model: patch,
            incoming: mut mapping,
            shunt_outlet_by,
            obliterate,
            inputs: replaced_inputs,
            ..
        } = self;
        let mut all_inputs = HashMap::new(); // new_node_id_in_model -> [ patch_outlet_id ]
        let mut model_input_outlets = target.input_outlets()?.to_vec();
        for node in patch.nodes {
            if <Graph<F, O>>::is_source(&node.op)
                && mapping.contains_key(&OutletId::new(node.id, 0))
            {
                // this is a tap
                continue;
            }
            let Node { id: patch_node_id, name, inputs, op, outputs } = node;
            let n_outputs = outputs.len();
            for dup in 0..target.nodes.len() {
                if target.node(dup).op().same_as(op.as_ref())
                    && inputs.len() == target.node(dup).inputs.len()
                    && inputs
                        .iter()
                        .zip(target.node(dup).inputs.iter())
                        .all(|(patch_input, d)| mapping[patch_input] == *d)
                {
                    for ix in 0..n_outputs {
                        mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(dup, ix));
                    }
                    continue;
                }
            }
            let facts = outputs.into_iter().map(|of| of.fact).collect();
            let added_node_id = target.add_node(name, op, facts)?;
            for ix in 0..n_outputs {
                mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(added_node_id, ix));
            }
            all_inputs.insert(added_node_id, inputs);
            if <Graph<F, O>>::is_source(&target.node(added_node_id).op) {
                // this is actually an input replacement
                model_input_outlets.iter_mut().for_each(|oo| {
                    if oo.node == replaced_inputs[&patch_node_id] {
                        oo.node = added_node_id;
                    }
                });
            }
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for (outlet, by) in shunt_outlet_by {
            let replace_by = mapping[&by];
            let succs = target.nodes()[outlet.node].outputs[outlet.slot].successors.clone();
            for succ in succs {
                target.add_edge(replace_by, succ)?;
            }
            for o in target.outputs.iter_mut() {
                if *o == outlet {
                    *o = replace_by;
                }
            }
            if let Some(label) = target.outlet_labels.remove(&outlet) {
                target.set_outlet_label(replace_by, label)?;
            }
        }
        if target.outputs.len() > target.outputs.iter().sorted().dedup().count() {
            bail!("Duplicate usage of node as output");
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for (node, inputs) in all_inputs {
            for (ix, input) in inputs.into_iter().enumerate() {
                target.add_edge(mapping[&input], InletId::new(node, ix))?;
            }
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for node in obliterate {
            target.node_mut(node).op = target.create_dummy();
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        target.set_input_outlets(&model_input_outlets)?;
        Ok(())
    }

Change model inputs and return self.

Set model inputs by the node name.

Examples found in repository?
src/model/graph.rs (line 182)
178
179
180
181
182
183
184
    pub fn with_input_names(
        mut self,
        inputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<Self> {
        self.set_input_names(inputs)?;
        Ok(self)
    }

Set model inputs by the node name and return self.

Get the ix-th input tensor type information.

Get the ix-th input tensor type information, mutably.

Set the ix-th input tensor type information.

Examples found in repository?
src/model/graph.rs (line 206)
205
206
207
208
    pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
        self.set_input_fact(input, fact)?;
        Ok(self)
    }

Set the ix-th input tensor type information and return self.

Get model outputs.

Examples found in repository?
src/model/graph.rs (line 284)
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
    pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
        let output = self.output_outlets()?[ix];
        self.outlet_fact(output)
    }

    /// Get the `ix`-th input tensor type information, mutably.
    pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
        let output = self.output_outlets()?[ix];
        self.outlet_fact_mut(output)
    }

    /// Set the `ix`-th output tensor type information.
    pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
        let outlet = self.outputs[output];
        self.set_outlet_fact(outlet, fact)
    }

    /// Set the `ix`-th output tensor type information and return `self`.
    pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
        self.set_output_fact(output, fact)?;
        Ok(self)
    }

    // nodes and their facts

    /// Iterate over all node names.
    pub fn node_names(&self) -> impl Iterator<Item = &str> {
        self.nodes.iter().map(|s| &*s.name)
    }

    pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
        self.nodes
            .iter()
            .find(|n| n.name == name)
            .map(|n| n.id)
            .with_context(|| format!("No node found for name: \"{}\"", name))
    }

    /// Find a node by its name.
    pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
        let id: usize = self.node_id_by_name(name.as_ref())?;
        Ok(&self.nodes[id])
    }

    /// Borrow mutably a node by its name.
    pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
        let id: usize = self.node_id_by_name(name.as_ref())?;
        Ok(&mut self.nodes[id])
    }

    pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
        self.node_mut(id).name = name.to_string();
        Ok(())
    }

    /// Find a node by its id.
    pub fn node(&self, id: usize) -> &Node<F, O> {
        &self.nodes[id]
    }

    /// Find a node by its id.
    pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
        &mut self.nodes[id]
    }

    /// Access the nodes table.
    pub fn nodes(&self) -> &[Node<F, O>] {
        &self.nodes
    }

    /// Access the nodes table.
    pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
        &mut self.nodes
    }

    /// Get input and output tensor information for a node.
    pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
        Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
    }

    /// Get input tensor information for a node.
    pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
    }

    /// Get output tensor information for a node.
    pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
    }

    // outlets

    /// Get tensor information for a single outlet.
    pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
        anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
        let outlets = &self.nodes[outlet.node].outputs;
        outlets
            .get(outlet.slot)
            .map(|o| &o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get tensor information for a single outlet.
    pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        outlets
            .get_mut(outlet.slot)
            .map(|o| &mut o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get multiple mutable tensor information for outlets.
    pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
        assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
        unsafe {
            outlets
                .iter()
                .map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
                .collect()
        }
    }

    /// Set tensor information for a single outlet.
    pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        if outlets.len() <= outlet.slot {
            bail!("Invalid outlet refererence: {:?}", outlet)
        }
        outlets[outlet.slot].fact = fact;
        Ok(())
    }

    /// Set tensor information for a single outlet and return `self`.
    pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
        self.set_outlet_fact(outlet, fact)?;
        Ok(self)
    }

    // outlet labels

    /// Get label for an outlet.
    pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
        self.outlet_labels.get(&outlet).map(|s| &**s)
    }

    /// Set label for an outlet.
    pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
        self.outlet_labels.insert(outlet, label);
        Ok(())
    }

    /// Set label for an outlet and return `self`.
    pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
        self.set_outlet_label(outlet, label)?;
        Ok(self)
    }

    /// Find outlet by label.
    pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
        self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
    }

    // misc

    /// Computes an evalutation order for the graph inputs and outputs
    pub fn eval_order(&self) -> TractResult<Vec<usize>> {
        eval_order(self)
    }

    #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
    #[inline]
    pub fn check_edges(&self) -> TractResult<()> {
        Ok(())
    }

    /// Performs a sanity check on network connections.
    #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
    #[inline]
    pub fn check_edges(&self) -> TractResult<()> {
        for node_id in self.eval_order()? {
            let node = &self.nodes[node_id];
            for (ix, input) in node.inputs.iter().enumerate() {
                let prec = &self.nodes[input.node];
                if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
                    bail!(
                        "Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
                        node.id,
                        ix,
                        prec
                    )
                }
            }
            for (ix, output) in node.outputs.iter().enumerate() {
                for succ in &output.successors {
                    if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
                        bail!(
                            "Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
                            node.id,
                            ix,
                            succ
                        )
                    }
                }
            }
        }
        Ok(())
    }

    /// Converts the model into a `RunnableModel` which fixes the inputs and outputs and allows passing data through the model.
    pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
        crate::plan::SimplePlan::new(self)
    }

    pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
        let node = &self.nodes()[id];
        if node.inputs.len() != 1 {
            return Ok(None);
        }
        let prec = &self.nodes()[node.inputs[0].node];
        if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        Ok(Some(prec))
    }

    pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_prec(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_succ(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
        let node = &self.nodes()[id];
        if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        let succ = node.outputs[0].successors[0];
        let succ = &self.nodes()[succ.node];
        if succ.inputs.len() != 1 {
            return Ok(None);
        }
        Ok(Some(succ))
    }

    pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
        &self.nodes[outlet.node].outputs[outlet.slot].successors
    }
}

impl<F: Fact + Clone + 'static, O> Graph<F, O>
where
    F: Fact + Clone + 'static + From<std::sync::Arc<Tensor>> + Hash,
    O: fmt::Debug
        + fmt::Display
        + From<crate::ops::konst::Const>
        + AsRef<dyn Op>
        + AsMut<dyn Op>
        + Clone
        + Hash
        + 'static,
{
    pub fn add_const(
        &mut self,
        name: impl Into<String>,
        v: impl IntoArcTensor,
    ) -> TractResult<OutletId> {
        let v = v.into_arc_tensor();
        let fact = F::from(v.clone());
        let name = name.into();
        self.add_node(name, crate::ops::konst::Const::new(v), tvec!(fact)).map(|id| id.into())
    }
}

impl<F, O> fmt::Display for Graph<F, O>
where
    F: Fact + Hash + Clone + 'static,
    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
        for i in 0..self.nodes.len() {
            let input_1 = self.nodes[i]
                .inputs
                .get(0)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let input_2 = self.nodes[i]
                .inputs
                .get(1)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let output_1 = self
                .outlet_successors(OutletId::new(i, 0))
                .get(0)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let output_2 = self
                .outlet_successors(OutletId::new(i, 0))
                .get(1)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            writeln!(
                fmt,
                "{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
                i,
                input_1,
                input_2,
                output_1,
                output_2,
                self.nodes[i].op().name(),
                self.nodes[i].name,
                self.node_input_facts(i).unwrap(),
                self.node_output_facts(i).unwrap(),
            )?;
            if self.nodes[i].inputs.len() > 2 {
                writeln!(
                    fmt,
                    "                                               |   * inputs: {}",
                    self.nodes[i].inputs.iter().map(|s| format!("{:?}", s)).join(", ")
                )?;
            }
            if self.nodes[i].outputs.len() > 1
                || self.outlet_successors((i, 0).into()).len() > 2
                || (self.outlet_label(i.into()).is_some()
                    && self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
            {
                for o in 0..self.nodes[i].outputs.len() {
                    if self.outlet_successors((i, o).into()).len() > 0 {
                        writeln!(
                                    fmt,
                                    "                                               |   * output #{}: {} {}",
                                    o,
                                    self.outlet_label((i, o).into()).unwrap_or(""),
                                    self.outlet_successors((i, o).into())
                                    .iter()
                                    .map(|s| format!("{:?}", s))
                                    .join(", "),
                                    )?;
                    }
                }
            }
        }
        writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{:?}", o)).join(", "))?;
        Ok(())
    }
}

impl<F, O> Graph<F, O>
where
    F: Fact + Clone + 'static + std::hash::Hash + for<'a> std::convert::From<&'a F>,
    O: std::fmt::Display
        + std::fmt::Debug
        + Clone
        + AsRef<dyn Op>
        + AsMut<dyn Op>
        + Clone
        + 'static
        + std::hash::Hash
        + for<'a> std::convert::From<&'a O>,
    Graph<F, O>: SpecialOps<F, O>,
{
    #[cfg(debug_assertions)]
    pub fn check_compact(&self) -> TractResult<()> {
        let order = self.eval_order()?;
        let useless_sources = self
            .input_outlets()?
            .iter()
            .filter(|io| {
                self.outlet_successors(**io).len() == 0
                    && !self.output_outlets().unwrap().contains(io)
            })
            .count();
        if order.len() + useless_sources != self.nodes.len() {
            bail!(
                "Eval order is {} long, nodes are {}, including {} unused sources",
                order.len(),
                self.nodes.len(),
                useless_sources
            );
        }
        if (0..order.len()).any(|ix| order[ix] != ix) {
            bail!("eval order is not trivial");
        }
        let mut seen = std::collections::HashSet::new();
        for (ix, n) in self.nodes.iter().enumerate() {
            if ix != n.id {
                bail!("Invalid node id: position is {}, node is {}", ix, n);
            }
            if seen.contains(&n.name) {
                eprintln!("{}", self);
                bail!("duplicate name {}", n.name);
            }
            seen.insert(&n.name);
        }
        Ok(())
    }
More examples
Hide additional examples
src/plan.rs (line 59)
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
    pub fn new(model: M) -> TractResult<SimplePlan<F, O, M>> {
        let outputs = model.borrow().output_outlets()?.to_vec();
        Self::new_for_outputs(model, &outputs)
    }

    /// This contructor returns a plan that will compute the specified output.
    pub fn new_for_output(model: M, output: OutletId) -> TractResult<SimplePlan<F, O, M>> {
        Self::new_for_outputs_and_deps(model, &[output], &[])
    }

    /// This contructor returns a plan that will compute all specified outputs in one pass.
    pub fn new_for_outputs(model: M, outputs: &[OutletId]) -> TractResult<SimplePlan<F, O, M>> {
        Self::new_for_outputs_and_deps(model, outputs, &[])
    }

    pub fn new_for_outputs_and_deps(
        model: M,
        outputs: &[OutletId],
        deps: &[(usize, usize)],
    ) -> TractResult<SimplePlan<F, O, M>> {
        let inputs = model.borrow().input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
        let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
        let order = eval_order_for_nodes(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?;
        let mut values_needed_until_step = vec![0; model.borrow().nodes().len()];
        for (step, node) in order.iter().enumerate() {
            for i in &model.borrow().node(*node).inputs {
                values_needed_until_step[i.node] = step;
            }
        }
        for o in outputs.iter() {
            values_needed_until_step[o.node] = order.len();
        }
        let mut flush_lists: Vec<TVec<usize>> = vec![tvec!(); order.len() + 1];
        for (node, &flush_at) in values_needed_until_step.iter().enumerate() {
            if flush_at != 0 {
                flush_lists[flush_at].push(node)
            }
        }
        let mut symbols: std::collections::HashSet<Symbol> = Default::default();
        for node in &model.borrow().nodes {
            for output in &node.outputs {
                if let Ok(fact) = output.fact.to_typed_fact() {
                    symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
                }
            }
        }
        Ok(SimplePlan {
            model,
            order,
            flush_lists,
            outputs: outputs.to_vec(),
            has_unresolved_symbols: !symbols.is_empty(),
            _casper: PhantomData,
        })
    }

    pub fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let mut state = SimpleState::new(self)?;
        state.run(inputs)
    }

    pub fn model(&self) -> &Graph<F, O> {
        self.model.borrow()
    }
}

#[derive(Clone, Debug)]
pub struct SimpleState<F, O, M, P>
where
    F: Fact + Hash + Clone + 'static,
    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
    M: Borrow<Graph<F, O>> + Hash,
    P: Borrow<SimplePlan<F, O, M>>,
{
    plan: P,
    pub states: Vec<Option<Box<dyn OpState>>>,
    pub session_state: SessionState,
    pub values: Vec<Option<TVec<TValue>>>,
    _phantom: PhantomData<(M, F, O)>,
}

impl<F, O, M, P> SimpleState<F, O, M, P>
where
    F: Fact + Hash + Clone + 'static,
    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
    M: Borrow<Graph<F, O>> + Hash,
    P: Borrow<SimplePlan<F, O, M>> + Clone,
{
    pub fn new(plan: P) -> TractResult<SimpleState<F, O, M, P>> {
        let values = vec![None; plan.borrow().model.borrow().nodes().len()];
        let mut session = SessionState::default();
        let model = plan.borrow().model();
        let states: Vec<Option<Box<dyn OpState>>> = model
            .nodes()
            .iter()
            .map(|n: &Node<F, O>| n.op().state(&mut session, n.id))
            .collect::<TractResult<_>>()?;
        Ok(SimpleState { plan, states, session_state: session, values, _phantom: PhantomData })
    }

    /// Reset wires state.
    pub fn reset_turn(&mut self) -> TractResult<()> {
        self.values.iter_mut().for_each(|s| *s = None);
        Ok(())
    }

    /// Reset op inner state.
    pub fn reset_op_states(&mut self) -> TractResult<()> {
        let &mut SimpleState { ref plan, ref mut session_state, ref mut states, .. } = self;
        *states = plan
            .borrow()
            .model()
            .nodes()
            .iter()
            .map(|n| n.op().state(session_state, n.id))
            .collect::<TractResult<_>>()?;
        Ok(())
    }

    pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        self.run_plan_with_eval(inputs, self::eval)
    }

    pub fn exec(&mut self) -> TractResult<()> {
        self.exec_plan_with_eval(self::eval)
    }

    pub fn run_plan_with_eval<Eval, E>(
        &mut self,
        inputs: TVec<TValue>,
        eval: Eval,
    ) -> TractResult<TVec<TValue>>
    where
        Eval: for<'a, 'b, 'c> FnMut(
            &'a mut SessionState,
            Option<&'b mut (dyn OpState + 'static)>,
            &'c Node<F, O>,
            TVec<TValue>,
        ) -> Result<TVec<TValue>, E>,
        E: Into<anyhow::Error> + Send + Sync + 'static,
    {
        self.set_inputs(inputs)?;
        self.exec_plan_with_eval(eval)?;
        let outputs = self.outputs()?;
        self.reset_turn()?;
        Ok(outputs)
    }

    pub fn exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
    where
        Eval: for<'a, 'b, 'c> FnMut(
            &'a mut SessionState,
            Option<&'b mut (dyn OpState + 'static)>,
            &'c Node<F, O>,
            TVec<TValue>,
        ) -> Result<TVec<TValue>, E>,
        E: Into<anyhow::Error> + Send + Sync + 'static,
    {
        {
            let &mut SimpleState {
                ref plan,
                ref mut session_state,
                ref mut states,
                ref mut values,
                ..
            } = self;
            let plan = plan.borrow();
            let model = plan.model().borrow();
            for (step, n) in plan.order.iter().enumerate() {
                let node = model.node(*n);
                trace!("Running step {}, node {}", step, node);
                let mut inputs: TVec<TValue> = tvec![];
                for i in &node.inputs {
                    trace!("  use input {:?}", i);
                    let prec_node = model.node(i.node);
                    let prec = values[i.node].as_ref().ok_or_else(|| {
                        format_err!("Computing {}, precursor {} not done:", node, prec_node)
                    })?;
                    inputs.push(prec[i.slot].clone())
                }

                for flush in &plan.flush_lists[step] {
                    trace!("  Ran {} can now flush {}", node, model.node(*flush));
                    values[*flush] = None;
                }

                if cfg!(debug_assertions) {
                    let facts = model.node_input_facts(node.id)?;
                    if facts.len() != inputs.len() {
                        bail!(
                            "Evaluating {}: expected {} inputs, got {}",
                            node,
                            facts.len(),
                            inputs.len()
                        );
                    }
                    for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
                        if !f.matches(v, Some(&session_state.resolved_symbols))? {
                            bail!(
                                "Evaluating {}: input {:?}, expected {:?}, got {:?}",
                                node,
                                ix,
                                f,
                                v
                            );
                        }
                    }
                }

                let vs = eval(session_state, states[node.id].as_deref_mut(), node, inputs)
                    .map_err(|e| e.into())?;

                if plan.has_unresolved_symbols {
                    for (o, v) in node.outputs.iter().zip(vs.iter()) {
                        if let Ok(f) = o.fact.to_typed_fact() {
                            for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
                                Self::resolve(
                                    &mut session_state.resolved_symbols,
                                    &dim_abstract,
                                    *dim_concrete as i64,
                                );
                            }
                        }
                    }
                }
                if cfg!(debug_assertions) {
                    let facts = model.node_output_facts(node.id)?;
                    if facts.len() != vs.len() {
                        bail!(
                            "Evaluating {}: expected {} outputs, got {}",
                            node,
                            facts.len(),
                            vs.len()
                        );
                    }
                    for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
                        if node.outputs[ix].successors.len() == 0 {
                            continue;
                        }
                        if !f.matches(v, Some(&session_state.resolved_symbols))? {
                            bail!(
                                "Evaluating {}: output {:?}, expected {:?}, got {:?}",
                                node,
                                ix,
                                f,
                                v
                            );
                        }
                    }
                }

                values[node.id] = Some(vs);
            }
        }
        Ok(())
    }

    pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
        ensure!(
            inputs.len() == self.model().inputs.len(),
            "Wrong number of inputs for model. Expected {} got {}",
            self.model().inputs.len(),
            inputs.len()
        );
        for (ix, t) in inputs.into_iter().enumerate() {
            self.set_input(ix, t)?
        }
        Ok(())
    }

    fn resolve(symbols: &mut SymbolValues, expected: &TDim, provided: i64) {
        match expected {
            TDim::Sym(s) => symbols[s] = Some(provided),
            TDim::MulInt(x, expr) => Self::resolve(symbols, expr, provided / *x),
            _ => (),
        }
    }

    pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
        let outlet: OutletId = *self
            .model()
            .input_outlets()?
            .get(input)
            .ok_or_else(|| format_err!("Invalid input id for model ({}).", input))?;
        let SimpleState { plan, session_state, .. } = self;
        let plan = (*plan).borrow();
        let model = plan.model.borrow();
        if let Ok(fact) = model.outlet_fact(outlet)?.to_typed_fact() {
            for (expected, provided) in fact.shape.iter().zip(t.shape()) {
                Self::resolve(&mut session_state.resolved_symbols, &expected, *provided as i64)
            }
        }
        let fact = self.plan.borrow().model().outlet_fact(outlet)?;
        ensure!(
            fact.matches(&t, Some(&self.session_state.resolved_symbols))
            .with_context(|| format!("Setting input {}", input))?,
            "Input at index {} has incorrect dtype or shape (got shape {:?} and dtype {:?}, expected to match fact {:?})",
            input,
            t.shape(),
            t.datum_type(),
            fact
            );
        self.session_state.inputs.insert(outlet.node, t);
        Ok(())
    }

    pub fn output(&self, id: usize) -> TractResult<&TValue> {
        let outlet = self.model().output_outlets()?.get(id).with_context(|| {
            format!(
                "Required output {}, only have {}",
                id,
                self.model().output_outlets().unwrap().len()
            )
        })?;
        let value: &TValue = self
            .values
            .get(outlet.node)
            .context("node id for output beyond node values array")?
            .as_ref()
            .context("node is not an output")?
            .get(outlet.slot)
            .context("slot id too high")?;
        Ok(value)
    }
src/model/order.rs (line 14)
8
9
10
11
12
13
14
15
16
pub fn eval_order<F, O>(model: &super::Graph<F, O>) -> TractResult<Vec<usize>>
where
    F: Fact + Hash + Clone + 'static,
    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
    let inputs = model.input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
    let targets = model.output_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
    eval_order_for_nodes(model.nodes(), &inputs, &targets, &[])
}
src/ops/invariants.rs (line 348)
341
342
343
344
345
346
347
348
349
350
351
352
pub fn for_model(model: &TypedModel) -> TractResult<Invariants> {
    full_axis_tracking(model)?
        .into_iter()
        .map(|tracking| {
            let inputs =
                model.input_outlets()?.iter().map(|i| tracking.outlets.get(i).cloned()).collect();
            let outputs =
                model.output_outlets()?.iter().map(|i| tracking.outlets.get(i).cloned()).collect();
            Ok(AxisInfo { inputs, outputs, disposable: tracking.disposable, period: 1 })
        })
        .collect()
}
src/ops/cnn/maxpool.rs (line 54)
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if self.with_index_outputs.is_some()
            && node.outputs[1].successors.len() == 0
            && !model.output_outlets()?.contains(&OutletId::new(node.id, 1))
        {
            let op = Self { with_index_outputs: None, ..self.clone() };
            let mut patch = TypedModelPatch::default();
            let mut wire = patch.tap_model(model, node.inputs[0])?;
            wire = patch.wire_node(&node.name, op, &[wire])?[0];
            patch.shunt_outside(model, node.id.into(), wire)?;
            return Ok(Some(patch));
        }
        Ok(None)
    }
src/optim/change_axes.rs (line 29)
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    fn next(
        &mut self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut interfaces = model.output_outlets()?.to_vec();
        interfaces.extend(model.input_outlets()?.iter());
        for n in model.eval_order()? {
            for suggestion in model.node(n).op.suggested_axis_changes()? {
                if self.0.insert((n, suggestion.clone())) {
                    let outlet = suggestion.0.as_outlet(model.node(n));
                    let change = AxisChange { outlet, op: suggestion.1.clone() };
                    if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[])
                        .with_context(|| {
                            format!("Making patch for {:?} from {}", change, model.node(n))
                        })?
                    {
                        return Ok(Some(patch));
                    }
                }
            }
        }
        Ok(None)
    }

Guess outputs from the topology: node or nodes with no successors.

Change model outputs.

Examples found in repository?
src/model/graph.rs (line 242)
241
242
243
244
    pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
        self.set_output_outlets(outputs)?;
        Ok(self)
    }
More examples
Hide additional examples
src/ops/cnn/deconv/unary.rs (line 161)
156
157
158
159
160
161
162
163
    fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let input = args_1!(inputs);
        let mut model = TypedModel::default();
        let source = model.add_source("source", input.datum_type().fact(input.shape()))?;
        let output = self.wire_with_deconv_sum("adhoc", &mut model, source)?;
        model.set_output_outlets(&output)?;
        model.into_runnable()?.run(tvec!(input))
    }
src/ops/cnn/conv/unary.rs (line 803)
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let mut model = TypedModel::default();

        let mut wires: TVec<OutletId> = inputs
            .iter()
            .enumerate()
            .map(|(ix, v)| {
                model.add_source(format!("source.{}", ix), v.datum_type().fact(v.shape()))
            })
            .collect::<TractResult<_>>()?;
        let new_op = self.kernel_offset_u8_as_i8(&mut wires, &mut model)?;
        let wire = unsafe {
            if self.q_params.is_some() {
                let op_ref = if let Some(op) = new_op.as_ref() { op } else { self };
                op_ref.wire_as_quant_im2col(
                    &mut model,
                    "im2col-adhoc",
                    inputs[0].datum_type(),
                    &wires,
                )?
            } else {
                self.wire_as_im2col_pair(&mut model, "im2col-adhoc", wires[0])?
            }
        };
        model.set_output_outlets(&[wire])?;
        model.into_runnable()?.run(inputs)
    }
src/ops/matmul/mir_quant.rs (line 306)
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        ensure!(
            inputs[0].rank() == inputs[1].rank(),
            "Rank mismatch {:?} vs {:?}",
            inputs[0],
            inputs[1]
        );

        let mut model = TypedModel::default();
        let a = model.add_const("source_a", inputs[0].clone().into_arc_tensor())?;
        let b = model.add_const("source_b", inputs[1].clone().into_arc_tensor())?;
        let bias = model.add_const("source_bias", inputs[2].clone().into_arc_tensor())?;

        let mut input_outlets = tvec![a, b, bias];
        for (i, t) in inputs.iter().enumerate().skip(3) {
            input_outlets
                .push(model.add_const(format!("source_{}", i), t.clone().into_arc_tensor())?)
        }

        let mut params = self.params.as_outlet_ids(
            &mut model,
            "qmatmul_unary",
            &input_outlets,
            inputs[0].datum_type(),
            inputs[1].datum_type(),
            self.output_type,
        )?;

        let a = wire_offset_u8_as_i8(&mut model, "adhoc", a, "a", &mut params[0], "a0")?;
        let b = wire_offset_u8_as_i8(&mut model, "adhoc", b, "b", &mut params[2], "b0")?;

        let new_op = MatMul { axes: self.axes };
        let result = model.wire_node("adhoc.matmul", new_op, &[a, b])?[0];
        let result = wire_matmul_quant(
            &mut model,
            "adhoc",
            a,
            b,
            Some(bias),
            self.axes,
            result,
            self.output_type,
            &params,
        )?;
        model.set_output_outlets(&[result])?;
        model.into_runnable()?.run(tvec![])
    }
src/ops/matmul/mir_quant_unary.rs (line 77)
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        ensure!(inputs[0].rank() == self.a.rank(), "Rank mismatch {:?} vs {:?}", inputs[0], self.a);

        let mut model = TypedModel::default();
        let t_a = self.a.offset_u8_as_i8();
        let a = model.add_const("source_a", self.a.clone())?;
        let b = model.add_const("source_b", inputs[0].clone().into_arc_tensor())?;
        let bias = if let Some(bias) = self.bias.clone() {
            Some(model.add_const("source_bias", bias)?)
        } else {
            None
        };

        let mut input_outlets = tvec![a];
        for (i, t) in inputs.iter().enumerate().skip(1) {
            input_outlets
                .push(model.add_const(format!("source_{}", i), t.clone().into_arc_tensor())?)
        }

        let mut params = self.params.as_outlet_ids(
            &mut model,
            "qmatmul_unary",
            &input_outlets,
            self.a.datum_type(),
            inputs[0].datum_type(),
            self.output_type,
        )?;
        let a = wire_offset_u8_as_i8(&mut model, "adhoc", a, "a", &mut params[0], "a0")?;
        let b = wire_offset_u8_as_i8(&mut model, "adhoc", b, "b", &mut params[2], "b0")?;

        let new_op = MatMulUnary { a: t_a, axes: self.axes };
        let result = model.wire_node("adhoc.matmul", new_op, &[b])?[0];
        let result = wire_matmul_quant(
            &mut model,
            "adhoc",
            a,
            b,
            bias,
            self.axes,
            result,
            self.output_type,
            &params,
        )?;
        model.set_output_outlets(&[result])?;
        model.into_runnable()?.run(tvec![])
    }
src/ops/downsample/scan.rs (line 32)
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
pub fn pull_downsample_over_scan(
    model: &TypedModel,
    scan_node: &TypedNode,
    scan_op: &ops::scan::Scan,
    down_node: &TypedNode,
    down_op: &Downsample,
) -> TractResult<Option<TypedModelPatch>> {
    if down_op.stride < 0 {
        return Ok(None);
    }

    // introduce downsample at end of body
    let mut downsampled_body = scan_op.body.clone();
    downsampled_body.check_consistency()?;
    let outputs = downsampled_body.output_outlets()?.to_owned();
    let downsample_outputs = outputs
        .into_iter()
        .enumerate()
        .map(|(ix, oo)| {
            Ok(downsampled_body.wire_node(
                format!("{}-{}", &down_node.name, ix),
                down_op.clone(),
                &[oo],
            )?[0])
        })
        .collect::<TractResult<Vec<_>>>()?;
    downsampled_body.set_output_outlets(&downsample_outputs)?;
    downsampled_body.declutter()?;
    downsampled_body.check_consistency()?;

    // check if downsample ops introduced at end have swimmed up to scan inputs during declutter
    for input in downsampled_body.input_outlets()? {
        let input = downsampled_body.node(input.node);
        if input.outputs[0]
            .successors
            .iter()
            .any(|succ| !downsampled_body.node(succ.node).op().same_as(down_op))
        {
            return Ok(None);
        }
    }

    let inputs = downsampled_body.input_outlets()?.to_vec();
    for input in inputs {
        let node = &mut downsampled_body.node_mut(input.node);
        let fact = &mut node.outputs[0].fact;
        *fact = down_op.transform_fact(fact)?;
        node.op_as_mut::<crate::ops::source::TypedSource>().unwrap().fact = fact.clone();
        let downsamples = downsampled_body.node(input.node).outputs[0].successors.clone();
        for ds in downsamples {
            TypedModelPatch::shunt_one_op(&downsampled_body as _, downsampled_body.node(ds.node))?
                .apply(&mut downsampled_body)?;
        }
    }

    downsampled_body.check_consistency()?;
    let inner_model = downsampled_body.into_decluttered()?;

    let mut new_scan = scan_op.clone();
    new_scan.body = inner_model;
    for input in &mut new_scan.input_mapping {
        match input {
            InputMapping::State { ref mut initializer } => {
                if let StateInitializer::Value(ref v) = initializer {
                    let mut new_v = down_op.eval(tvec!(v.clone().into_tvalue()))?;
                    *initializer = StateInitializer::Value(new_v.remove(0).into_arc_tensor());
                }
            }
            InputMapping::Scan(info) => {
                if info.chunk > 0 && info.chunk as usize % down_op.stride as usize != 0 {
                    return Ok(None);
                }
                info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
                    * info.chunk.signum()
            }
            _ => (),
        }
    }
    for output in &mut new_scan.output_mapping {
        if let Some(d) = output.full_dim_hint.as_mut() {
            *d = down_op.transform_dim(d)
        }
        if let Some(info) = &mut output.scan {
            if info.chunk as usize % down_op.stride as usize != 0 {
                return Ok(None);
            }
            info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
                * info.chunk.signum()
        }
    }

    let mut patch = TypedModelPatch::default();
    let mut inputs = tvec!();
    for (ix, &i) in scan_node.inputs.iter().enumerate() {
        let tap = patch.tap_model(model, i)?;
        let ds = patch.wire_node(format!("{}-{}", down_node.name, ix), down_op.clone(), &[tap])?[0];
        inputs.push(ds);
    }
    let scan = patch.wire_node(&*scan_node.name, new_scan, &inputs)?;
    for ix in 0..scan_node.outputs.len() {
        // FIXME need to check earlier on that all output are followed by a ds
        let succ = scan_node.outputs[ix].successors[0].node;
        patch.shunt_outside(model, OutletId::new(succ, 0), scan[ix])?;
    }
    Ok(Some(patch))
}

Change model outputs and return self.

Set model outputs by node names.

Examples found in repository?
src/model/graph.rs (line 278)
274
275
276
277
278
279
280
    pub fn with_output_names(
        mut self,
        outputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<Self> {
        self.set_output_names(outputs)?;
        Ok(self)
    }

Set model outputs by node names and return self.

Get the ix-th input tensor type information.

Examples found in repository?
src/ops/scan/mir.rs (line 388)
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
    fn declutter_pull_constant_outputs(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
        for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
            if let Some(slot) = mapping.last_value_slot {
                if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
                    let inner_node = self.body.output_outlets()?[model_output_ix].node;
                    let inner_node = self.body.node(inner_node);
                    let mut patch = TypedModelPatch::new(format!("Extract const node {}", inner_node));
                    let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
                    patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
                    return Ok(Some(patch));
                }
            }
        }
        Ok(None)
    }

    fn declutter_pull_batcheable_output(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (model_ix, mapping) in self.output_mapping.iter().enumerate() {
            if let Some(info) = mapping.scan {
                let emitter_outlet = self.body.output_outlets()?[model_ix];
                let emitter_node = self.body.node(emitter_outlet.node);
                if emitter_node.outputs[emitter_outlet.slot].successors.len() > 0
                    || mapping.state
                    || mapping.scan.map(|i| i.chunk > 1).unwrap_or(true)
                {
                    // continue if both last_value and full values are exported
                    continue;
                }
                let (input_facts, output_facts) = self.body.node_facts(emitter_node.id)?;
                let invariants = emitter_node.op.invariants(&input_facts, &output_facts)?;
                let Some(axis_before) = invariants.unary_track_axis_up(info.axis, false)
                else {
                    continue;
                };

                let mut new_body = self.body.clone();
                let mut new_output_mapping = self.output_mapping.clone();
                let mut new_scan_outputs = node.outputs.len();
                let mut outer_slots = vec![];

                for input in &emitter_node.inputs {
                    if new_body.outputs.iter().all(|o| o != input) {
                        new_output_mapping.push(OutputMapping::default());
                        new_body.outputs.push(*input);
                    }
                    let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
                    let mut mapping = &mut new_output_mapping[body_output_id];
                    let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
                        if mapping.last_value_slot.is_none() {
                            mapping.last_value_slot = Some(new_scan_outputs);
                        }
                        new_scan_outputs += 1;
                        mapping.last_value_slot.unwrap()
                    } else {
                        if mapping.scan.is_none() {
                            mapping.scan = Some(ScanInfo {
                                slot: new_scan_outputs,
                                axis: axis_before,
                                chunk: info.chunk,
                            });
                            new_scan_outputs += 1;
                        }
                        mapping.scan.unwrap().slot
                    };
                    outer_slots.push(outer_slot);
                }
                let mut outside_patch = TypedModelPatch::new(format!(
                    "Outside patch for output extraction of {}",
                    emitter_node
                ));
                let inputs = node
                    .inputs
                    .iter()
                    .map(|&i| outside_patch.tap_model(model, i))
                    .collect::<TractResult<TVec<_>>>()?;
                let new_op = Self {
                    input_mapping: self.input_mapping.clone(),
                    output_mapping: new_output_mapping,
                    decluttered: false,
                    body: new_body,
                    skip: self.skip,
                    seq_length_input_slot: self.seq_length_input_slot,
                };
                let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
                let output = mapping.scan.unwrap();
                let inputs =
                    outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
                let wire = outside_patch.wire_node(
                    &*emitter_node.name,
                    emitter_node.op.clone(),
                    &inputs,
                )?[0];
                outside_patch.shunt_outside(model, OutletId::new(node.id, output.slot), wire)?;
                for output_slot in 0..node.outputs.len() {
                    if output_slot != output.slot {
                        outside_patch.shunt_outside(
                            model,
                            OutletId::new(node.id, output_slot),
                            OutletId::new(scan_outputs[0].node, output_slot),
                        )?;
                    }
                }
                return Ok(Some(outside_patch));
            }
        }
        Ok(None)
    }

    fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
        let input_state_outlets = self
            .input_mapping
            .iter()
            .zip(self.body.input_outlets()?.iter())
            .filter(|(m, _)| m.as_state().is_some())
            .map(|(_, o)| o);
        let output_state_outlets = self
            .output_mapping
            .iter()
            .zip(self.body.output_outlets()?.iter())
            .filter(|(m, _)| m.state)
            .map(|(_, o)| o);
        Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
    }

    fn body_exposed_outlets(&self) -> TractResult<TVec<OutletId>> {
        let input_outlets = self
            .input_mapping
            .iter()
            .zip(self.body.input_outlets()?.iter())
            .filter(|(m, _)| !m.invisible())
            .map(|(_, o)| o);
        let output_outlets = self
            .output_mapping
            .iter()
            .zip(self.body.output_outlets()?.iter())
            .filter(|(m, _)| !m.invisible())
            .map(|(_, o)| o);
        Ok(input_outlets.chain(output_outlets).cloned().collect())
    }

    fn try_body_axes_change(
        &self,
        change: AxisChange,
        locked_interface: bool,
    ) -> TractResult<Option<AxisChangeConsequence>> {
        self.body.check_consistency()?;
        let interface = self.body_exposed_outlets()?;
        let (patch, body_changed_wires) = if let Some(changes) =
            crate::ops::change_axes::change_axes(
                &self.body,
                &change,
                if locked_interface { &interface } else { &[] },
                &self.body_bounds()?,
            )? {
            changes
        } else {
            return Ok(None);
        };
        let mut body = self.body.clone();
        patch.apply(&mut body)?;
        body.compact()?;
        let mut wire_changes = tvec!();
        let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
        for (ix, m) in input_mapping.iter_mut().enumerate() {
            if let Some(change) = body_changed_wires
                .iter()
                .find(|(iface, _change)| iface == &InOut::In(ix))
                .map(|pair| pair.1.clone())
            {
                if let Some(slot) = m.slot() {
                    wire_changes.push((InOut::In(slot), change.clone()));
                }
                match &*m {
                    InputMapping::Full { .. } => (),
                    &InputMapping::Scan(info) => {
                        if let Some(axis) = change.transform_axis(info.axis) {
                            *m = InputMapping::Scan(ScanInfo { axis, ..info });
                        } else {
                            return Ok(None);
                        };
                    }
                    InputMapping::State { initializer } => match initializer {
                        StateInitializer::FromInput(_) => (),
                        StateInitializer::Value(ref v) => {
                            let mut v = v.clone().into_tensor();
                            change.change_tensor(&mut v, false)?;
                            *m = InputMapping::State {
                                initializer: StateInitializer::Value(v.into_arc_tensor()),
                            };
                        }
                    },
                };
            }
        }
        let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
        for (ix, m) in output_mapping.iter_mut().enumerate() {
            if let Some(change) = body_changed_wires
                .iter()
                .find(|(iface, _change)| iface == &InOut::Out(ix))
                .map(|pair| pair.1.clone())
            {
                if let Some(info) = m.scan.as_mut() {
                    if let Some(new_axis) = change.transform_axis(info.axis) {
                        info.axis = new_axis;
                    } else {
                        return Ok(None);
                    }
                    wire_changes.push((InOut::Out(info.slot), change.clone()));
                }
                if let Some(slot) = m.last_value_slot {
                    wire_changes.push((InOut::Out(slot), change.clone()));
                }
            };
        }
        body.check_consistency()?;
        let op = Some(Box::new(Scan {
            body,
            input_mapping,
            output_mapping,
            decluttered: false,
            ..self.clone()
        }) as _);
        Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
    }
}

impl Op for Scan {
    fn name(&self) -> Cow<str> {
        "Scan".into()
    }

    fn info(&self) -> TractResult<Vec<String>> {
        let mut lines = vec![];
        for (ix, im) in self.input_mapping.iter().enumerate() {
            lines.push(format!("Model input  #{}: {:?}", ix, im));
        }
        for (ix, om) in self.output_mapping.iter().enumerate() {
            lines.push(format!("Model output #{}: {:?}", ix, om));
        }
        Ok(lines)
    }

    fn validation(&self) -> Validation {
        Validation::Rounding
    }

    op_as_typed_op!();
}

impl EvalOp for Scan {
    fn is_stateless(&self) -> bool {
        false
    }
    fn state(
        &self,
        session: &mut SessionState,
        node_id: usize,
    ) -> TractResult<Option<Box<dyn OpState>>> {
        self.to_codegen_op(false)?.state(session, node_id)
    }
}

impl TypedOp for Scan {
    as_op!();

    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let mut outputs = tvec!();
        let iters = {
            let info = self.input_mapping.iter().flat_map(|it| it.as_scan()).next().unwrap();
            inputs[info.slot].shape[info.axis].clone().div_ceil(info.chunk.unsigned_abs() as u64)
        };
        for (ix, output) in self.output_mapping.iter().enumerate() {
            let fact = self.body.output_fact(ix)?;
            if let Some(info) = output.scan {
                let mut shape = fact.shape.clone();
                let scanning_dim =
                    output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
                shape.set(info.axis, scanning_dim);
                outputs.push((info.slot, fact.datum_type.fact(shape)));
            }
            if let Some(slot) = output.last_value_slot {
                outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
            }
        }
        outputs.sort_by_key(|a| a.0);
        anyhow::ensure!(outputs.iter().enumerate().all(|(ix, (slot, _))| ix == *slot));
        let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
        Ok(outputs)
    }
More examples
Hide additional examples
src/ops/scan/lir.rs (line 208)
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
    fn eval(
        &mut self,
        session: &mut SessionState,
        _op: &dyn Op,
        inputs: TVec<TValue>,
    ) -> TractResult<TVec<TValue>> {
        let State { op, ref mut hidden_state, ref mut position, ref mut model_state } = self;
        // initialize state at first pass
        if hidden_state.len() == 0 {
            for input in &op.input_mapping {
                if let InputMapping::State { initializer } = input {
                    hidden_state.push(match initializer {
                        StateInitializer::FromInput(slot) => inputs[*slot].clone(),
                        StateInitializer::Value(v) => (**v).to_owned().into_tvalue(),
                    });
                }
            }
        }

        let iters = {
            let info = op
                .input_mapping
                .iter()
                .find_map(|it| match it {
                    InputMapping::Scan(info) => Some(info),
                    _ => None,
                })
                .unwrap();
            inputs[info.slot].shape()[info.axis].divceil(info.chunk.unsigned_abs())
        };

        let mut outputs = tvec!();
        for (ix, output) in op.output_mapping.iter().enumerate() {
            if let Some(info) = output.scan {
                let fact = op.plan.model().output_fact(ix)?;
                let mut shape: TVec<usize> =
                    fact.shape.eval_to_usize(&session.resolved_symbols)?.into_owned();
                let scanning_dim = output
                    .full_dim_hint
                    .as_ref()
                    .and_then(|d| d.to_usize().ok())
                    .unwrap_or(shape[info.axis] * iters);
                shape[info.axis] = scanning_dim;
                let t = unsafe { Tensor::uninitialized_dt(fact.datum_type, &shape)? };
                outputs.push((info.slot, t));
            }
            if let Some(slot) = output.last_value_slot {
                outputs.push((slot, Tensor::default()));
            }
        }
        outputs.sort_by_key(|a| a.0);
        let mut outputs: TVec<Tensor> = outputs.into_iter().map(|(_slot, v)| v).collect();

        for i in 0..iters {
            *position += 1;
            if *position <= op.skip {
                continue;
            }
            hidden_state.reverse();

            let iter_inputs: TVec<TValue> = op
                .input_mapping
                .iter()
                .map(|m| {
                    Ok(match m {
                        InputMapping::State { .. } => Some(hidden_state.pop().unwrap()),
                        InputMapping::Scan(info) => Some(
                            Self::slice_input(&inputs[info.slot], info.axis, i, info.chunk)?
                                .into_tvalue(),
                        ),
                        InputMapping::Full { slot } => Some(inputs[*slot].clone()),
                    })
                })
                .collect::<TractResult<Vec<_>>>()?
                .into_iter()
                .flatten()
                .collect();

            trace!("iter_inputs #{}: {:?}", i, iter_inputs);
            let iter_outputs =
                model_state.run(iter_inputs).with_context(|| "Evaluating inner body")?;
            trace!("iter_outputs #{}: {:?}", i, iter_outputs);

            for (v, mapping) in iter_outputs.into_iter().zip(&op.output_mapping) {
                if let Some(info) = mapping.scan {
                    Self::assign_output(&mut outputs[info.slot], info.axis, &v, i, info.chunk < 0);
                }
                if i == iters - 1 {
                    if let Some(slot) = mapping.last_value_slot {
                        outputs[slot] = v.clone().into_tensor();
                    }
                }
                if mapping.state {
                    hidden_state.push(v);
                }
            }
        }

        Ok(outputs.into_iter().map(|t| t.into_tvalue()).collect())
    }
}

impl TypedOp for LirScan {
    as_op!();

    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let mut outputs = tvec!();
        let iters = {
            let info = self.input_mapping.iter().find_map(|it| it.as_scan()).unwrap();
            inputs[info.slot].shape[info.axis].clone().div_ceil(info.chunk.unsigned_abs() as _)
        };
        for (ix, output) in self.output_mapping.iter().enumerate() {
            let fact = self.plan.model().output_fact(ix)?;
            if let Some(slot) = output.last_value_slot {
                outputs.push((slot, fact.datum_type.fact(fact.shape.clone())));
            }
            if let Some(info) = output.scan {
                let mut shape = fact.shape.clone();
                let scanning_dim =
                    output.full_dim_hint.clone().unwrap_or(shape[info.axis].clone() * &iters);
                shape.set(info.axis, scanning_dim);
                outputs.push((info.slot, fact.datum_type.fact(shape)));
            }
        }
        outputs.sort_by_key(|a| a.0);
        let outputs: TVec<_> = outputs.into_iter().map(|(_slot, v)| v).collect();
        Ok(outputs)
    }

Get the ix-th input tensor type information, mutably.

Set the ix-th output tensor type information.

Examples found in repository?
src/model/graph.rs (line 302)
301
302
303
304
    pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
        self.set_output_fact(output, fact)?;
        Ok(self)
    }

Set the ix-th output tensor type information and return self.

Iterate over all node names.

Examples found in repository?
src/model/graph.rs (line 323)
322
323
324
325
326
327
328
329
330
331
    pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
        let id: usize = self.node_id_by_name(name.as_ref())?;
        Ok(&self.nodes[id])
    }

    /// Borrow mutably a node by its name.
    pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
        let id: usize = self.node_id_by_name(name.as_ref())?;
        Ok(&mut self.nodes[id])
    }

Find a node by its name.

Examples found in repository?
src/plan.rs (line 478)
477
478
479
480
    pub fn take_by_name(&mut self, name: &str) -> TractResult<TVec<Tensor>> {
        let id = self.model().node_by_name(name)?.id;
        Self::take(self, id)
    }
More examples
Hide additional examples
src/model/graph.rs (line 168)
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    pub fn set_input_names(
        &mut self,
        inputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<()> {
        let mut ids = vec![];
        for i in inputs.into_iter() {
            let node = self.node_by_name(&i)?;
            for o in 0..node.outputs.len() {
                ids.push(OutletId::new(node.id, o))
            }
        }
        self.inputs = ids;
        Ok(())
    }

Borrow mutably a node by its name.

Find a node by its id.

Examples found in repository?
src/model/graph.rs (line 509)
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
    pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_prec(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_succ(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }
More examples
Hide additional examples
src/ops/downsample/mod.rs (line 103)
94
95
96
97
98
99
100
101
102
103
104
    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if self.stride == 1 {
            return Ok(Some(TypedModelPatch::shunt_one_op(model, node)?));
        }
        pull_downsample_up(model, node)
            .with_context(|| format!("Pulling {} over {}", node, model.node(node.inputs[0].node)))
    }
src/ops/scan/mir.rs (line 86)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
    fn declutter_body_axes(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut suggestions = vec![];
        for n in self.body.eval_order()? {
            let node = self.body.node(n);
            for suggestion in node.op.suggested_axis_changes()? {
                let outlet = suggestion.0.as_outlet(node);
                suggestions.push(AxisChange { outlet, op: suggestion.1 })
            }
        }
        for suggestion in suggestions.into_iter() {
            if let Some(op) =
                self.try_body_axes_change(suggestion, true)?.and_then(|c| c.substitute_op)
            {
                return Ok(Some(TypedModelPatch::replace_single_op(
                    model,
                    node,
                    &node.inputs,
                    op,
                )?));
            }
        }
        Ok(None)
    }

    fn remove_outer_input_from_mappings(
        mappings: &[InputMapping],
        discarded: usize,
    ) -> Vec<InputMapping> {
        mappings
            .iter()
            .map(|m| match m {
                &InputMapping::Full { slot } => {
                    InputMapping::Full { slot: slot - (slot > discarded) as usize }
                }
                &InputMapping::Scan(info) => InputMapping::Scan(ScanInfo {
                    slot: info.slot - (info.slot > discarded) as usize,
                    ..info
                }),
                InputMapping::State { initializer } => {
                    let initializer = match initializer {
                        StateInitializer::FromInput(n) => {
                            StateInitializer::FromInput(*n - (*n > discarded) as usize)
                        }
                        StateInitializer::Value(v) => StateInitializer::Value(v.clone()),
                    };
                    InputMapping::State { initializer }
                }
            })
            .collect()
    }

    fn remove_outer_output_from_mappings(
        mappings: &[OutputMapping<TDim>],
        discarded: usize,
    ) -> Vec<OutputMapping<TDim>> {
        mappings
            .iter()
            .map(|m| OutputMapping {
                scan: m.scan.map(|info| ScanInfo {
                    slot: info.slot - (info.slot > discarded) as usize,
                    ..info
                }),
                last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
                full_dim_hint: m.full_dim_hint.clone(),
                state: m.state,
            })
            .collect()
    }

    fn declutter_const_initializer(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let inputs = model.node_input_facts(node.id)?;
        for (ix, mapping) in self.input_mapping.iter().enumerate() {
            if let InputMapping::State { initializer: StateInitializer::FromInput(n) } = mapping {
                if let Some(i) = inputs[*n].konst.as_ref() {
                    let mut op = self.clone();
                    op.input_mapping[ix] =
                        InputMapping::State { initializer: StateInitializer::Value(i.clone()) };
                    op.input_mapping =
                        Self::remove_outer_input_from_mappings(&op.input_mapping, *n);
                    let mut inputs = node.inputs.clone();
                    inputs.remove(*n);
                    return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
                }
            }
        }
        Ok(None)
    }

    fn declutter_discard_unused_input_mapping(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (inner_input_id, input) in self.body.input_outlets()?.iter().enumerate() {
            let source_node = self.body.node(input.node);
            if source_node.outputs[0].successors.len() == 0
                && !self.body.output_outlets()?.contains(input)
            {
                let mut new_inputs = node.inputs.clone();
                let slot = match &self.input_mapping[inner_input_id] {
                    InputMapping::Full { slot } => Some(*slot),
                    InputMapping::Scan(info) => Some(info.slot),
                    InputMapping::State { initializer } => match initializer {
                        StateInitializer::FromInput(n) => Some(*n),
                        _ => None,
                    },
                };
                let mut new_mappings: Vec<_> = self.input_mapping.clone();
                new_mappings.remove(inner_input_id);
                if let Some(slot) = slot {
                    new_mappings = Self::remove_outer_input_from_mappings(&new_mappings, slot);
                }
                let mut model_inputs = self.body.input_outlets()?.to_vec();
                if let Some(slot) = slot {
                    new_inputs.remove(slot);
                }
                model_inputs.remove(inner_input_id);
                let mut body = self.body.clone();
                let mut patch = TypedModelPatch::default();
                patch.obliterate(source_node.id)?;
                patch.apply(&mut body)?;
                body.set_input_outlets(&model_inputs)?;
                body.declutter()?;
                let op = Self {
                    body,
                    skip: self.skip,
                    seq_length_input_slot: self.seq_length_input_slot,
                    input_mapping: new_mappings,
                    decluttered: true,
                    output_mapping: self.output_mapping.clone(),
                };
                return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
            }
        }
        Ok(None)
    }

    fn declutter_discard_useless_outer_output(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (ix, o) in node.outputs.iter().enumerate() {
            if o.successors.len() == 0
                && !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
            {
                let mappings = self
                    .output_mapping
                    .iter()
                    .map(|m| OutputMapping {
                        scan: m.scan.filter(|info| info.slot != ix),
                        last_value_slot: m.last_value_slot.filter(|s| *s != ix),
                        full_dim_hint: m.full_dim_hint.clone(),
                        state: m.state,
                    })
                    .collect::<Vec<_>>();
                let mut op = self.clone();
                op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
                let mut patch = TypedModelPatch::default();
                let inputs = node
                    .inputs
                    .iter()
                    .map(|&i| patch.tap_model(model, i))
                    .collect::<TractResult<Vec<_>>>()?;
                let wires = patch.wire_node(&*node.name, op, &inputs)?;
                for oix in 0..node.outputs.len() {
                    if oix != ix {
                        patch.shunt_outside(
                            model,
                            OutletId::new(node.id, oix),
                            wires[oix - (oix > ix) as usize],
                        )?;
                    }
                }
                return Ok(Some(patch));
            }
        }
        Ok(None)
    }

    fn declutter_discard_empty_output_mapping_with_body_output(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (ix, om) in self.output_mapping.iter().enumerate() {
            if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
                let mut new_op = self.clone();
                new_op.output_mapping.remove(ix);
                new_op.body.outputs.remove(ix);
                new_op.decluttered = false;
                return Ok(Some(TypedModelPatch::replace_single_op(
                    model,
                    node,
                    &node.inputs,
                    new_op,
                )?));
            }
        }
        Ok(None)
    }

    fn declutter_pull_batcheable_input(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (model_input, input) in self.input_mapping.iter().enumerate() {
            if let Some(info) = input.as_scan() {
                let scan_source = self.body.input_outlets()?[model_input];
                let scan_source_node = self.body.node(scan_source.node);
                for successor in &scan_source_node.outputs[0].successors {
                    let successor_node = self.body.node(successor.node);
                    if successor_node.inputs.len() != 1 || successor_node.outputs.len() != 1 {
                        continue;
                    }
                    let (input_facts, output_facts) = self.body.node_facts(successor_node.id)?;
                    let invariants = successor_node.op.invariants(&input_facts, &output_facts)?;
                    if let Some(axis_after) = invariants.unary_track_axis_down(info.axis, false) {
                        let mut outside_patch = TypedModelPatch::new(format!(
                            "Outer patch for input extraction of {}",
                            successor_node
                        ));
                        let mut patch_inputs = node
                            .inputs
                            .iter()
                            .map(|&i| outside_patch.tap_model(model, i))
                            .collect::<TractResult<TVec<_>>>()?;
                        let input = patch_inputs[info.slot];
                        let new_input_wire = outside_patch.wire_node(
                            format!("{}.extracted.{}", node.name, successor_node.name),
                            successor_node.op.clone(),
                            &[input],
                        )?[0];
                        patch_inputs.push(new_input_wire);
                        let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
                        let mut new_input_inner_fact = new_input_outer_fact.clone();
                        new_input_inner_fact.shape.set(axis_after, info.chunk.abs().to_dim());

                        let mut new_body = self.body.clone();
                        let new_source_wire = new_body.add_source(
                            format!("{}.extracted.{}", node.name, successor_node.name),
                            new_input_inner_fact,
                        )?;
                        let mut inner_patch = TypedModelPatch::new(format!(
                            "Inner body patch for extraction of {}",
                            successor_node
                        ));
                        let new_source_wire_in_patch =
                            inner_patch.tap_model(&new_body, new_source_wire)?;
                        inner_patch
                            .shunt_outside(
                                &new_body,
                                OutletId::new(successor.node, 0),
                                new_source_wire_in_patch,
                            )
                            .with_context(|| "patching inner model")?;
                        inner_patch.apply(&mut new_body)?;

                        let mut input_mapping = self.input_mapping.clone();
                        input_mapping.push(InputMapping::Scan(ScanInfo {
                            axis: axis_after,
                            chunk: info.chunk,
                            slot: node.inputs.len(),
                        }));

                        let new_op = Self {
                            input_mapping,
                            output_mapping: self.output_mapping.clone(),
                            decluttered: false,
                            body: new_body,
                            skip: self.skip,
                            seq_length_input_slot: self.seq_length_input_slot,
                        };
                        let output_wires =
                            outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
                        for w in output_wires {
                            outside_patch
                                .shunt_outside(model, OutletId::new(node.id, w.slot), w)
                                .with_context(|| "patching outer model")?;
                        }
                        return Ok(Some(outside_patch));
                    }
                }
            }
        }
        Ok(None)
    }

    fn declutter_pull_constant_outputs(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
        for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
            if let Some(slot) = mapping.last_value_slot {
                if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
                    let inner_node = self.body.output_outlets()?[model_output_ix].node;
                    let inner_node = self.body.node(inner_node);
                    let mut patch = TypedModelPatch::new(format!("Extract const node {}", inner_node));
                    let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
                    patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
                    return Ok(Some(patch));
                }
            }
        }
        Ok(None)
    }

    fn declutter_pull_batcheable_output(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (model_ix, mapping) in self.output_mapping.iter().enumerate() {
            if let Some(info) = mapping.scan {
                let emitter_outlet = self.body.output_outlets()?[model_ix];
                let emitter_node = self.body.node(emitter_outlet.node);
                if emitter_node.outputs[emitter_outlet.slot].successors.len() > 0
                    || mapping.state
                    || mapping.scan.map(|i| i.chunk > 1).unwrap_or(true)
                {
                    // continue if both last_value and full values are exported
                    continue;
                }
                let (input_facts, output_facts) = self.body.node_facts(emitter_node.id)?;
                let invariants = emitter_node.op.invariants(&input_facts, &output_facts)?;
                let Some(axis_before) = invariants.unary_track_axis_up(info.axis, false)
                else {
                    continue;
                };

                let mut new_body = self.body.clone();
                let mut new_output_mapping = self.output_mapping.clone();
                let mut new_scan_outputs = node.outputs.len();
                let mut outer_slots = vec![];

                for input in &emitter_node.inputs {
                    if new_body.outputs.iter().all(|o| o != input) {
                        new_output_mapping.push(OutputMapping::default());
                        new_body.outputs.push(*input);
                    }
                    let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
                    let mut mapping = &mut new_output_mapping[body_output_id];
                    let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
                        if mapping.last_value_slot.is_none() {
                            mapping.last_value_slot = Some(new_scan_outputs);
                        }
                        new_scan_outputs += 1;
                        mapping.last_value_slot.unwrap()
                    } else {
                        if mapping.scan.is_none() {
                            mapping.scan = Some(ScanInfo {
                                slot: new_scan_outputs,
                                axis: axis_before,
                                chunk: info.chunk,
                            });
                            new_scan_outputs += 1;
                        }
                        mapping.scan.unwrap().slot
                    };
                    outer_slots.push(outer_slot);
                }
                let mut outside_patch = TypedModelPatch::new(format!(
                    "Outside patch for output extraction of {}",
                    emitter_node
                ));
                let inputs = node
                    .inputs
                    .iter()
                    .map(|&i| outside_patch.tap_model(model, i))
                    .collect::<TractResult<TVec<_>>>()?;
                let new_op = Self {
                    input_mapping: self.input_mapping.clone(),
                    output_mapping: new_output_mapping,
                    decluttered: false,
                    body: new_body,
                    skip: self.skip,
                    seq_length_input_slot: self.seq_length_input_slot,
                };
                let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
                let output = mapping.scan.unwrap();
                let inputs =
                    outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
                let wire = outside_patch.wire_node(
                    &*emitter_node.name,
                    emitter_node.op.clone(),
                    &inputs,
                )?[0];
                outside_patch.shunt_outside(model, OutletId::new(node.id, output.slot), wire)?;
                for output_slot in 0..node.outputs.len() {
                    if output_slot != output.slot {
                        outside_patch.shunt_outside(
                            model,
                            OutletId::new(node.id, output_slot),
                            OutletId::new(scan_outputs[0].node, output_slot),
                        )?;
                    }
                }
                return Ok(Some(outside_patch));
            }
        }
        Ok(None)
    }
src/optim/change_axes.rs (line 32)
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    fn next(
        &mut self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut interfaces = model.output_outlets()?.to_vec();
        interfaces.extend(model.input_outlets()?.iter());
        for n in model.eval_order()? {
            for suggestion in model.node(n).op.suggested_axis_changes()? {
                if self.0.insert((n, suggestion.clone())) {
                    let outlet = suggestion.0.as_outlet(model.node(n));
                    let change = AxisChange { outlet, op: suggestion.1.clone() };
                    if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[])
                        .with_context(|| {
                            format!("Making patch for {:?} from {}", change, model.node(n))
                        })?
                    {
                        return Ok(Some(patch));
                    }
                }
            }
        }
        Ok(None)
    }
src/optim/push_split_down.rs (line 16)
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    fn next(&mut self, _session: &mut OptimizerSession, model: &TypedModel) -> TractResult<Option<TypedModelPatch>> {
        let mut patch = TypedModelPatch::default();
        for node in model.eval_order()? {
            for output in &model.node(node).outputs {
                for (a, b) in output.successors.iter().tuple_combinations() {
                    if patch.obliterate.contains(&b.node) {
                        continue;
                    }
                    let a = model.node(a.node);
                    let b = model.node(b.node);
                    if a.same_as(b) {
                        for slot in 0..b.outputs.len() {
                            let tap = patch.tap_model(model, OutletId::new(a.id, slot))?;
                            patch.shunt_outside(model, OutletId::new(b.id, slot), tap)?;
                            patch.obliterate(b.id)?;
                        }
                    }
                }
            }
        }
        Ok(Some(patch).filter(|p| !p.is_empty()))
    }
src/optim/prop_const.rs (line 21)
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    fn next(
        &mut self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut patch = TypedModelPatch::default();
        for n in model.eval_order()? {
            let node = model.node(n);
            if node.op.is_stateless() && !node.op_is::<Const>() {
                if let Some(inputs) = model
                    .node_input_facts(n)?
                    .iter()
                    .map(|f| f.konst.clone().map(|t| t.into_tvalue()))
                    .collect()
                {
                    match node.op.eval(inputs) {
                        Ok(res) => {
                            for (ix, output) in res.into_iter().enumerate() {
                                let mut name = node.name.clone();
                                if ix > 0 {
                                    name = format!("{}.{}", name, ix);
                                }
                                let wire = patch.add_const(name, output.into_arc_tensor())?;
                                patch.shunt_outside(model, (n, ix).into(), wire)?;
                            }
                        }
                        Err(e) => {
                            if !e.root_cause().is::<UndeterminedSymbol>() {
                                Err(e).with_context(|| {
                                    format!("Eager eval {} during optimisation", model.node(n))
                                })?;
                            }
                        }
                    }
                }
            }
        }
        Ok(Some(patch).filter(|p| p.nodes.len() > 0))
    }

Find a node by its id.

Examples found in repository?
src/model/graph.rs (line 334)
333
334
335
336
    pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
        self.node_mut(id).name = name.to_string();
        Ok(())
    }
More examples
Hide additional examples
src/model/patch.rs (line 332)
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    pub fn apply(self, target: &mut Graph<F, O>) -> TractResult<()> {
        let prior_target_inputs = target.input_outlets()?.len();
        let prior_target_outputs = target.output_outlets()?.len();
        let ModelPatch {
            model: patch,
            incoming: mut mapping,
            shunt_outlet_by,
            obliterate,
            inputs: replaced_inputs,
            ..
        } = self;
        let mut all_inputs = HashMap::new(); // new_node_id_in_model -> [ patch_outlet_id ]
        let mut model_input_outlets = target.input_outlets()?.to_vec();
        for node in patch.nodes {
            if <Graph<F, O>>::is_source(&node.op)
                && mapping.contains_key(&OutletId::new(node.id, 0))
            {
                // this is a tap
                continue;
            }
            let Node { id: patch_node_id, name, inputs, op, outputs } = node;
            let n_outputs = outputs.len();
            for dup in 0..target.nodes.len() {
                if target.node(dup).op().same_as(op.as_ref())
                    && inputs.len() == target.node(dup).inputs.len()
                    && inputs
                        .iter()
                        .zip(target.node(dup).inputs.iter())
                        .all(|(patch_input, d)| mapping[patch_input] == *d)
                {
                    for ix in 0..n_outputs {
                        mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(dup, ix));
                    }
                    continue;
                }
            }
            let facts = outputs.into_iter().map(|of| of.fact).collect();
            let added_node_id = target.add_node(name, op, facts)?;
            for ix in 0..n_outputs {
                mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(added_node_id, ix));
            }
            all_inputs.insert(added_node_id, inputs);
            if <Graph<F, O>>::is_source(&target.node(added_node_id).op) {
                // this is actually an input replacement
                model_input_outlets.iter_mut().for_each(|oo| {
                    if oo.node == replaced_inputs[&patch_node_id] {
                        oo.node = added_node_id;
                    }
                });
            }
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for (outlet, by) in shunt_outlet_by {
            let replace_by = mapping[&by];
            let succs = target.nodes()[outlet.node].outputs[outlet.slot].successors.clone();
            for succ in succs {
                target.add_edge(replace_by, succ)?;
            }
            for o in target.outputs.iter_mut() {
                if *o == outlet {
                    *o = replace_by;
                }
            }
            if let Some(label) = target.outlet_labels.remove(&outlet) {
                target.set_outlet_label(replace_by, label)?;
            }
        }
        if target.outputs.len() > target.outputs.iter().sorted().dedup().count() {
            bail!("Duplicate usage of node as output");
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for (node, inputs) in all_inputs {
            for (ix, input) in inputs.into_iter().enumerate() {
                target.add_edge(mapping[&input], InletId::new(node, ix))?;
            }
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for node in obliterate {
            target.node_mut(node).op = target.create_dummy();
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        target.set_input_outlets(&model_input_outlets)?;
        Ok(())
    }
src/ops/downsample/scan.rs (line 50)
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
pub fn pull_downsample_over_scan(
    model: &TypedModel,
    scan_node: &TypedNode,
    scan_op: &ops::scan::Scan,
    down_node: &TypedNode,
    down_op: &Downsample,
) -> TractResult<Option<TypedModelPatch>> {
    if down_op.stride < 0 {
        return Ok(None);
    }

    // introduce downsample at end of body
    let mut downsampled_body = scan_op.body.clone();
    downsampled_body.check_consistency()?;
    let outputs = downsampled_body.output_outlets()?.to_owned();
    let downsample_outputs = outputs
        .into_iter()
        .enumerate()
        .map(|(ix, oo)| {
            Ok(downsampled_body.wire_node(
                format!("{}-{}", &down_node.name, ix),
                down_op.clone(),
                &[oo],
            )?[0])
        })
        .collect::<TractResult<Vec<_>>>()?;
    downsampled_body.set_output_outlets(&downsample_outputs)?;
    downsampled_body.declutter()?;
    downsampled_body.check_consistency()?;

    // check if downsample ops introduced at end have swimmed up to scan inputs during declutter
    for input in downsampled_body.input_outlets()? {
        let input = downsampled_body.node(input.node);
        if input.outputs[0]
            .successors
            .iter()
            .any(|succ| !downsampled_body.node(succ.node).op().same_as(down_op))
        {
            return Ok(None);
        }
    }

    let inputs = downsampled_body.input_outlets()?.to_vec();
    for input in inputs {
        let node = &mut downsampled_body.node_mut(input.node);
        let fact = &mut node.outputs[0].fact;
        *fact = down_op.transform_fact(fact)?;
        node.op_as_mut::<crate::ops::source::TypedSource>().unwrap().fact = fact.clone();
        let downsamples = downsampled_body.node(input.node).outputs[0].successors.clone();
        for ds in downsamples {
            TypedModelPatch::shunt_one_op(&downsampled_body as _, downsampled_body.node(ds.node))?
                .apply(&mut downsampled_body)?;
        }
    }

    downsampled_body.check_consistency()?;
    let inner_model = downsampled_body.into_decluttered()?;

    let mut new_scan = scan_op.clone();
    new_scan.body = inner_model;
    for input in &mut new_scan.input_mapping {
        match input {
            InputMapping::State { ref mut initializer } => {
                if let StateInitializer::Value(ref v) = initializer {
                    let mut new_v = down_op.eval(tvec!(v.clone().into_tvalue()))?;
                    *initializer = StateInitializer::Value(new_v.remove(0).into_arc_tensor());
                }
            }
            InputMapping::Scan(info) => {
                if info.chunk > 0 && info.chunk as usize % down_op.stride as usize != 0 {
                    return Ok(None);
                }
                info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
                    * info.chunk.signum()
            }
            _ => (),
        }
    }
    for output in &mut new_scan.output_mapping {
        if let Some(d) = output.full_dim_hint.as_mut() {
            *d = down_op.transform_dim(d)
        }
        if let Some(info) = &mut output.scan {
            if info.chunk as usize % down_op.stride as usize != 0 {
                return Ok(None);
            }
            info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
                * info.chunk.signum()
        }
    }

    let mut patch = TypedModelPatch::default();
    let mut inputs = tvec!();
    for (ix, &i) in scan_node.inputs.iter().enumerate() {
        let tap = patch.tap_model(model, i)?;
        let ds = patch.wire_node(format!("{}-{}", down_node.name, ix), down_op.clone(), &[tap])?[0];
        inputs.push(ds);
    }
    let scan = patch.wire_node(&*scan_node.name, new_scan, &inputs)?;
    for ix in 0..scan_node.outputs.len() {
        // FIXME need to check earlier on that all output are followed by a ds
        let succ = scan_node.outputs[ix].successors[0].node;
        patch.shunt_outside(model, OutletId::new(succ, 0), scan[ix])?;
    }
    Ok(Some(patch))
}

Access the nodes table.

Examples found in repository?
src/model/order.rs (line 15)
8
9
10
11
12
13
14
15
16
pub fn eval_order<F, O>(model: &super::Graph<F, O>) -> TractResult<Vec<usize>>
where
    F: Fact + Hash + Clone + 'static,
    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
    let inputs = model.input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
    let targets = model.output_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
    eval_order_for_nodes(model.nodes(), &inputs, &targets, &[])
}
More examples
Hide additional examples
src/optim/op_optim.rs (line 24)
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    fn full_pass(
        &mut self,
        session: &mut OptimizerSession,
        new: &TypedModel,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (ix, &id) in new.eval_order()?.iter().enumerate().skip(self.2) {
            let node = &new.nodes()[id];
            let patch = (self.1)(node.op.as_ref(), session, new, node)
                .with_context(|| format!("{:?} node {}", self, node))?;
            if let Some(mut p) = patch {
                p.push_context(format!("{:?} {}", self, node));
                self.2 = ix + p.dont_apply_twice.is_some() as usize;
                return Ok(Some(p));
            }
        }
        Ok(None)
    }
src/model/graph.rs (line 253)
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
    pub fn set_output_names(
        &mut self,
        outputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<()> {
        let mut labels: HashMap<Cow<str>, OutletId> =
            self.outlet_labels.iter().map(|(o, s)| (Cow::Borrowed(&**s), *o)).collect();
        for n in self.nodes() {
            for ix in 0..n.outputs.len() {
                labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
            }
        }
        let ids: Vec<OutletId> = outputs
            .into_iter()
            .map(|s| {
                let s = s.as_ref();
                labels
                    .get(s)
                    .cloned()
                    .or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
                    .ok_or_else(|| format_err!("Node {} not found", s))
            })
            .collect::<TractResult<_>>()?;
        self.outputs = ids;
        Ok(())
    }

    /// Set model outputs by node names and return `self`.
    pub fn with_output_names(
        mut self,
        outputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<Self> {
        self.set_output_names(outputs)?;
        Ok(self)
    }

    /// Get the `ix`-th input tensor type information.
    pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
        let output = self.output_outlets()?[ix];
        self.outlet_fact(output)
    }

    /// Get the `ix`-th input tensor type information, mutably.
    pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
        let output = self.output_outlets()?[ix];
        self.outlet_fact_mut(output)
    }

    /// Set the `ix`-th output tensor type information.
    pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
        let outlet = self.outputs[output];
        self.set_outlet_fact(outlet, fact)
    }

    /// Set the `ix`-th output tensor type information and return `self`.
    pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
        self.set_output_fact(output, fact)?;
        Ok(self)
    }

    // nodes and their facts

    /// Iterate over all node names.
    pub fn node_names(&self) -> impl Iterator<Item = &str> {
        self.nodes.iter().map(|s| &*s.name)
    }

    pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
        self.nodes
            .iter()
            .find(|n| n.name == name)
            .map(|n| n.id)
            .with_context(|| format!("No node found for name: \"{}\"", name))
    }

    /// Find a node by its name.
    pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
        let id: usize = self.node_id_by_name(name.as_ref())?;
        Ok(&self.nodes[id])
    }

    /// Borrow mutably a node by its name.
    pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
        let id: usize = self.node_id_by_name(name.as_ref())?;
        Ok(&mut self.nodes[id])
    }

    pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
        self.node_mut(id).name = name.to_string();
        Ok(())
    }

    /// Find a node by its id.
    pub fn node(&self, id: usize) -> &Node<F, O> {
        &self.nodes[id]
    }

    /// Find a node by its id.
    pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
        &mut self.nodes[id]
    }

    /// Access the nodes table.
    pub fn nodes(&self) -> &[Node<F, O>] {
        &self.nodes
    }

    /// Access the nodes table.
    pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
        &mut self.nodes
    }

    /// Get input and output tensor information for a node.
    pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
        Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
    }

    /// Get input tensor information for a node.
    pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
    }

    /// Get output tensor information for a node.
    pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
    }

    // outlets

    /// Get tensor information for a single outlet.
    pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
        anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
        let outlets = &self.nodes[outlet.node].outputs;
        outlets
            .get(outlet.slot)
            .map(|o| &o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get tensor information for a single outlet.
    pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        outlets
            .get_mut(outlet.slot)
            .map(|o| &mut o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get multiple mutable tensor information for outlets.
    pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
        assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
        unsafe {
            outlets
                .iter()
                .map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
                .collect()
        }
    }

    /// Set tensor information for a single outlet.
    pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        if outlets.len() <= outlet.slot {
            bail!("Invalid outlet refererence: {:?}", outlet)
        }
        outlets[outlet.slot].fact = fact;
        Ok(())
    }

    /// Set tensor information for a single outlet and return `self`.
    pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
        self.set_outlet_fact(outlet, fact)?;
        Ok(self)
    }

    // outlet labels

    /// Get label for an outlet.
    pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
        self.outlet_labels.get(&outlet).map(|s| &**s)
    }

    /// Set label for an outlet.
    pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
        self.outlet_labels.insert(outlet, label);
        Ok(())
    }

    /// Set label for an outlet and return `self`.
    pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
        self.set_outlet_label(outlet, label)?;
        Ok(self)
    }

    /// Find outlet by label.
    pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
        self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
    }

    // misc

    /// Computes an evalutation order for the graph inputs and outputs
    pub fn eval_order(&self) -> TractResult<Vec<usize>> {
        eval_order(self)
    }

    #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
    #[inline]
    pub fn check_edges(&self) -> TractResult<()> {
        Ok(())
    }

    /// Performs a sanity check on network connections.
    #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
    #[inline]
    pub fn check_edges(&self) -> TractResult<()> {
        for node_id in self.eval_order()? {
            let node = &self.nodes[node_id];
            for (ix, input) in node.inputs.iter().enumerate() {
                let prec = &self.nodes[input.node];
                if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
                    bail!(
                        "Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
                        node.id,
                        ix,
                        prec
                    )
                }
            }
            for (ix, output) in node.outputs.iter().enumerate() {
                for succ in &output.successors {
                    if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
                        bail!(
                            "Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
                            node.id,
                            ix,
                            succ
                        )
                    }
                }
            }
        }
        Ok(())
    }

    /// Converts the model into a `RunnableModel` which fixes the inputs and outputs and allows passing data through the model.
    pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
        crate::plan::SimplePlan::new(self)
    }

    pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
        let node = &self.nodes()[id];
        if node.inputs.len() != 1 {
            return Ok(None);
        }
        let prec = &self.nodes()[node.inputs[0].node];
        if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        Ok(Some(prec))
    }

    pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_prec(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_succ(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
        let node = &self.nodes()[id];
        if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        let succ = node.outputs[0].successors[0];
        let succ = &self.nodes()[succ.node];
        if succ.inputs.len() != 1 {
            return Ok(None);
        }
        Ok(Some(succ))
    }
src/plan.rs (line 80)
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
    pub fn new_for_outputs_and_deps(
        model: M,
        outputs: &[OutletId],
        deps: &[(usize, usize)],
    ) -> TractResult<SimplePlan<F, O, M>> {
        let inputs = model.borrow().input_outlets()?.iter().map(|n| n.node).collect::<Vec<usize>>();
        let outputs_nodes = outputs.iter().map(|n| n.node).collect::<Vec<usize>>();
        let order = eval_order_for_nodes(model.borrow().nodes(), &inputs, &outputs_nodes, deps)?;
        let mut values_needed_until_step = vec![0; model.borrow().nodes().len()];
        for (step, node) in order.iter().enumerate() {
            for i in &model.borrow().node(*node).inputs {
                values_needed_until_step[i.node] = step;
            }
        }
        for o in outputs.iter() {
            values_needed_until_step[o.node] = order.len();
        }
        let mut flush_lists: Vec<TVec<usize>> = vec![tvec!(); order.len() + 1];
        for (node, &flush_at) in values_needed_until_step.iter().enumerate() {
            if flush_at != 0 {
                flush_lists[flush_at].push(node)
            }
        }
        let mut symbols: std::collections::HashSet<Symbol> = Default::default();
        for node in &model.borrow().nodes {
            for output in &node.outputs {
                if let Ok(fact) = output.fact.to_typed_fact() {
                    symbols.extend(fact.shape.iter().flat_map(|d| d.symbols()))
                }
            }
        }
        Ok(SimplePlan {
            model,
            order,
            flush_lists,
            outputs: outputs.to_vec(),
            has_unresolved_symbols: !symbols.is_empty(),
            _casper: PhantomData,
        })
    }

    pub fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let mut state = SimpleState::new(self)?;
        state.run(inputs)
    }

    pub fn model(&self) -> &Graph<F, O> {
        self.model.borrow()
    }
}

#[derive(Clone, Debug)]
pub struct SimpleState<F, O, M, P>
where
    F: Fact + Hash + Clone + 'static,
    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
    M: Borrow<Graph<F, O>> + Hash,
    P: Borrow<SimplePlan<F, O, M>>,
{
    plan: P,
    pub states: Vec<Option<Box<dyn OpState>>>,
    pub session_state: SessionState,
    pub values: Vec<Option<TVec<TValue>>>,
    _phantom: PhantomData<(M, F, O)>,
}

impl<F, O, M, P> SimpleState<F, O, M, P>
where
    F: Fact + Hash + Clone + 'static,
    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
    M: Borrow<Graph<F, O>> + Hash,
    P: Borrow<SimplePlan<F, O, M>> + Clone,
{
    pub fn new(plan: P) -> TractResult<SimpleState<F, O, M, P>> {
        let values = vec![None; plan.borrow().model.borrow().nodes().len()];
        let mut session = SessionState::default();
        let model = plan.borrow().model();
        let states: Vec<Option<Box<dyn OpState>>> = model
            .nodes()
            .iter()
            .map(|n: &Node<F, O>| n.op().state(&mut session, n.id))
            .collect::<TractResult<_>>()?;
        Ok(SimpleState { plan, states, session_state: session, values, _phantom: PhantomData })
    }

    /// Reset wires state.
    pub fn reset_turn(&mut self) -> TractResult<()> {
        self.values.iter_mut().for_each(|s| *s = None);
        Ok(())
    }

    /// Reset op inner state.
    pub fn reset_op_states(&mut self) -> TractResult<()> {
        let &mut SimpleState { ref plan, ref mut session_state, ref mut states, .. } = self;
        *states = plan
            .borrow()
            .model()
            .nodes()
            .iter()
            .map(|n| n.op().state(session_state, n.id))
            .collect::<TractResult<_>>()?;
        Ok(())
    }

    pub fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        self.run_plan_with_eval(inputs, self::eval)
    }

    pub fn exec(&mut self) -> TractResult<()> {
        self.exec_plan_with_eval(self::eval)
    }

    pub fn run_plan_with_eval<Eval, E>(
        &mut self,
        inputs: TVec<TValue>,
        eval: Eval,
    ) -> TractResult<TVec<TValue>>
    where
        Eval: for<'a, 'b, 'c> FnMut(
            &'a mut SessionState,
            Option<&'b mut (dyn OpState + 'static)>,
            &'c Node<F, O>,
            TVec<TValue>,
        ) -> Result<TVec<TValue>, E>,
        E: Into<anyhow::Error> + Send + Sync + 'static,
    {
        self.set_inputs(inputs)?;
        self.exec_plan_with_eval(eval)?;
        let outputs = self.outputs()?;
        self.reset_turn()?;
        Ok(outputs)
    }

    pub fn exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
    where
        Eval: for<'a, 'b, 'c> FnMut(
            &'a mut SessionState,
            Option<&'b mut (dyn OpState + 'static)>,
            &'c Node<F, O>,
            TVec<TValue>,
        ) -> Result<TVec<TValue>, E>,
        E: Into<anyhow::Error> + Send + Sync + 'static,
    {
        {
            let &mut SimpleState {
                ref plan,
                ref mut session_state,
                ref mut states,
                ref mut values,
                ..
            } = self;
            let plan = plan.borrow();
            let model = plan.model().borrow();
            for (step, n) in plan.order.iter().enumerate() {
                let node = model.node(*n);
                trace!("Running step {}, node {}", step, node);
                let mut inputs: TVec<TValue> = tvec![];
                for i in &node.inputs {
                    trace!("  use input {:?}", i);
                    let prec_node = model.node(i.node);
                    let prec = values[i.node].as_ref().ok_or_else(|| {
                        format_err!("Computing {}, precursor {} not done:", node, prec_node)
                    })?;
                    inputs.push(prec[i.slot].clone())
                }

                for flush in &plan.flush_lists[step] {
                    trace!("  Ran {} can now flush {}", node, model.node(*flush));
                    values[*flush] = None;
                }

                if cfg!(debug_assertions) {
                    let facts = model.node_input_facts(node.id)?;
                    if facts.len() != inputs.len() {
                        bail!(
                            "Evaluating {}: expected {} inputs, got {}",
                            node,
                            facts.len(),
                            inputs.len()
                        );
                    }
                    for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
                        if !f.matches(v, Some(&session_state.resolved_symbols))? {
                            bail!(
                                "Evaluating {}: input {:?}, expected {:?}, got {:?}",
                                node,
                                ix,
                                f,
                                v
                            );
                        }
                    }
                }

                let vs = eval(session_state, states[node.id].as_deref_mut(), node, inputs)
                    .map_err(|e| e.into())?;

                if plan.has_unresolved_symbols {
                    for (o, v) in node.outputs.iter().zip(vs.iter()) {
                        if let Ok(f) = o.fact.to_typed_fact() {
                            for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
                                Self::resolve(
                                    &mut session_state.resolved_symbols,
                                    &dim_abstract,
                                    *dim_concrete as i64,
                                );
                            }
                        }
                    }
                }
                if cfg!(debug_assertions) {
                    let facts = model.node_output_facts(node.id)?;
                    if facts.len() != vs.len() {
                        bail!(
                            "Evaluating {}: expected {} outputs, got {}",
                            node,
                            facts.len(),
                            vs.len()
                        );
                    }
                    for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
                        if node.outputs[ix].successors.len() == 0 {
                            continue;
                        }
                        if !f.matches(v, Some(&session_state.resolved_symbols))? {
                            bail!(
                                "Evaluating {}: output {:?}, expected {:?}, got {:?}",
                                node,
                                ix,
                                f,
                                v
                            );
                        }
                    }
                }

                values[node.id] = Some(vs);
            }
        }
        Ok(())
    }

    pub fn set_inputs(&mut self, inputs: TVec<TValue>) -> TractResult<()> {
        ensure!(
            inputs.len() == self.model().inputs.len(),
            "Wrong number of inputs for model. Expected {} got {}",
            self.model().inputs.len(),
            inputs.len()
        );
        for (ix, t) in inputs.into_iter().enumerate() {
            self.set_input(ix, t)?
        }
        Ok(())
    }

    fn resolve(symbols: &mut SymbolValues, expected: &TDim, provided: i64) {
        match expected {
            TDim::Sym(s) => symbols[s] = Some(provided),
            TDim::MulInt(x, expr) => Self::resolve(symbols, expr, provided / *x),
            _ => (),
        }
    }

    pub fn set_input(&mut self, input: usize, t: TValue) -> TractResult<()> {
        let outlet: OutletId = *self
            .model()
            .input_outlets()?
            .get(input)
            .ok_or_else(|| format_err!("Invalid input id for model ({}).", input))?;
        let SimpleState { plan, session_state, .. } = self;
        let plan = (*plan).borrow();
        let model = plan.model.borrow();
        if let Ok(fact) = model.outlet_fact(outlet)?.to_typed_fact() {
            for (expected, provided) in fact.shape.iter().zip(t.shape()) {
                Self::resolve(&mut session_state.resolved_symbols, &expected, *provided as i64)
            }
        }
        let fact = self.plan.borrow().model().outlet_fact(outlet)?;
        ensure!(
            fact.matches(&t, Some(&self.session_state.resolved_symbols))
            .with_context(|| format!("Setting input {}", input))?,
            "Input at index {} has incorrect dtype or shape (got shape {:?} and dtype {:?}, expected to match fact {:?})",
            input,
            t.shape(),
            t.datum_type(),
            fact
            );
        self.session_state.inputs.insert(outlet.node, t);
        Ok(())
    }

    pub fn output(&self, id: usize) -> TractResult<&TValue> {
        let outlet = self.model().output_outlets()?.get(id).with_context(|| {
            format!(
                "Required output {}, only have {}",
                id,
                self.model().output_outlets().unwrap().len()
            )
        })?;
        let value: &TValue = self
            .values
            .get(outlet.node)
            .context("node id for output beyond node values array")?
            .as_ref()
            .context("node is not an output")?
            .get(outlet.slot)
            .context("slot id too high")?;
        Ok(value)
    }

    pub fn outputs(&mut self) -> TractResult<TVec<TValue>> {
        let SimpleState { ref plan, ref mut values, .. } = self;
        let mut v = tvec![];
        for o in plan.borrow().outputs.iter() {
            let vs = values[o.node].as_mut().ok_or_else(|| {
                format_err!(
                    "Outputs of {:?} are not computed",
                    &plan.borrow().model().nodes()[o.node]
                )
            })?;
            v.push(vs[o.slot].clone())
        }
        Ok(v)
    }

    pub fn set_values(&mut self, id: usize, values: TVec<TValue>) -> TractResult<()> {
        self.values[id] = Some(values);
        Ok(())
    }

    pub fn set_value(&mut self, id: usize, value: TValue) -> TractResult<()> {
        self.set_values(id, tvec!(value))
    }

    pub fn prepare_inputs(&self, node: usize) -> TractResult<TVec<TValue>> {
        let SimpleState { ref plan, ref values, .. } = self;
        let plan = plan.borrow();
        let nodes = plan.model().nodes();
        let node = &nodes[node];
        let mut inputs: TVec<TValue> = tvec![];
        for i in &node.inputs {
            let prec_node = &nodes[i.node];
            let prec = values[i.node].as_ref().ok_or_else(|| {
                format_err!("Computing {}, precursor {} not done.", node, prec_node)
            })?;
            inputs.push(prec[i.slot].clone())
        }
        Ok(inputs)
    }

    pub fn compute_one(&mut self, node: usize) -> TractResult<()> {
        let inputs = self.prepare_inputs(node)?;
        self.compute_one_with_inputs(node, inputs)
    }

    pub fn compute_one_with_inputs(
        &mut self,
        node: usize,
        inputs: TVec<TValue>,
    ) -> TractResult<()> {
        let SimpleState { ref plan, ref mut session_state, ref mut values, .. } = self;
        let plan = plan.borrow();
        let nodes = plan.model().nodes();
        let node = &nodes[node];
        let vs = match self.states[node.id] {
            Some(ref mut state) => state.eval(session_state, node.op(), inputs),
            None => node.op().eval(inputs),
        }
        .with_context(|| format!("Evaluating {}", node))?;
        values[node.id] = Some(vs);
        Ok(())
    }

    pub fn compute_recursively(&mut self, node: usize) -> TractResult<&[TValue]> {
        let values = {
            #[allow(clippy::needless_collect)] // clippy bug ?
            let precs: Vec<usize> =
                self.model().nodes()[node].inputs.iter().map(|i| i.node).collect();
            for i in precs.into_iter() {
                if self.values[i].is_none() {
                    let _ = self.compute_recursively(i)?;
                }
            }
            let mut inputs: TVec<TValue> = tvec![];
            {
                let node = &self.model().nodes()[node];
                for i in &node.inputs {
                    inputs.push(self.values[i.node].as_ref().unwrap()[i.slot].clone())
                }
            }
            let Self { ref mut states, ref mut session_state, ref plan, .. } = self;
            let plan = plan.borrow();
            match states[node] {
                Some(ref mut state) => {
                    state.eval(session_state, plan.borrow().model().nodes()[node].op(), inputs)
                }
                None => plan.borrow().model().nodes()[node].op().eval(inputs),
            }
            .with_context(|| format!("Evaluating {:?}", node))?
        };
        self.values[node] = Some(values);
        Ok(self.values[node].as_ref().unwrap())
    }
src/ops/matmul/mir_unary.rs (line 234)
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    fn declutter_precusor_is_concat(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if let Some(concat) = model.nodes()[node.inputs[0].node].op().downcast_ref::<TypedConcat>()
        {
            let mut patch = TypedModelPatch::new("split over k-concatenated input");
            if concat.axis == self.axes.b_k {
                let concat_node = model.node(node.inputs[0].node);
                let offsets = concat
                    .offsets(&model.node_input_facts(concat_node.id)?)?
                    .iter()
                    .map(|x| x.to_usize())
                    .collect::<TractResult<Vec<usize>>>()?;
                let mut wires = vec![];
                for (ix, input) in concat_node.inputs.iter().enumerate() {
                    let wire = patch.tap_model(model, *input)?;
                    let a = self.a.slice(self.axes.a_k, offsets[ix], offsets[ix + 1])?;
                    let wire = patch.wire_node(
                        format!("{}.k-{}-{}", node.name, offsets[ix], offsets[ix + 1]),
                        MatMulUnary { a: a.into_arc_tensor(), ..self.clone() },
                        &[wire],
                    )?[0];
                    wires.push(wire)
                }
                let mut wire = wires[0];
                for (ix, w) in wires[1..].iter().enumerate() {
                    wire = patch.wire_node(
                        format!("{}.k-add-{}", node.name, ix),
                        crate::ops::binary::TypedBinOp(Box::new(crate::ops::math::Add)),
                        &[wire, *w],
                    )?[0];
                }
                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
                return Ok(Some(patch));
            }
        }
        Ok(None)
    }
src/ops/matmul/mir_quant_unary.rs (line 163)
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        use crate::ops::array::TypedConcat;
        if let Some(concat) = model.nodes()[node.inputs[0].node].op().downcast_ref::<TypedConcat>()
        {
            let mut patch = TypedModelPatch::new("split over k-concatenated input");
            let k_axis = self.axes.a_k;
            if concat.axis == self.axes.b_k {
                let concat_node = model.node(node.inputs[0].node);
                let offsets = concat
                    .offsets(&model.node_input_facts(concat_node.id)?)?
                    .iter()
                    .map(|x| x.to_usize())
                    .collect::<TractResult<Vec<usize>>>()?;
                let mut wires = vec![];
                let mut params_for_split = self.params.clone();
                params_for_split.a_scale = tensor0(1.0f32).into();
                params_for_split.b_scale = tensor0(1.0f32).into();
                params_for_split.c_scale = tensor0(1.0f32).into();
                params_for_split.c0 = tensor0(0i32).into();
                let input_outlets = node
                    .inputs
                    .iter()
                    .skip(1)
                    .map(|o| patch.tap_model(model, *o))
                    .collect::<TractResult<TVec<_>>>()?;
                let params_outlets = self.params.as_outlet_ids(
                    &mut patch,
                    &node.name,
                    &input_outlets,
                    self.a.datum_type(),
                    model.node_input_facts(node.id)?[0].datum_type,
                    self.output_type,
                )?;

                let scale = combine_scales(
                    &mut patch,
                    &node.name,
                    params_outlets[1],
                    params_outlets[3],
                    params_outlets[5],
                )?;
                let c0 = params_outlets[4];

                for (ix, input) in concat_node.inputs.iter().enumerate() {
                    let wire = patch.tap_model(model, *input)?;
                    let a = self.a.slice(k_axis, offsets[ix], offsets[ix + 1])?;
                    let wire = patch
                        .wire_node(
                            format!("{}.k-{}-{}", node.name, offsets[ix], offsets[ix + 1]),
                            Self {
                                a: a.into_arc_tensor(),
                                output_type: DatumType::I32,
                                bias: self.bias.clone().filter(|_| ix == 0),
                                params: params_for_split.clone(),
                                ..self.clone()
                            },
                            &[wire],
                        )
                        .context("wiring new matmulunary")?[0];
                    wires.push(wire)
                }
                let mut wire = wires[0];
                for (ix, w) in wires[1..].iter().enumerate() {
                    wire = patch.wire_node(
                        format!("{}.k-add-{}", node.name, ix),
                        crate::ops::binary::TypedBinOp(Box::new(crate::ops::math::Add)),
                        &[wire, *w],
                    )?[0];
                }
                wire = requant(&mut patch, &node.name, wire, self.output_type, scale, c0)?;
                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
                return Ok(Some(patch));
            }
        }
        Ok(None)
    }

Access the nodes table.

Get input and output tensor information for a node.

Examples found in repository?
src/ops/downsample/mod.rs (line 116)
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
fn pull_downsample_up(
    model: &TypedModel,
    down_node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
    model.check_consistency()?;
    let down_op = down_node.op_as::<Downsample>().unwrap();
    if let Some(prec) = model.single_prec(down_node.id)? {
        let (input_facts, output_facts) = model.node_facts(prec.id)?;
        let invariants = prec.op.invariants(&input_facts, &output_facts)?;
        debug!("Consider pull {:?} over {:?} (invariants: {:?})", down_op, prec, invariants);
        if let Some(slice_op) = prec.op_as::<ops::array::Slice>() {
            if let Some(p) = array::pull_downsample_over_slice(model, prec, slice_op, down_node, down_op)? {
                return Ok(Some(p))
            }
        } else if let Some(other_op) = prec.op_as::<AxisOp>() {
            return array::pull_downsample_over_axis_op(model, prec, other_op, down_node, down_op);
        } else if let Some(conv_op) = prec.op_as::<ops::cnn::conv::ConvUnary>() {
            return conv::fuse_downsample_into_conv(model, prec, conv_op, down_node, down_op);
        } else if let Some(other_op) = prec.op_as::<ops::scan::Scan>() {
            return scan::pull_downsample_over_scan(model, prec, other_op, down_node, down_op);
        }
        if let Some(above_axis) = invariants.unary_track_axis_up(down_op.axis, false) {
            let mut patch = TypedModelPatch::default();
            let mut inputs = vec![];
            for (ix, &oo) in prec.inputs.iter().enumerate() {
                let source = patch.tap_model(model, oo)?;
                let mut op = down_op.clone();
                op.axis = above_axis;
                let ds = patch.wire_node(
                    format!("{}.{}-{}", down_node.name, prec.name, ix),
                    op,
                    [source].as_ref(),
                )?;
                inputs.push(ds[0]);
            }
            let other = patch.wire_node(&prec.name, prec.op.clone(), &inputs)?;
            patch.shunt_outside(model, OutletId::new(down_node.id, 0), other[0])?;
            return Ok(Some(patch));
        }
    }
    Ok(None)
}
More examples
Hide additional examples
src/optim/slice.rs (line 19)
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    fn next(
        &mut self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
    ) -> TractResult<Option<TypedModelPatch>> {
        for n in model.eval_order()? {
            let (ifacts, ofacts) = model.node_facts(n)?;
            if ofacts.len() != 1 {
                continue;
            }
            let node = model.node(n);
            let invariants = node.op.invariants(&ifacts, &ofacts)?;
            'axis: for axis in 0..ofacts[0].rank() {
                if let Some(boundaries) = should_slice_output(model, node, axis)? {
                    let mut splits = tvec!();
                    let mut patch = TypedModelPatch::new("push slice up");
                    let inputs = node
                        .inputs
                        .iter()
                        .map(|i| patch.tap_model(model, *i))
                        .collect::<TractResult<TVec<OutletId>>>()?;
                    let mut start = 0;
                    let axis_info = invariants.track_output_axis(0, axis);
                    for end in &boundaries {
                        let mut wires = tvec!();
                        for input_ix in 0..inputs.len() {
                            let mut wire = inputs[input_ix];
                            if let Some(input_axis) = axis_info.and_then(|it| it.inputs[input_ix]) {
                                wire = patch.wire_node(
                                    format!(
                                        "{}.split-{}-over-{}.{}..{}.slice",
                                        &node.name, input_ix, input_axis, start, end
                                    ),
                                    Slice {
                                        axis: input_axis,
                                        start: start.to_dim(),
                                        end: end.to_dim(),
                                    },
                                    &[wire],
                                )?[0];
                            }
                            wires.push(wire);
                        }
                        let Some(wire) = node.op.slice(
                            &mut patch,
                            &format!(
                                "{}.split-over-{}.{}..{}",
                                &node.name, axis, start, end
                                ),
                                &wires,
                                axis,
                                start,
                                *end,
                                )? else {
                            continue 'axis };
                        splits.push(wire[0]);
                        start = *end;
                    }
                    rewire_sliced_outputs(model, node, axis, &mut patch, &boundaries, &splits)?;
                    return Ok(Some(patch));
                }
            }
        }
        Ok(None)
    }
src/ops/invariants.rs (line 270)
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
    pub fn for_outlet_and_axis(
        model: &TypedModel,
        outlet: OutletId,
        axis: usize,
    ) -> TractResult<AxisTracking> {
        let mut mapped_outlets = OutletMap::default();
        let mut todo = OutletMap::default();
        let mut disposable = true;
        let mut creators = tvec!();
        let mut destructors = tvec!();
        mapped_outlets.insert(outlet, axis);
        todo.insert(outlet, ());
        while let Some(wire) = todo.keys().next() {
            todo.remove(&wire);
            let axis = mapped_outlets[&wire];
            let emiter_node = model.node(wire.node);
            let mut nodes = vec![];
            let (input_facts, output_facts) = model.node_facts(emiter_node.id)?;
            let invs = emiter_node
                .op
                .invariants(&input_facts, &output_facts)
                .with_context(|| format!("Computing invariants for {}", emiter_node))?;
            assert!(invs.axes.iter().all(|axis| axis.inputs.len() == emiter_node.inputs.len()));
            assert!(invs.axes.iter().all(|axis| axis.outputs.len() == emiter_node.outputs.len()));
            if let Some(info) = invs.track_output_axis(wire.slot, axis) {
                nodes.push((wire.node, info.clone()));
            } else {
                creators.push(wire);
            };
            for succ in &emiter_node.outputs[wire.slot].successors {
                let succ_node = model.node(succ.node);
                let (input_facts, output_facts) = model.node_facts(succ_node.id)?;
                let invs = succ_node.op.invariants(&input_facts, &output_facts)?;
                assert!(invs.axes.iter().all(|axis| axis.inputs.len() == succ_node.inputs.len()));
                assert!(invs.axes.iter().all(|axis| axis.outputs.len() == succ_node.outputs.len()));
                if let Some(info) = invs.track_input_axis(succ.slot, axis) {
                    nodes.push((succ_node.id, info.clone()));
                } else {
                    destructors.push(*succ);
                };
            }
            let mut new_outlets = vec![];
            for (n, axes) in nodes {
                disposable = disposable && axes.disposable;
                let node = model.node(n);
                for slot in 0..node.outputs.len() {
                    if let Some(axis) = axes.outputs[slot] {
                        new_outlets.push((OutletId::new(n, slot), axis));
                    }
                }
                for slot in 0..node.inputs.len() {
                    if let Some(axis) = axes.inputs[slot] {
                        new_outlets.push((node.inputs[slot], axis));
                    }
                }
            }
            for (outlet, axis) in new_outlets {
                if let Some(prev) = mapped_outlets.get(&outlet) {
                    if *prev != axis {
                        bail!("Inconsistent network");
                    }
                } else {
                    mapped_outlets.insert(outlet, axis);
                    todo.insert(outlet, ());
                }
            }
        }
        Ok(AxisTracking { creators, destructors, outlets: mapped_outlets, disposable })
    }
src/ops/scan/mir.rs (line 308)
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
    fn declutter_pull_batcheable_input(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (model_input, input) in self.input_mapping.iter().enumerate() {
            if let Some(info) = input.as_scan() {
                let scan_source = self.body.input_outlets()?[model_input];
                let scan_source_node = self.body.node(scan_source.node);
                for successor in &scan_source_node.outputs[0].successors {
                    let successor_node = self.body.node(successor.node);
                    if successor_node.inputs.len() != 1 || successor_node.outputs.len() != 1 {
                        continue;
                    }
                    let (input_facts, output_facts) = self.body.node_facts(successor_node.id)?;
                    let invariants = successor_node.op.invariants(&input_facts, &output_facts)?;
                    if let Some(axis_after) = invariants.unary_track_axis_down(info.axis, false) {
                        let mut outside_patch = TypedModelPatch::new(format!(
                            "Outer patch for input extraction of {}",
                            successor_node
                        ));
                        let mut patch_inputs = node
                            .inputs
                            .iter()
                            .map(|&i| outside_patch.tap_model(model, i))
                            .collect::<TractResult<TVec<_>>>()?;
                        let input = patch_inputs[info.slot];
                        let new_input_wire = outside_patch.wire_node(
                            format!("{}.extracted.{}", node.name, successor_node.name),
                            successor_node.op.clone(),
                            &[input],
                        )?[0];
                        patch_inputs.push(new_input_wire);
                        let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
                        let mut new_input_inner_fact = new_input_outer_fact.clone();
                        new_input_inner_fact.shape.set(axis_after, info.chunk.abs().to_dim());

                        let mut new_body = self.body.clone();
                        let new_source_wire = new_body.add_source(
                            format!("{}.extracted.{}", node.name, successor_node.name),
                            new_input_inner_fact,
                        )?;
                        let mut inner_patch = TypedModelPatch::new(format!(
                            "Inner body patch for extraction of {}",
                            successor_node
                        ));
                        let new_source_wire_in_patch =
                            inner_patch.tap_model(&new_body, new_source_wire)?;
                        inner_patch
                            .shunt_outside(
                                &new_body,
                                OutletId::new(successor.node, 0),
                                new_source_wire_in_patch,
                            )
                            .with_context(|| "patching inner model")?;
                        inner_patch.apply(&mut new_body)?;

                        let mut input_mapping = self.input_mapping.clone();
                        input_mapping.push(InputMapping::Scan(ScanInfo {
                            axis: axis_after,
                            chunk: info.chunk,
                            slot: node.inputs.len(),
                        }));

                        let new_op = Self {
                            input_mapping,
                            output_mapping: self.output_mapping.clone(),
                            decluttered: false,
                            body: new_body,
                            skip: self.skip,
                            seq_length_input_slot: self.seq_length_input_slot,
                        };
                        let output_wires =
                            outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
                        for w in output_wires {
                            outside_patch
                                .shunt_outside(model, OutletId::new(node.id, w.slot), w)
                                .with_context(|| "patching outer model")?;
                        }
                        return Ok(Some(outside_patch));
                    }
                }
            }
        }
        Ok(None)
    }

    fn declutter_pull_constant_outputs(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
        for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
            if let Some(slot) = mapping.last_value_slot {
                if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
                    let inner_node = self.body.output_outlets()?[model_output_ix].node;
                    let inner_node = self.body.node(inner_node);
                    let mut patch = TypedModelPatch::new(format!("Extract const node {}", inner_node));
                    let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
                    patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
                    return Ok(Some(patch));
                }
            }
        }
        Ok(None)
    }

    fn declutter_pull_batcheable_output(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (model_ix, mapping) in self.output_mapping.iter().enumerate() {
            if let Some(info) = mapping.scan {
                let emitter_outlet = self.body.output_outlets()?[model_ix];
                let emitter_node = self.body.node(emitter_outlet.node);
                if emitter_node.outputs[emitter_outlet.slot].successors.len() > 0
                    || mapping.state
                    || mapping.scan.map(|i| i.chunk > 1).unwrap_or(true)
                {
                    // continue if both last_value and full values are exported
                    continue;
                }
                let (input_facts, output_facts) = self.body.node_facts(emitter_node.id)?;
                let invariants = emitter_node.op.invariants(&input_facts, &output_facts)?;
                let Some(axis_before) = invariants.unary_track_axis_up(info.axis, false)
                else {
                    continue;
                };

                let mut new_body = self.body.clone();
                let mut new_output_mapping = self.output_mapping.clone();
                let mut new_scan_outputs = node.outputs.len();
                let mut outer_slots = vec![];

                for input in &emitter_node.inputs {
                    if new_body.outputs.iter().all(|o| o != input) {
                        new_output_mapping.push(OutputMapping::default());
                        new_body.outputs.push(*input);
                    }
                    let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
                    let mut mapping = &mut new_output_mapping[body_output_id];
                    let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
                        if mapping.last_value_slot.is_none() {
                            mapping.last_value_slot = Some(new_scan_outputs);
                        }
                        new_scan_outputs += 1;
                        mapping.last_value_slot.unwrap()
                    } else {
                        if mapping.scan.is_none() {
                            mapping.scan = Some(ScanInfo {
                                slot: new_scan_outputs,
                                axis: axis_before,
                                chunk: info.chunk,
                            });
                            new_scan_outputs += 1;
                        }
                        mapping.scan.unwrap().slot
                    };
                    outer_slots.push(outer_slot);
                }
                let mut outside_patch = TypedModelPatch::new(format!(
                    "Outside patch for output extraction of {}",
                    emitter_node
                ));
                let inputs = node
                    .inputs
                    .iter()
                    .map(|&i| outside_patch.tap_model(model, i))
                    .collect::<TractResult<TVec<_>>>()?;
                let new_op = Self {
                    input_mapping: self.input_mapping.clone(),
                    output_mapping: new_output_mapping,
                    decluttered: false,
                    body: new_body,
                    skip: self.skip,
                    seq_length_input_slot: self.seq_length_input_slot,
                };
                let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
                let output = mapping.scan.unwrap();
                let inputs =
                    outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
                let wire = outside_patch.wire_node(
                    &*emitter_node.name,
                    emitter_node.op.clone(),
                    &inputs,
                )?[0];
                outside_patch.shunt_outside(model, OutletId::new(node.id, output.slot), wire)?;
                for output_slot in 0..node.outputs.len() {
                    if output_slot != output.slot {
                        outside_patch.shunt_outside(
                            model,
                            OutletId::new(node.id, output_slot),
                            OutletId::new(scan_outputs[0].node, output_slot),
                        )?;
                    }
                }
                return Ok(Some(outside_patch));
            }
        }
        Ok(None)
    }
src/ops/quant.rs (line 233)
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    fn declutter(
        &self,
        model: &TypedModel,
        dequant: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut current = dequant;
        let incoming_dt = model.node_input_facts(dequant.id)?[0].datum_type;
        while let Some(quant) = model.single_succ(current.id)? {
            let q_params = if let Some(op) = quant.op_as::<ElementWiseOp>() {
                if let Some(mop) = op.0.downcast_ref::<QuantizeLinearU8>() {
                    Some((mop.scale, mop.zero_point as i32, u8::datum_type()))
                } else {
                    op.0.downcast_ref::<QuantizeLinearI8>()
                        .map(|mop| (mop.scale, mop.zero_point as i32, i8::datum_type()))
                }
            } else {
                None
            };
            if let Some((scale, zero_point, dt)) = q_params {
                // first, try Op::quantize() on all ops in the chain
                let mut patch = TypedModelPatch::default();
                let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
                let mut next = model.single_succ(dequant.id)?.unwrap();
                loop {
                    if let Some(op) = next
                        .op
                        .quantize(model, dequant, dt, scale, zero_point)
                        .with_context(|| format!("Quantizing {}", next))?
                    {
                        wire = patch.wire_node(&*next.name, op, [wire].as_ref())?[0];
                    } else {
                        break;
                    }
                    if next.id == current.id {
                        patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
                        return Ok(Some(patch));
                    } else {
                        next = model.single_succ(next.id)?.unwrap();
                    }
                }
                // or else make a lookup table
                if incoming_dt == DatumType::I8 || incoming_dt == DatumType::U8 {
                    let mut adhoc_model = TypedModel::default();
                    let mut wire = adhoc_model.add_source("ad-hoc", dt.fact([256]))?;
                    let mut next = model.single_succ(dequant.id)?.unwrap();
                    let mut name = None;
                    // plug in dequant
                    wire = adhoc_model.wire_node(
                        &*dequant.name,
                        dequant.op.clone(),
                        [wire].as_ref(),
                    )?[0];
                    while next.id != quant.id {
                        name.get_or_insert(&*next.name);
                        wire =
                            adhoc_model.wire_node(&*next.name, next.op.clone(), [wire].as_ref())?
                                [0];
                        next = model.single_succ(next.id)?.unwrap();
                    }
                    // plug in quant
                    wire =
                        adhoc_model.wire_node(&*quant.name, quant.op.clone(), [wire].as_ref())?[0];
                    adhoc_model.set_output_outlets(&[wire])?;
                    let input = (0u8..=255).collect::<Vec<u8>>();
                    let input = match dt {
                        DatumType::I8 => unsafe {
                            tensor1(std::mem::transmute::<&[u8], &[i8]>(&*input))
                        },
                        DatumType::U8 => tensor1(&input),
                        _ => unreachable!(),
                    };
                    let output =
                        SimplePlan::new(adhoc_model)?.run(tvec!(input.into_tvalue()))?.remove(0);
                    let table: &[u8] = match dt {
                        DatumType::I8 => unsafe { std::mem::transmute(output.as_slice::<i8>()?) },
                        DatumType::U8 => output.as_slice::<u8>()?,
                        _ => unreachable!(),
                    };
                    let op = lookup_table((tract_linalg::ops().lut_u8)(table));
                    let mut patch = TypedModelPatch::default();
                    let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;

                    wire = patch.wire_node(name.unwrap_or(&*dequant.name), op, [wire].as_ref())?[0];
                    patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
                    return Ok(Some(patch));
                }
            }
            let (input_facts, output_facts) = model.node_facts(quant.id)?;
            let invariants = quant
                .op
                .invariants(&input_facts, &output_facts)
                .with_context(|| format!("Querying invariants for {}", quant))?;
            if invariants.element_wise() {
                current = quant;
            } else {
                break;
            }
        }
        Ok(None)
    }

Get input tensor information for a node.

Examples found in repository?
src/model/graph.rs (line 360)
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
    pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
        Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
    }

    /// Get input tensor information for a node.
    pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
    }

    /// Get output tensor information for a node.
    pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
    }

    // outlets

    /// Get tensor information for a single outlet.
    pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
        anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
        let outlets = &self.nodes[outlet.node].outputs;
        outlets
            .get(outlet.slot)
            .map(|o| &o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get tensor information for a single outlet.
    pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        outlets
            .get_mut(outlet.slot)
            .map(|o| &mut o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get multiple mutable tensor information for outlets.
    pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
        assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
        unsafe {
            outlets
                .iter()
                .map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
                .collect()
        }
    }

    /// Set tensor information for a single outlet.
    pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        if outlets.len() <= outlet.slot {
            bail!("Invalid outlet refererence: {:?}", outlet)
        }
        outlets[outlet.slot].fact = fact;
        Ok(())
    }

    /// Set tensor information for a single outlet and return `self`.
    pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
        self.set_outlet_fact(outlet, fact)?;
        Ok(self)
    }

    // outlet labels

    /// Get label for an outlet.
    pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
        self.outlet_labels.get(&outlet).map(|s| &**s)
    }

    /// Set label for an outlet.
    pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
        self.outlet_labels.insert(outlet, label);
        Ok(())
    }

    /// Set label for an outlet and return `self`.
    pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
        self.set_outlet_label(outlet, label)?;
        Ok(self)
    }

    /// Find outlet by label.
    pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
        self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
    }

    // misc

    /// Computes an evalutation order for the graph inputs and outputs
    pub fn eval_order(&self) -> TractResult<Vec<usize>> {
        eval_order(self)
    }

    #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
    #[inline]
    pub fn check_edges(&self) -> TractResult<()> {
        Ok(())
    }

    /// Performs a sanity check on network connections.
    #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
    #[inline]
    pub fn check_edges(&self) -> TractResult<()> {
        for node_id in self.eval_order()? {
            let node = &self.nodes[node_id];
            for (ix, input) in node.inputs.iter().enumerate() {
                let prec = &self.nodes[input.node];
                if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
                    bail!(
                        "Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
                        node.id,
                        ix,
                        prec
                    )
                }
            }
            for (ix, output) in node.outputs.iter().enumerate() {
                for succ in &output.successors {
                    if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
                        bail!(
                            "Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
                            node.id,
                            ix,
                            succ
                        )
                    }
                }
            }
        }
        Ok(())
    }

    /// Converts the model into a `RunnableModel` which fixes the inputs and outputs and allows passing data through the model.
    pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
        crate::plan::SimplePlan::new(self)
    }

    pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
        let node = &self.nodes()[id];
        if node.inputs.len() != 1 {
            return Ok(None);
        }
        let prec = &self.nodes()[node.inputs[0].node];
        if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        Ok(Some(prec))
    }

    pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_prec(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_succ(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
        let node = &self.nodes()[id];
        if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        let succ = node.outputs[0].successors[0];
        let succ = &self.nodes()[succ.node];
        if succ.inputs.len() != 1 {
            return Ok(None);
        }
        Ok(Some(succ))
    }

    pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
        &self.nodes[outlet.node].outputs[outlet.slot].successors
    }
}

impl<F: Fact + Clone + 'static, O> Graph<F, O>
where
    F: Fact + Clone + 'static + From<std::sync::Arc<Tensor>> + Hash,
    O: fmt::Debug
        + fmt::Display
        + From<crate::ops::konst::Const>
        + AsRef<dyn Op>
        + AsMut<dyn Op>
        + Clone
        + Hash
        + 'static,
{
    pub fn add_const(
        &mut self,
        name: impl Into<String>,
        v: impl IntoArcTensor,
    ) -> TractResult<OutletId> {
        let v = v.into_arc_tensor();
        let fact = F::from(v.clone());
        let name = name.into();
        self.add_node(name, crate::ops::konst::Const::new(v), tvec!(fact)).map(|id| id.into())
    }
}

impl<F, O> fmt::Display for Graph<F, O>
where
    F: Fact + Hash + Clone + 'static,
    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
        for i in 0..self.nodes.len() {
            let input_1 = self.nodes[i]
                .inputs
                .get(0)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let input_2 = self.nodes[i]
                .inputs
                .get(1)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let output_1 = self
                .outlet_successors(OutletId::new(i, 0))
                .get(0)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let output_2 = self
                .outlet_successors(OutletId::new(i, 0))
                .get(1)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            writeln!(
                fmt,
                "{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
                i,
                input_1,
                input_2,
                output_1,
                output_2,
                self.nodes[i].op().name(),
                self.nodes[i].name,
                self.node_input_facts(i).unwrap(),
                self.node_output_facts(i).unwrap(),
            )?;
            if self.nodes[i].inputs.len() > 2 {
                writeln!(
                    fmt,
                    "                                               |   * inputs: {}",
                    self.nodes[i].inputs.iter().map(|s| format!("{:?}", s)).join(", ")
                )?;
            }
            if self.nodes[i].outputs.len() > 1
                || self.outlet_successors((i, 0).into()).len() > 2
                || (self.outlet_label(i.into()).is_some()
                    && self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
            {
                for o in 0..self.nodes[i].outputs.len() {
                    if self.outlet_successors((i, o).into()).len() > 0 {
                        writeln!(
                                    fmt,
                                    "                                               |   * output #{}: {} {}",
                                    o,
                                    self.outlet_label((i, o).into()).unwrap_or(""),
                                    self.outlet_successors((i, o).into())
                                    .iter()
                                    .map(|s| format!("{:?}", s))
                                    .join(", "),
                                    )?;
                    }
                }
            }
        }
        writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{:?}", o)).join(", "))?;
        Ok(())
    }
More examples
Hide additional examples
src/ops/matmul/mir_unary.rs (line 132)
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    fn codegen(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let b = args_1!(model.node_input_facts(node.id)?);
        if let Some(b_shape) = b.shape.as_concrete() {
            Ok(Some(self.new_mat_mul_unary_finite(model, node, b_shape, b.datum_type)?))
        } else {
            Ok(None)
        }
    }

    as_op!();
}

impl MatMulUnary {
    fn new_mat_mul_unary_finite(
        &self,
        model: &TypedModel,
        node: &TypedNode,
        b_shape: &[usize],
        b_dt: DatumType,
    ) -> TractResult<TypedModelPatch> {
        let mut patch = TypedModelPatch::default();
        let mut wire = patch.tap_model(model, node.inputs[0])?;

        let c_dt = output_type(self.a.datum_type());
        let (m, k, n, c_shape) = compute_shape(self.a.shape(), b_shape, self.axes)?;

        let mmm = tract_linalg::ops()
            .mmm(self.a.datum_type(), b_dt, c_dt, Some(m), Some(k), Some(n))
            .with_context(|| {
                format!(
                    "No matrix multiplier for {:?}x{:?} to {:?}",
                    self.a.datum_type(),
                    b_dt,
                    c_dt
                )
            })?;

        let mut a_iter_shape: TVec<usize> = self.a.shape().into();
        a_iter_shape[self.axes.a_m] = 1;
        a_iter_shape[self.axes.a_k] = 1;
        let packed_as = Array::from_shape_fn(&*a_iter_shape, |a_prefix| unsafe {
            let offset = a_prefix
                .as_array_view()
                .iter()
                .zip(self.a.strides())
                .map(|(x, s)| *x as isize * s)
                .sum::<isize>()
                * self.a.datum_type().size_of() as isize;
            let mut pa = Tensor::uninitialized_aligned_dt(
                self.a.datum_type(),
                &[mmm.a_pack().len(k, m)],
                mmm.a_pack().alignment(),
            )
            .unwrap();
            mmm.a_pack().pack(
                &mut pa.view_mut(),
                TensorView::from_bytes(&self.a, offset, self.a.shape(), self.a.strides()),
                self.axes.a_k,
                self.axes.a_m,
            );
            (pa.into_arc_tensor(), vec![ProtoFusedSpec::Store])
        });
        unsafe {
            let mut packed_b_shape: TVec<usize> = b_shape.into();
            packed_b_shape.remove(self.axes.b_k.max(self.axes.b_n));
            packed_b_shape.remove(self.axes.b_k.min(self.axes.b_n));
            packed_b_shape.push(mmm.b_pack().len(k, n));
            wire = patch.wire_node(
                format!("{}.pack", &*node.name),
                super::MatMatMulPack {
                    packer: mmm.b_pack(),
                    k_axis: self.axes.b_k,
                    mn_axis: self.axes.b_n,
                },
                &[wire],
            )?[0];
            let b_storage = mmm.b_packed(b_dt.size_of(), k);
            let geometry = ConcreteMatMulGeometry { m, k, n, b_storage };
            wire = patch.wire_node(
                format!("{}.matmatmul", &*node.name),
                LirMatMulUnary {
                    c_fact: c_dt.fact(&c_shape),
                    geometry: MatMulGeometry::Concrete(geometry),
                    micro_ops: packed_as,
                    c_m_axis: self.axes.c_m,
                    c_n_axis: self.axes.c_n,
                    c_final_shape: c_shape.into(),
                    reshape_post: vec![],
                    mmm,
                },
                &[wire],
            )?[0];
            patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
            patch.obliterate(node.id)?;
        }
        Ok(patch)
    }

    fn declutter_precusor_is_concat(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if let Some(concat) = model.nodes()[node.inputs[0].node].op().downcast_ref::<TypedConcat>()
        {
            let mut patch = TypedModelPatch::new("split over k-concatenated input");
            if concat.axis == self.axes.b_k {
                let concat_node = model.node(node.inputs[0].node);
                let offsets = concat
                    .offsets(&model.node_input_facts(concat_node.id)?)?
                    .iter()
                    .map(|x| x.to_usize())
                    .collect::<TractResult<Vec<usize>>>()?;
                let mut wires = vec![];
                for (ix, input) in concat_node.inputs.iter().enumerate() {
                    let wire = patch.tap_model(model, *input)?;
                    let a = self.a.slice(self.axes.a_k, offsets[ix], offsets[ix + 1])?;
                    let wire = patch.wire_node(
                        format!("{}.k-{}-{}", node.name, offsets[ix], offsets[ix + 1]),
                        MatMulUnary { a: a.into_arc_tensor(), ..self.clone() },
                        &[wire],
                    )?[0];
                    wires.push(wire)
                }
                let mut wire = wires[0];
                for (ix, w) in wires[1..].iter().enumerate() {
                    wire = patch.wire_node(
                        format!("{}.k-add-{}", node.name, ix),
                        crate::ops::binary::TypedBinOp(Box::new(crate::ops::math::Add)),
                        &[wire, *w],
                    )?[0];
                }
                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
                return Ok(Some(patch));
            }
        }
        Ok(None)
    }
src/ops/binary.rs (line 229)
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
    fn codegen(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let inputs = model.node_input_facts(node.id)?;
        if self.0.result_datum_type(inputs[0].datum_type, inputs[1].datum_type)?
            == inputs[0].datum_type
            && inputs[0] == inputs[1]
        {
            Ok(Some(TypedModelPatch::replace_single_op(
                model,
                node,
                &node.inputs,
                MergeOpUnicast(self.0.clone()),
            )?))
        } else {
            Ok(None)
        }
    }

    as_op!();
}

#[derive(Debug, Clone, Hash)]
pub struct MergeOpUnicast(pub Box<dyn BinMiniOp>);
impl_dyn_hash!(MergeOpUnicast);

impl Op for MergeOpUnicast {
    fn name(&self) -> Cow<str> {
        format!("{}Unicast", self.0.name()).into()
    }

    op_as_typed_op!();
}

impl EvalOp for MergeOpUnicast {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let (a, b) = args_2!(inputs);
        let mut b = b.into_tensor();
        self.0.eval_unicast_in_place(&a, &mut b)?;
        Ok(tvec!(b.into_tvalue()))
    }
}

impl TypedOp for MergeOpUnicast {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        debug_assert_eq!(inputs[0].shape, inputs[1].shape);
        Ok(tvec!(inputs[0].clone()))
    }

    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
        let count: TDim = self.output_facts(inputs)?[0].shape.iter().product();
        Ok(self
            .0
            .cost_per_element(inputs[0].datum_type)
            .into_iter()
            .map(|(c, n)| (c, count.clone() * n))
            .collect())
    }

    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        self.0.declutter(model, node)
    }

    as_op!();
}

#[macro_export]
macro_rules! bin_to_super_type {
    ($func:ident, $Op:ident,
     $(codegen: $codegen:expr,)?
     $(cost: $cost:expr,)?
     $(declutter: $declutter:expr,)?
     $(eval_override: $eval_override: expr,)?
     $(linalg: $linalg:ident,)?
     $(operating_datum_type: $operating_datum_type:expr,)?
     $(out_of_place: $out_of_place:expr,)?
     $(validation: $validation:expr,)?
     $(q: $([$($typ_dt:ident),*] => $cab_dt:expr),* ;)?
     $( [$($typ:ident),*] => $cab:expr),*) => {
        #[derive(Debug, Clone, Hash)]
        pub struct $Op;
        tract_data::internal::impl_dyn_hash!($Op);
        #[allow(clippy::redundant_closure_call)]
        impl $crate::ops::binary::BinMiniOp for $Op {
            fn name(&self) -> &'static str {
                stringify!($Op)
            }

            fn eval_uniform_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> {
                $(
                    $(if a.datum_type() == $typ::datum_type() {
                        let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
                        let a = a.to_scalar::<$typ>()?;
                        let b = b.as_slice_mut::<$typ>()?;
                        unsafe {
                            for i in 0..b.len() {
                                let mut c = $typ::default();
                                cab(&mut c, a, b.get_unchecked_mut(i));
                                b[i] = c;
                            }
                        }
                        return Ok(())
                    }
                    )*
                 )*

                    $(
                        $(
                            $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
                                let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
                                let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
                                let a = a.to_scalar::<$typ_dt>()?;
                                let b = b.as_slice_mut::<$typ_dt>()?;
                                unsafe {
                                    for i in 0..b.len() {
                                        let mut c = $typ_dt::default();
                                        cab(&mut c, a, b.get_unchecked_mut(i), zp, scale);
                                        b[i] = c;
                                    }
                                }
                                return Ok(())
                            }
                            )*
                         )*
                     )?
                    bail!("{} does not support {:?} (inplace uniform)", self.name(), a.datum_type());
            }

            fn eval_unicast_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> {
                $(
                    $(if a.datum_type() == $typ::datum_type() {
                        let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
                        let a = a.as_slice::<$typ>()?;
                        let b = b.as_slice_mut::<$typ>()?;
                        unsafe {
                            for i in 0..a.len() {
                                let mut c = $typ::default();
                                cab(&mut c, &a[i], b.get_unchecked(i));
                                *b.get_unchecked_mut(i) = c;
                            }
                        }
                        return Ok(())
                    }
                    )*
                 )*
                    $(
                        $(
                            $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
                                let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
                                let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
                                let a = a.as_slice::<$typ_dt>()?;
                                let b = b.as_slice_mut::<$typ_dt>()?;
                                unsafe {
                                    for i in 0..a.len() {
                                        let mut c = $typ_dt::default();
                                        cab(&mut c, &a[i], b.get_unchecked(i), zp, scale);
                                        *b.get_unchecked_mut(i) = c;
                                    }
                                }
                                return Ok(())
                            }
                            )*
                         )*
                     )?
                    bail!("{} does not support {:?} (inplace)", self.name(), a.datum_type());
            }

            fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
                $(if $out_of_place(c, a, b)? { return Ok(()) } )?
                    $(
                        $(if c.datum_type() == $typ::datum_type() {
                            let a = a.to_array_view::<$typ>()?;
                            let b = b.to_array_view::<$typ>()?;
                            let mut c = c.to_array_view_mut::<$typ>()?;
                            $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each($cab);
                            return Ok(())
                        })*
                     )*
                    $(
                        $(
                            $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
                                let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
                                let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
                                let a = a.to_array_view::<$typ_dt>()?;
                                let b = b.to_array_view::<$typ_dt>()?;
                                let mut c = c.to_array_view_mut::<$typ_dt>()?;
                                $crate::ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(|c, a, b| cab(c, a, b, zp, scale));
                                return Ok(())
                            }
                            )*
                         )*
                     )?
                    bail!("{} does not support {:?} (out of place)", self.name(), c.datum_type());
            }

            fn eval_in_a(&self, a: &mut Tensor, b: &Tensor) -> TractResult<()> {
                // c and a are same type
                $(
                    $(if b.datum_type() == $typ::datum_type() {
                        let cab: fn(&mut $typ, &$typ, &$typ) -> () = $cab;
                        let b = b.to_array_view::<$typ>()?;
                        let mut a = a.to_array_view_mut::<$typ>()?;
                        $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, &a.clone(), b));
                        return Ok(())
                    })*
                 )*
                    /*
                       $(
                       $(
                       $(if a.datum_type().unquantized() == <$typ_dt>::datum_type().unquantized() {
                       let cab: fn(&mut $typ_dt, &$typ_dt, &$typ_dt, i32, f32) -> () = $cab_dt;
                       let (zp, scale) = a.datum_type().qparams().map(|q| q.zp_scale()).unwrap_or((0, 1.));
                       let mut a = a.to_array_view_mut::<$typ_dt>()?;
                       let b = b.to_array_view::<$typ_dt>()?;
                       $crate::ndarray::Zip::from(&mut a).and_broadcast(b).for_each(|a, b| cab(a, a, b, zp, scale));
                       return Ok(())
                       }
                       )*
                       )*
                       )?
                       */
                    bail!("{} does not support {:?} (out of place)", self.name(), a.datum_type());
            }

            $(fn eval(&self, a: TValue, b: TValue) -> TractResult<Tensor> {
                $eval_override(a, b)
            })?

            fn result_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
                if a.unquantized() == b.unquantized() {
                    if a.is_quantized() || !b.is_quantized() {
                        return Ok(a)
                    }
                    else {
                        return Ok(b)
                    }
                }
                self.operating_datum_type(a, b)
            }

                $(
                    fn declutter(
                        &self,
                        model: &TypedModel,
                        node: &TypedNode,
                        ) -> TractResult<Option<TypedModelPatch>> {
                        ($declutter)(self, model, node)
                    }
                 )?
                $(
                    fn codegen(
                        &self,
                        model: &TypedModel,
                        node: &TypedNode,
                        a: &Arc<Tensor>,
                        ) -> TractResult<Option<TypedModelPatch>> {
                        ($codegen)(self, model, node, a)
                    }
                 )?
                $(
                    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
                        ($cost)(dt)
                    }
                 )?
                $(
                    fn validation(&self) -> Validation {
                        $validation
                    }
                 )?
                $(
                    fn as_linalg_binop(&self) -> Option<tract_linalg::mmm::BinOp> {
                        Some(tract_linalg::mmm::BinOp::$linalg)
                    }
                 )?
                $(
                    fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
                        ($operating_datum_type)(a, b)
                    })?
        }

        pub fn $func() -> $crate::ops::binary::TypedBinOp {
            $crate::ops::binary::TypedBinOp(Box::new($Op))
        }
    };
}

macro_rules! bin_to_bool {
    ($func:ident, $Op:ident,
     $( codegen: $codegen:expr, )?
     $( cost: $cost:expr, )?
     $( declutter: $declutter:expr, )?
     $( operating_datum_type: $operating_datum_type:expr, )?
     $( [$($typ:ident),*] => $cab:expr),*) => {
        #[derive(Debug, Clone, Hash)]
        pub struct $Op;
        tract_data::internal::impl_dyn_hash!($Op);
        impl $crate::ops::binary::BinMiniOp for $Op {
            fn name(&self) -> &'static str {
                stringify!($Op)
            }

            fn eval_uniform_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> {
                $(
                    $(if a.datum_type() == $typ::datum_type() {
                        let cab: fn(&mut bool, &bool, &bool) -> () = $cab;
                        let a = a.to_scalar::<bool>()?;
                        let b = b.as_slice_mut::<bool>()?;
                        unsafe {
                            for i in 0..b.len() {
                                let mut c = bool::default();
                                cab(&mut c, a, b.get_unchecked(i));
                                *b.get_unchecked_mut(i) = c;
                            }
                        }
                        return Ok(())
                    }
                    )*
                 )*
                    bail!("{} does not support {:?} (inplace uniform)", self.name(), a.datum_type());
            }

            #[allow(unreachable_code)]
            fn eval_unicast_in_place(&self, a: &Tensor, b: &mut Tensor) -> TractResult<()> {
                $(
                    $(if a.datum_type() == $typ::datum_type() {
                        let cab: fn(&mut bool, &bool, &bool) -> () = $cab;
                        let a = a.as_slice::<bool>()?;
                        let b = b.as_slice_mut::<bool>()?;
                        unsafe {
                            for i in 0..a.len() {
                                let mut c = bool::default();
                                cab(&mut c, a.get_unchecked(i), b.get_unchecked(i));
                                *b.get_unchecked_mut(i) = c;
                            }
                        }
                        return Ok(())
                    }
                    )*
                 )*
                    bail!("{} does not support {:?}", self.name(), a.datum_type());
            }

            fn eval_out_of_place(&self, c: &mut Tensor, a: &Tensor, b: &Tensor) -> TractResult<()> {
                $(
                    $(if a.datum_type() == $typ::datum_type() {
                        let cab: fn(&mut bool, &$typ, &$typ) -> () = $cab;
                        let a = a.to_array_view::<$typ>()?;
                        let b = b.to_array_view::<$typ>()?;
                        let mut c = c.to_array_view_mut::<bool>()?;
                        ndarray::Zip::from(&mut c).and_broadcast(a).and_broadcast(b).for_each(cab);
                        return Ok(())
                    }
                    )*
                 )*
                    bail!("{} does not support {:?}", self.name(), a.datum_type());
            }

            fn eval_in_a(&self, a: &mut Tensor, _b: &Tensor) -> TractResult<()> {
                bail!("{} does not support {:?}", self.name(), a.datum_type());
            }

            fn result_datum_type(&self, _a: DatumType, _b: DatumType) -> TractResult<DatumType> {
                Ok(bool::datum_type())
            }

            $(
                fn codegen(
                    &self,
                    model: &TypedModel,
                    node: &TypedNode,
                    ) -> TractResult<Option<TypedModelPatch>> {
                    ($codegen)(self, model, node)
                }
             )?


                $(
                    fn declutter(
                        &self,
                        model: &TypedModel,
                        node: &TypedNode,
                        ) -> TractResult<Option<TypedModelPatch>> {
                        ($declutter)(self, model, node)
                    }
                 )?

                $(
                    fn cost_per_element(&self, dt: DatumType) -> TVec<(Cost, usize)> {
                        ($cost)(dt)
                    }
                 )?

                $(
                    fn operating_datum_type(&self, a: DatumType, b: DatumType) -> TractResult<DatumType> {
                        ($operating_datum_type)(a, b)
                    })?

        }

        pub fn $func() -> $crate::ops::binary::TypedBinOp {
            $crate::ops::binary::TypedBinOp(Box::new($Op))
        }
    };
}

#[derive(Debug)]
pub(crate) struct OneUniformInput {
    pub uni: Arc<Tensor>,
    pub var: OutletId,
    pub left_is_uniform: bool,
}

pub(crate) fn one_input_is_uniform(
    model: &TypedModel,
    node: &TypedNode,
) -> TractResult<Option<OneUniformInput>> {
    if let &[a, b] = &*model.node_input_facts(node.id)? {
        let uni = if let Some(a) = &a.uniform {
            OneUniformInput { uni: a.clone(), var: node.inputs[1], left_is_uniform: true }
        } else if let Some(b) = &b.uniform {
            OneUniformInput { uni: b.clone(), var: node.inputs[0], left_is_uniform: false }
        } else {
            return Ok(None);
        };
        let var_fact = [a, b][uni.left_is_uniform as usize];
        let uni_fact = [a, b][!uni.left_is_uniform as usize];
        if izip!(var_fact.shape.iter(), uni_fact.shape.iter()).all(|(v, u)| u.is_one() || u == v) {
            return Ok(Some(uni))
        }
    }
    Ok(None)
}
src/ops/array/dyn_slice.rs (line 115)
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let inputs = model.node_input_facts(node.id)?;
        let start =
            if self.start_input { inputs[1].konst.clone() } else { Some(rctensor0(TDim::zero())) };
        let end = if self.end_input {
            inputs[1 + self.start_input as usize].konst.clone()
        } else {
            Some(rctensor0(inputs[0].shape[self.axis].clone()))
        };
        if let (Some(start), Some(end)) = (start, end) {
            return Ok(Some(TypedModelPatch::replace_single_op(
                model,
                node,
                &[node.inputs[0]],
                crate::ops::array::Slice {
                    axis: self.axis,
                    start: start.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone(),
                    end: end.cast_to::<TDim>()?.to_scalar::<TDim>()?.clone(),
                },
            )?));
        }
        Ok(None)
    }
src/ops/scan/mir.rs (line 158)
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    fn declutter_const_initializer(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let inputs = model.node_input_facts(node.id)?;
        for (ix, mapping) in self.input_mapping.iter().enumerate() {
            if let InputMapping::State { initializer: StateInitializer::FromInput(n) } = mapping {
                if let Some(i) = inputs[*n].konst.as_ref() {
                    let mut op = self.clone();
                    op.input_mapping[ix] =
                        InputMapping::State { initializer: StateInitializer::Value(i.clone()) };
                    op.input_mapping =
                        Self::remove_outer_input_from_mappings(&op.input_mapping, *n);
                    let mut inputs = node.inputs.clone();
                    inputs.remove(*n);
                    return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
                }
            }
        }
        Ok(None)
    }
src/ops/logic.rs (line 57)
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
fn codegen_compare_to_zero(
    op: &dyn BinMiniOp,
    model: &TypedModel,
    node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
    let facts = model.node_input_facts(node.id)?;
    if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
        let dt = facts[0].datum_type;
        if (dt.is_signed() || dt.is_float()) && *uniform.uni == Tensor::zero_scalar_dt(dt)? {
            let reversed = uniform.left_is_uniform;
            let mapped = || -> Box<dyn ElementWiseMiniOp> {
                macro_rules! m {
                    ($bin: ty, $same: expr, $other: expr) => {
                        if op.is::<$bin>() {
                            return if reversed {Box::new($other) } else {Box::new($same)}
                        };
                    }
                }
                m!(Less, LessThanZero {}, GreaterEqualThanZero {});
                m!(LessEqual, LessEqualThanZero {}, GreaterThanZero {});
                m!(Greater, GreaterThanZero {}, LessEqualThanZero {});
                m!(GreaterEqual, GreaterEqualThanZero {}, LessThanZero {});
                unreachable!();
            };
            return Ok(Some(TypedModelPatch::replace_single_op(
                        model,
                        node,
                        &[uniform.var],
                        ElementWiseOp(mapped()),
                        )?));
        }
    }
    Ok(None)
}

Get output tensor information for a node.

Examples found in repository?
src/model/graph.rs (line 360)
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
    pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
        Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
    }

    /// Get input tensor information for a node.
    pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
    }

    /// Get output tensor information for a node.
    pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
    }

    // outlets

    /// Get tensor information for a single outlet.
    pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
        anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
        let outlets = &self.nodes[outlet.node].outputs;
        outlets
            .get(outlet.slot)
            .map(|o| &o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get tensor information for a single outlet.
    pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        outlets
            .get_mut(outlet.slot)
            .map(|o| &mut o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get multiple mutable tensor information for outlets.
    pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
        assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
        unsafe {
            outlets
                .iter()
                .map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
                .collect()
        }
    }

    /// Set tensor information for a single outlet.
    pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        if outlets.len() <= outlet.slot {
            bail!("Invalid outlet refererence: {:?}", outlet)
        }
        outlets[outlet.slot].fact = fact;
        Ok(())
    }

    /// Set tensor information for a single outlet and return `self`.
    pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
        self.set_outlet_fact(outlet, fact)?;
        Ok(self)
    }

    // outlet labels

    /// Get label for an outlet.
    pub fn outlet_label(&self, outlet: OutletId) -> Option<&str> {
        self.outlet_labels.get(&outlet).map(|s| &**s)
    }

    /// Set label for an outlet.
    pub fn set_outlet_label(&mut self, outlet: OutletId, label: String) -> TractResult<()> {
        self.outlet_labels.insert(outlet, label);
        Ok(())
    }

    /// Set label for an outlet and return `self`.
    pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
        self.set_outlet_label(outlet, label)?;
        Ok(self)
    }

    /// Find outlet by label.
    pub fn find_outlet_label(&self, label: &str) -> Option<OutletId> {
        self.outlet_labels.iter().find(|(_k, v)| **v == label).map(|(k, _v)| *k)
    }

    // misc

    /// Computes an evalutation order for the graph inputs and outputs
    pub fn eval_order(&self) -> TractResult<Vec<usize>> {
        eval_order(self)
    }

    #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
    #[inline]
    pub fn check_edges(&self) -> TractResult<()> {
        Ok(())
    }

    /// Performs a sanity check on network connections.
    #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
    #[inline]
    pub fn check_edges(&self) -> TractResult<()> {
        for node_id in self.eval_order()? {
            let node = &self.nodes[node_id];
            for (ix, input) in node.inputs.iter().enumerate() {
                let prec = &self.nodes[input.node];
                if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
                    bail!(
                        "Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
                        node.id,
                        ix,
                        prec
                    )
                }
            }
            for (ix, output) in node.outputs.iter().enumerate() {
                for succ in &output.successors {
                    if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
                        bail!(
                            "Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
                            node.id,
                            ix,
                            succ
                        )
                    }
                }
            }
        }
        Ok(())
    }

    /// Converts the model into a `RunnableModel` which fixes the inputs and outputs and allows passing data through the model.
    pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
        crate::plan::SimplePlan::new(self)
    }

    pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
        let node = &self.nodes()[id];
        if node.inputs.len() != 1 {
            return Ok(None);
        }
        let prec = &self.nodes()[node.inputs[0].node];
        if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        Ok(Some(prec))
    }

    pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_prec(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_succ(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
        let node = &self.nodes()[id];
        if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        let succ = node.outputs[0].successors[0];
        let succ = &self.nodes()[succ.node];
        if succ.inputs.len() != 1 {
            return Ok(None);
        }
        Ok(Some(succ))
    }

    pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
        &self.nodes[outlet.node].outputs[outlet.slot].successors
    }
}

impl<F: Fact + Clone + 'static, O> Graph<F, O>
where
    F: Fact + Clone + 'static + From<std::sync::Arc<Tensor>> + Hash,
    O: fmt::Debug
        + fmt::Display
        + From<crate::ops::konst::Const>
        + AsRef<dyn Op>
        + AsMut<dyn Op>
        + Clone
        + Hash
        + 'static,
{
    pub fn add_const(
        &mut self,
        name: impl Into<String>,
        v: impl IntoArcTensor,
    ) -> TractResult<OutletId> {
        let v = v.into_arc_tensor();
        let fact = F::from(v.clone());
        let name = name.into();
        self.add_node(name, crate::ops::konst::Const::new(v), tvec!(fact)).map(|id| id.into())
    }
}

impl<F, O> fmt::Display for Graph<F, O>
where
    F: Fact + Hash + Clone + 'static,
    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
        for i in 0..self.nodes.len() {
            let input_1 = self.nodes[i]
                .inputs
                .get(0)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let input_2 = self.nodes[i]
                .inputs
                .get(1)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let output_1 = self
                .outlet_successors(OutletId::new(i, 0))
                .get(0)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let output_2 = self
                .outlet_successors(OutletId::new(i, 0))
                .get(1)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            writeln!(
                fmt,
                "{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
                i,
                input_1,
                input_2,
                output_1,
                output_2,
                self.nodes[i].op().name(),
                self.nodes[i].name,
                self.node_input_facts(i).unwrap(),
                self.node_output_facts(i).unwrap(),
            )?;
            if self.nodes[i].inputs.len() > 2 {
                writeln!(
                    fmt,
                    "                                               |   * inputs: {}",
                    self.nodes[i].inputs.iter().map(|s| format!("{:?}", s)).join(", ")
                )?;
            }
            if self.nodes[i].outputs.len() > 1
                || self.outlet_successors((i, 0).into()).len() > 2
                || (self.outlet_label(i.into()).is_some()
                    && self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
            {
                for o in 0..self.nodes[i].outputs.len() {
                    if self.outlet_successors((i, o).into()).len() > 0 {
                        writeln!(
                                    fmt,
                                    "                                               |   * output #{}: {} {}",
                                    o,
                                    self.outlet_label((i, o).into()).unwrap_or(""),
                                    self.outlet_successors((i, o).into())
                                    .iter()
                                    .map(|s| format!("{:?}", s))
                                    .join(", "),
                                    )?;
                    }
                }
            }
        }
        writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{:?}", o)).join(", "))?;
        Ok(())
    }
More examples
Hide additional examples
src/plan.rs (line 284)
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
    pub fn exec_plan_with_eval<Eval, E>(&mut self, mut eval: Eval) -> TractResult<()>
    where
        Eval: for<'a, 'b, 'c> FnMut(
            &'a mut SessionState,
            Option<&'b mut (dyn OpState + 'static)>,
            &'c Node<F, O>,
            TVec<TValue>,
        ) -> Result<TVec<TValue>, E>,
        E: Into<anyhow::Error> + Send + Sync + 'static,
    {
        {
            let &mut SimpleState {
                ref plan,
                ref mut session_state,
                ref mut states,
                ref mut values,
                ..
            } = self;
            let plan = plan.borrow();
            let model = plan.model().borrow();
            for (step, n) in plan.order.iter().enumerate() {
                let node = model.node(*n);
                trace!("Running step {}, node {}", step, node);
                let mut inputs: TVec<TValue> = tvec![];
                for i in &node.inputs {
                    trace!("  use input {:?}", i);
                    let prec_node = model.node(i.node);
                    let prec = values[i.node].as_ref().ok_or_else(|| {
                        format_err!("Computing {}, precursor {} not done:", node, prec_node)
                    })?;
                    inputs.push(prec[i.slot].clone())
                }

                for flush in &plan.flush_lists[step] {
                    trace!("  Ran {} can now flush {}", node, model.node(*flush));
                    values[*flush] = None;
                }

                if cfg!(debug_assertions) {
                    let facts = model.node_input_facts(node.id)?;
                    if facts.len() != inputs.len() {
                        bail!(
                            "Evaluating {}: expected {} inputs, got {}",
                            node,
                            facts.len(),
                            inputs.len()
                        );
                    }
                    for (ix, (v, f)) in inputs.iter().zip(facts.iter()).enumerate() {
                        if !f.matches(v, Some(&session_state.resolved_symbols))? {
                            bail!(
                                "Evaluating {}: input {:?}, expected {:?}, got {:?}",
                                node,
                                ix,
                                f,
                                v
                            );
                        }
                    }
                }

                let vs = eval(session_state, states[node.id].as_deref_mut(), node, inputs)
                    .map_err(|e| e.into())?;

                if plan.has_unresolved_symbols {
                    for (o, v) in node.outputs.iter().zip(vs.iter()) {
                        if let Ok(f) = o.fact.to_typed_fact() {
                            for (dim_abstract, dim_concrete) in f.shape.iter().zip(v.shape()) {
                                Self::resolve(
                                    &mut session_state.resolved_symbols,
                                    &dim_abstract,
                                    *dim_concrete as i64,
                                );
                            }
                        }
                    }
                }
                if cfg!(debug_assertions) {
                    let facts = model.node_output_facts(node.id)?;
                    if facts.len() != vs.len() {
                        bail!(
                            "Evaluating {}: expected {} outputs, got {}",
                            node,
                            facts.len(),
                            vs.len()
                        );
                    }
                    for (ix, (v, f)) in vs.iter().zip(facts.iter()).enumerate() {
                        if node.outputs[ix].successors.len() == 0 {
                            continue;
                        }
                        if !f.matches(v, Some(&session_state.resolved_symbols))? {
                            bail!(
                                "Evaluating {}: output {:?}, expected {:?}, got {:?}",
                                node,
                                ix,
                                f,
                                v
                            );
                        }
                    }
                }

                values[node.id] = Some(vs);
            }
        }
        Ok(())
    }

Get tensor information for a single outlet.

Examples found in repository?
src/model/graph.rs (line 189)
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    pub fn input_fact(&self, ix: usize) -> TractResult<&F> {
        let input = self.input_outlets()?[ix];
        self.outlet_fact(input)
    }

    /// Get the `ix`-th input tensor type information, mutably.
    pub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
        let input = self.input_outlets()?[ix];
        self.outlet_fact_mut(input)
    }

    /// Set the `ix`-th input tensor type information.
    pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
        let outlet = self.inputs[input];
        self.set_outlet_fact(outlet, fact)
    }

    /// Set the `ix`-th input tensor type information and return `self`.
    pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
        self.set_input_fact(input, fact)?;
        Ok(self)
    }

    // Outputs
    /// Get model outputs.
    pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
        Ok(&self.outputs)
    }

    /// Guess outputs from the topology: node or nodes with no successors.
    pub fn auto_outputs(&mut self) -> TractResult<()> {
        let outputs = self
            .nodes
            .iter()
            .flat_map(|n| {
                let id = n.id;
                n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
                    (OutletId::new(id, ix), output_fact.successors.len())
                })
            })
            .filter(|(_f, succs)| *succs == 0)
            .map(|(f, _)| f)
            .collect();
        self.outputs = outputs;
        Ok(())
    }

    /// Change model outputs.
    pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
        self.outputs = outputs.to_vec();
        Ok(())
    }

    /// Change model outputs and return `self`.
    pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
        self.set_output_outlets(outputs)?;
        Ok(self)
    }

    /// Set model outputs by node names.
    pub fn set_output_names(
        &mut self,
        outputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<()> {
        let mut labels: HashMap<Cow<str>, OutletId> =
            self.outlet_labels.iter().map(|(o, s)| (Cow::Borrowed(&**s), *o)).collect();
        for n in self.nodes() {
            for ix in 0..n.outputs.len() {
                labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
            }
        }
        let ids: Vec<OutletId> = outputs
            .into_iter()
            .map(|s| {
                let s = s.as_ref();
                labels
                    .get(s)
                    .cloned()
                    .or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
                    .ok_or_else(|| format_err!("Node {} not found", s))
            })
            .collect::<TractResult<_>>()?;
        self.outputs = ids;
        Ok(())
    }

    /// Set model outputs by node names and return `self`.
    pub fn with_output_names(
        mut self,
        outputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<Self> {
        self.set_output_names(outputs)?;
        Ok(self)
    }

    /// Get the `ix`-th input tensor type information.
    pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
        let output = self.output_outlets()?[ix];
        self.outlet_fact(output)
    }

    /// Get the `ix`-th input tensor type information, mutably.
    pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
        let output = self.output_outlets()?[ix];
        self.outlet_fact_mut(output)
    }

    /// Set the `ix`-th output tensor type information.
    pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
        let outlet = self.outputs[output];
        self.set_outlet_fact(outlet, fact)
    }

    /// Set the `ix`-th output tensor type information and return `self`.
    pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
        self.set_output_fact(output, fact)?;
        Ok(self)
    }

    // nodes and their facts

    /// Iterate over all node names.
    pub fn node_names(&self) -> impl Iterator<Item = &str> {
        self.nodes.iter().map(|s| &*s.name)
    }

    pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
        self.nodes
            .iter()
            .find(|n| n.name == name)
            .map(|n| n.id)
            .with_context(|| format!("No node found for name: \"{}\"", name))
    }

    /// Find a node by its name.
    pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
        let id: usize = self.node_id_by_name(name.as_ref())?;
        Ok(&self.nodes[id])
    }

    /// Borrow mutably a node by its name.
    pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
        let id: usize = self.node_id_by_name(name.as_ref())?;
        Ok(&mut self.nodes[id])
    }

    pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
        self.node_mut(id).name = name.to_string();
        Ok(())
    }

    /// Find a node by its id.
    pub fn node(&self, id: usize) -> &Node<F, O> {
        &self.nodes[id]
    }

    /// Find a node by its id.
    pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
        &mut self.nodes[id]
    }

    /// Access the nodes table.
    pub fn nodes(&self) -> &[Node<F, O>] {
        &self.nodes
    }

    /// Access the nodes table.
    pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
        &mut self.nodes
    }

    /// Get input and output tensor information for a node.
    pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
        Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
    }

    /// Get input tensor information for a node.
    pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
    }

    /// Get output tensor information for a node.
    pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
    }

    // outlets

    /// Get tensor information for a single outlet.
    pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
        anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
        let outlets = &self.nodes[outlet.node].outputs;
        outlets
            .get(outlet.slot)
            .map(|o| &o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get tensor information for a single outlet.
    pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        outlets
            .get_mut(outlet.slot)
            .map(|o| &mut o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get multiple mutable tensor information for outlets.
    pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
        assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
        unsafe {
            outlets
                .iter()
                .map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
                .collect()
        }
    }
More examples
Hide additional examples
src/ops/cast.rs (line 75)
70
71
72
73
74
75
76
77
78
79
80
    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if model.outlet_fact(node.inputs[0])?.datum_type == self.to {
            Ok(Some(TypedModelPatch::shunt_one_op(model, node)?))
        } else {
            Ok(None)
        }
    }
src/model/patch.rs (line 103)
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    pub fn tap_model(&mut self, model: &Graph<F, O>, outlet: OutletId) -> TractResult<OutletId> {
        let fact = model.outlet_fact(outlet)?;
        let id = self.add_source(
            format!("incoming-{}/{}", outlet.node, outlet.slot),
            dyn_clone::clone(fact),
        )?;
        self.incoming.insert(id, outlet);
        Ok(id)
    }

    pub unsafe fn shunt_outside_unchecked(
        &mut self,
        outlet: OutletId,
        by: OutletId,
    ) -> TractResult<()> {
        self.shunt_outlet_by.insert(outlet, by);
        Ok(())
    }

    /// Replace an Outlet in the target model by one from the patch.
    pub fn shunt_outside(
        &mut self,
        model: &Graph<F, O>,
        outlet: OutletId,
        by: OutletId,
    ) -> TractResult<()> {
        let original_fact = model.outlet_fact(outlet)?;
        let new_fact = self.model.outlet_fact(by)?;
        if !original_fact.compatible_with(new_fact) {
            bail!("Trying to substitute a {:?} by {:?}.\n{:?}", original_fact, new_fact, self);
        }
        self.shunt_outlet_by.insert(outlet, by);
        Ok(())
    }
src/ops/array/slice.rs (line 141)
136
137
138
139
140
141
142
143
144
145
146
147
    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if self.start.is_zero() && (self.end == model.outlet_fact(node.inputs[0])?.shape[self.axis])
        {
            Ok(Some(TypedModelPatch::shunt_one_op(model, node)?.with_context("noop")))
        } else {
            Ok(None)
        }
    }
src/ops/cnn/conv/im2col.rs (line 232)
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let input_fact = model.outlet_fact(node.inputs[0])?;
        if node.inputs.len() == 2
            && model.outlet_fact(node.inputs[1])?.konst.as_ref().and_then(|t| t.as_uniform())
                == Some(Tensor::zero_scalar_dt(input_fact.datum_type)?)
        {
            Ok(Some(
                TypedModelPatch::replace_single_op(model, node, &node.inputs[0..1], self.clone())?
                    .with_context("b0 is zero"),
            ))
        } else {
            Ok(None)
        }
    }
src/ops/invariants.rs (line 329)
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
pub fn full_axis_tracking(model: &TypedModel) -> TractResult<Vec<AxisTracking>> {
    let mut axes: Vec<AxisTracking> = vec![];
    for node in model.eval_order()? {
        for slot in 0..model.node(node).outputs.len() {
            let outlet = OutletId::new(node, slot);
            let input_fact = model.outlet_fact(outlet)?;
            'axis: for axis in 0..input_fact.rank() {
                if axes.iter().any(|tracking| tracking.outlets.get(&outlet) == Some(&axis)) {
                    continue 'axis;
                }
                axes.push(AxisTracking::for_outlet_and_axis(model, outlet, axis)?);
            }
        }
    }
    Ok(axes)
}

Get tensor information for a single outlet.

Examples found in repository?
src/model/graph.rs (line 195)
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
    pub fn input_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
        let input = self.input_outlets()?[ix];
        self.outlet_fact_mut(input)
    }

    /// Set the `ix`-th input tensor type information.
    pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
        let outlet = self.inputs[input];
        self.set_outlet_fact(outlet, fact)
    }

    /// Set the `ix`-th input tensor type information and return `self`.
    pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
        self.set_input_fact(input, fact)?;
        Ok(self)
    }

    // Outputs
    /// Get model outputs.
    pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
        Ok(&self.outputs)
    }

    /// Guess outputs from the topology: node or nodes with no successors.
    pub fn auto_outputs(&mut self) -> TractResult<()> {
        let outputs = self
            .nodes
            .iter()
            .flat_map(|n| {
                let id = n.id;
                n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
                    (OutletId::new(id, ix), output_fact.successors.len())
                })
            })
            .filter(|(_f, succs)| *succs == 0)
            .map(|(f, _)| f)
            .collect();
        self.outputs = outputs;
        Ok(())
    }

    /// Change model outputs.
    pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
        self.outputs = outputs.to_vec();
        Ok(())
    }

    /// Change model outputs and return `self`.
    pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
        self.set_output_outlets(outputs)?;
        Ok(self)
    }

    /// Set model outputs by node names.
    pub fn set_output_names(
        &mut self,
        outputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<()> {
        let mut labels: HashMap<Cow<str>, OutletId> =
            self.outlet_labels.iter().map(|(o, s)| (Cow::Borrowed(&**s), *o)).collect();
        for n in self.nodes() {
            for ix in 0..n.outputs.len() {
                labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
            }
        }
        let ids: Vec<OutletId> = outputs
            .into_iter()
            .map(|s| {
                let s = s.as_ref();
                labels
                    .get(s)
                    .cloned()
                    .or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
                    .ok_or_else(|| format_err!("Node {} not found", s))
            })
            .collect::<TractResult<_>>()?;
        self.outputs = ids;
        Ok(())
    }

    /// Set model outputs by node names and return `self`.
    pub fn with_output_names(
        mut self,
        outputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<Self> {
        self.set_output_names(outputs)?;
        Ok(self)
    }

    /// Get the `ix`-th input tensor type information.
    pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
        let output = self.output_outlets()?[ix];
        self.outlet_fact(output)
    }

    /// Get the `ix`-th input tensor type information, mutably.
    pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
        let output = self.output_outlets()?[ix];
        self.outlet_fact_mut(output)
    }

Get multiple mutable tensor information for outlets.

Set tensor information for a single outlet.

Examples found in repository?
src/model/graph.rs (line 201)
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    pub fn set_input_fact(&mut self, input: usize, fact: F) -> TractResult<()> {
        let outlet = self.inputs[input];
        self.set_outlet_fact(outlet, fact)
    }

    /// Set the `ix`-th input tensor type information and return `self`.
    pub fn with_input_fact(mut self, input: usize, fact: F) -> TractResult<Self> {
        self.set_input_fact(input, fact)?;
        Ok(self)
    }

    // Outputs
    /// Get model outputs.
    pub fn output_outlets(&self) -> TractResult<&[OutletId]> {
        Ok(&self.outputs)
    }

    /// Guess outputs from the topology: node or nodes with no successors.
    pub fn auto_outputs(&mut self) -> TractResult<()> {
        let outputs = self
            .nodes
            .iter()
            .flat_map(|n| {
                let id = n.id;
                n.outputs.iter().enumerate().map(move |(ix, output_fact)| {
                    (OutletId::new(id, ix), output_fact.successors.len())
                })
            })
            .filter(|(_f, succs)| *succs == 0)
            .map(|(f, _)| f)
            .collect();
        self.outputs = outputs;
        Ok(())
    }

    /// Change model outputs.
    pub fn set_output_outlets(&mut self, outputs: &[OutletId]) -> TractResult<()> {
        self.outputs = outputs.to_vec();
        Ok(())
    }

    /// Change model outputs and return `self`.
    pub fn with_output_outlets(mut self, outputs: &[OutletId]) -> TractResult<Self> {
        self.set_output_outlets(outputs)?;
        Ok(self)
    }

    /// Set model outputs by node names.
    pub fn set_output_names(
        &mut self,
        outputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<()> {
        let mut labels: HashMap<Cow<str>, OutletId> =
            self.outlet_labels.iter().map(|(o, s)| (Cow::Borrowed(&**s), *o)).collect();
        for n in self.nodes() {
            for ix in 0..n.outputs.len() {
                labels.insert(Cow::Owned(format!("{}:{}", &n.name, ix)), OutletId::new(n.id, ix));
            }
        }
        let ids: Vec<OutletId> = outputs
            .into_iter()
            .map(|s| {
                let s = s.as_ref();
                labels
                    .get(s)
                    .cloned()
                    .or_else(|| self.nodes.iter().find(|n| n.name == s).map(|n| n.id.into()))
                    .ok_or_else(|| format_err!("Node {} not found", s))
            })
            .collect::<TractResult<_>>()?;
        self.outputs = ids;
        Ok(())
    }

    /// Set model outputs by node names and return `self`.
    pub fn with_output_names(
        mut self,
        outputs: impl IntoIterator<Item = impl AsRef<str>>,
    ) -> TractResult<Self> {
        self.set_output_names(outputs)?;
        Ok(self)
    }

    /// Get the `ix`-th input tensor type information.
    pub fn output_fact(&self, ix: usize) -> TractResult<&F> {
        let output = self.output_outlets()?[ix];
        self.outlet_fact(output)
    }

    /// Get the `ix`-th input tensor type information, mutably.
    pub fn output_fact_mut(&mut self, ix: usize) -> TractResult<&mut F> {
        let output = self.output_outlets()?[ix];
        self.outlet_fact_mut(output)
    }

    /// Set the `ix`-th output tensor type information.
    pub fn set_output_fact(&mut self, output: usize, fact: F) -> TractResult<()> {
        let outlet = self.outputs[output];
        self.set_outlet_fact(outlet, fact)
    }

    /// Set the `ix`-th output tensor type information and return `self`.
    pub fn with_output_fact(mut self, output: usize, fact: F) -> TractResult<Self> {
        self.set_output_fact(output, fact)?;
        Ok(self)
    }

    // nodes and their facts

    /// Iterate over all node names.
    pub fn node_names(&self) -> impl Iterator<Item = &str> {
        self.nodes.iter().map(|s| &*s.name)
    }

    pub fn node_id_by_name(&self, name: &str) -> TractResult<usize> {
        self.nodes
            .iter()
            .find(|n| n.name == name)
            .map(|n| n.id)
            .with_context(|| format!("No node found for name: \"{}\"", name))
    }

    /// Find a node by its name.
    pub fn node_by_name(&self, name: impl AsRef<str>) -> TractResult<&Node<F, O>> {
        let id: usize = self.node_id_by_name(name.as_ref())?;
        Ok(&self.nodes[id])
    }

    /// Borrow mutably a node by its name.
    pub fn node_by_name_mut(&mut self, name: impl AsRef<str>) -> TractResult<&mut Node<F, O>> {
        let id: usize = self.node_id_by_name(name.as_ref())?;
        Ok(&mut self.nodes[id])
    }

    pub fn rename_node(&mut self, id: usize, name: &str) -> TractResult<()> {
        self.node_mut(id).name = name.to_string();
        Ok(())
    }

    /// Find a node by its id.
    pub fn node(&self, id: usize) -> &Node<F, O> {
        &self.nodes[id]
    }

    /// Find a node by its id.
    pub fn node_mut(&mut self, id: usize) -> &mut Node<F, O> {
        &mut self.nodes[id]
    }

    /// Access the nodes table.
    pub fn nodes(&self) -> &[Node<F, O>] {
        &self.nodes
    }

    /// Access the nodes table.
    pub fn nodes_mut(&mut self) -> &mut [Node<F, O>] {
        &mut self.nodes
    }

    /// Get input and output tensor information for a node.
    pub fn node_facts(&self, id: usize) -> TractResult<(TVec<&F>, TVec<&F>)> {
        Ok((self.node_input_facts(id)?, self.node_output_facts(id)?))
    }

    /// Get input tensor information for a node.
    pub fn node_input_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        self.nodes[node_id].inputs.iter().map(|o| self.outlet_fact(*o)).collect()
    }

    /// Get output tensor information for a node.
    pub fn node_output_facts(&self, node_id: usize) -> TractResult<TVec<&F>> {
        Ok(self.nodes[node_id].outputs.iter().map(|o| &o.fact).collect())
    }

    // outlets

    /// Get tensor information for a single outlet.
    pub fn outlet_fact(&self, outlet: OutletId) -> TractResult<&F> {
        anyhow::ensure!(outlet.node < self.nodes.len(), "Invalid outlet for graph");
        let outlets = &self.nodes[outlet.node].outputs;
        outlets
            .get(outlet.slot)
            .map(|o| &o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get tensor information for a single outlet.
    pub fn outlet_fact_mut(&mut self, outlet: OutletId) -> TractResult<&mut F> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        outlets
            .get_mut(outlet.slot)
            .map(|o| &mut o.fact)
            .with_context(|| format!("Invalid outlet reference: {:?}", outlet))
    }

    /// Get multiple mutable tensor information for outlets.
    pub fn outlets_fact_mut(&mut self, outlets: &[OutletId]) -> TractResult<TVec<&mut F>> {
        assert!(outlets.iter().tuple_combinations().all(|(a, b)| a != b));
        unsafe {
            outlets
                .iter()
                .map(|o| Ok((self.outlet_fact(*o)? as *const F as *mut F).as_mut().unwrap()))
                .collect()
        }
    }

    /// Set tensor information for a single outlet.
    pub fn set_outlet_fact(&mut self, outlet: OutletId, fact: F) -> TractResult<()> {
        let outlets = &mut self.nodes[outlet.node].outputs;
        if outlets.len() <= outlet.slot {
            bail!("Invalid outlet refererence: {:?}", outlet)
        }
        outlets[outlet.slot].fact = fact;
        Ok(())
    }

    /// Set tensor information for a single outlet and return `self`.
    pub fn with_outlet_fact(mut self, outlet: OutletId, fact: F) -> TractResult<Self> {
        self.set_outlet_fact(outlet, fact)?;
        Ok(self)
    }

Set tensor information for a single outlet and return self.

Get label for an outlet.

Examples found in repository?
src/model/translator.rs (line 40)
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    fn translate_model_with_mappings(
        &self,
        source: &Graph<TI1, O1>,
    ) -> TractResult<(Graph<TI2, O2>, HashMap<OutletId, OutletId>)> {
        let mut target = Graph::default();
        let mut mapping = HashMap::new();
        for old_id in source.eval_order()? {
            let node = source.node(old_id);
            trace!("Translating {} {:?}", node, self);
            let outlets = self
                .translate_node(source, node, &mut target, &mapping)
                .with_context(|| format!("Translating node {} {:?}", node, self))?;
            for (ix, outlet) in outlets.into_iter().enumerate() {
                mapping.insert(OutletId::new(node.id, ix), outlet);
                if let Some(label) = source.outlet_label(OutletId::new(node.id, ix)) {
                    target.set_outlet_label(outlet, label.to_string())?;
                }
            }
        }
        // do not drop inputs, even if they are useless, to maintain interface
        for i in source.input_outlets()? {
            if !mapping.contains_key(i) {
                let node = source.node(i.node);
                trace!("Translate useless source {}", node);
                let outlets = self
                    .translate_node(source, node, &mut target, &mapping)
                    .with_context(|| format!("Translating input {} {:?}", node, self))?;
                mapping.insert(*i, outlets[0]);
            }
        }
        // maintaining order of i/o interface
        target.inputs = source.input_outlets()?.iter().map(|i| mapping[i]).collect();
        target.outputs = source.output_outlets()?.iter().map(|o| mapping[o]).collect();
        target.symbol_table = source.symbol_table.clone();
        target.properties = source.properties.clone();
        Ok((target, mapping))
    }
More examples
Hide additional examples
src/model/graph.rs (line 623)
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
        for i in 0..self.nodes.len() {
            let input_1 = self.nodes[i]
                .inputs
                .get(0)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let input_2 = self.nodes[i]
                .inputs
                .get(1)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let output_1 = self
                .outlet_successors(OutletId::new(i, 0))
                .get(0)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let output_2 = self
                .outlet_successors(OutletId::new(i, 0))
                .get(1)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            writeln!(
                fmt,
                "{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
                i,
                input_1,
                input_2,
                output_1,
                output_2,
                self.nodes[i].op().name(),
                self.nodes[i].name,
                self.node_input_facts(i).unwrap(),
                self.node_output_facts(i).unwrap(),
            )?;
            if self.nodes[i].inputs.len() > 2 {
                writeln!(
                    fmt,
                    "                                               |   * inputs: {}",
                    self.nodes[i].inputs.iter().map(|s| format!("{:?}", s)).join(", ")
                )?;
            }
            if self.nodes[i].outputs.len() > 1
                || self.outlet_successors((i, 0).into()).len() > 2
                || (self.outlet_label(i.into()).is_some()
                    && self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
            {
                for o in 0..self.nodes[i].outputs.len() {
                    if self.outlet_successors((i, o).into()).len() > 0 {
                        writeln!(
                                    fmt,
                                    "                                               |   * output #{}: {} {}",
                                    o,
                                    self.outlet_label((i, o).into()).unwrap_or(""),
                                    self.outlet_successors((i, o).into())
                                    .iter()
                                    .map(|s| format!("{:?}", s))
                                    .join(", "),
                                    )?;
                    }
                }
            }
        }
        writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{:?}", o)).join(", "))?;
        Ok(())
    }

Set label for an outlet.

Examples found in repository?
src/model/graph.rs (line 436)
435
436
437
438
    pub fn with_outlet_label(mut self, outlet: OutletId, label: String) -> TractResult<Self> {
        self.set_outlet_label(outlet, label)?;
        Ok(self)
    }
More examples
Hide additional examples
src/model/translator.rs (line 41)
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    fn translate_model_with_mappings(
        &self,
        source: &Graph<TI1, O1>,
    ) -> TractResult<(Graph<TI2, O2>, HashMap<OutletId, OutletId>)> {
        let mut target = Graph::default();
        let mut mapping = HashMap::new();
        for old_id in source.eval_order()? {
            let node = source.node(old_id);
            trace!("Translating {} {:?}", node, self);
            let outlets = self
                .translate_node(source, node, &mut target, &mapping)
                .with_context(|| format!("Translating node {} {:?}", node, self))?;
            for (ix, outlet) in outlets.into_iter().enumerate() {
                mapping.insert(OutletId::new(node.id, ix), outlet);
                if let Some(label) = source.outlet_label(OutletId::new(node.id, ix)) {
                    target.set_outlet_label(outlet, label.to_string())?;
                }
            }
        }
        // do not drop inputs, even if they are useless, to maintain interface
        for i in source.input_outlets()? {
            if !mapping.contains_key(i) {
                let node = source.node(i.node);
                trace!("Translate useless source {}", node);
                let outlets = self
                    .translate_node(source, node, &mut target, &mapping)
                    .with_context(|| format!("Translating input {} {:?}", node, self))?;
                mapping.insert(*i, outlets[0]);
            }
        }
        // maintaining order of i/o interface
        target.inputs = source.input_outlets()?.iter().map(|i| mapping[i]).collect();
        target.outputs = source.output_outlets()?.iter().map(|o| mapping[o]).collect();
        target.symbol_table = source.symbol_table.clone();
        target.properties = source.properties.clone();
        Ok((target, mapping))
    }
src/model/patch.rs (line 316)
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    pub fn apply(self, target: &mut Graph<F, O>) -> TractResult<()> {
        let prior_target_inputs = target.input_outlets()?.len();
        let prior_target_outputs = target.output_outlets()?.len();
        let ModelPatch {
            model: patch,
            incoming: mut mapping,
            shunt_outlet_by,
            obliterate,
            inputs: replaced_inputs,
            ..
        } = self;
        let mut all_inputs = HashMap::new(); // new_node_id_in_model -> [ patch_outlet_id ]
        let mut model_input_outlets = target.input_outlets()?.to_vec();
        for node in patch.nodes {
            if <Graph<F, O>>::is_source(&node.op)
                && mapping.contains_key(&OutletId::new(node.id, 0))
            {
                // this is a tap
                continue;
            }
            let Node { id: patch_node_id, name, inputs, op, outputs } = node;
            let n_outputs = outputs.len();
            for dup in 0..target.nodes.len() {
                if target.node(dup).op().same_as(op.as_ref())
                    && inputs.len() == target.node(dup).inputs.len()
                    && inputs
                        .iter()
                        .zip(target.node(dup).inputs.iter())
                        .all(|(patch_input, d)| mapping[patch_input] == *d)
                {
                    for ix in 0..n_outputs {
                        mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(dup, ix));
                    }
                    continue;
                }
            }
            let facts = outputs.into_iter().map(|of| of.fact).collect();
            let added_node_id = target.add_node(name, op, facts)?;
            for ix in 0..n_outputs {
                mapping.insert(OutletId::new(patch_node_id, ix), OutletId::new(added_node_id, ix));
            }
            all_inputs.insert(added_node_id, inputs);
            if <Graph<F, O>>::is_source(&target.node(added_node_id).op) {
                // this is actually an input replacement
                model_input_outlets.iter_mut().for_each(|oo| {
                    if oo.node == replaced_inputs[&patch_node_id] {
                        oo.node = added_node_id;
                    }
                });
            }
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for (outlet, by) in shunt_outlet_by {
            let replace_by = mapping[&by];
            let succs = target.nodes()[outlet.node].outputs[outlet.slot].successors.clone();
            for succ in succs {
                target.add_edge(replace_by, succ)?;
            }
            for o in target.outputs.iter_mut() {
                if *o == outlet {
                    *o = replace_by;
                }
            }
            if let Some(label) = target.outlet_labels.remove(&outlet) {
                target.set_outlet_label(replace_by, label)?;
            }
        }
        if target.outputs.len() > target.outputs.iter().sorted().dedup().count() {
            bail!("Duplicate usage of node as output");
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for (node, inputs) in all_inputs {
            for (ix, input) in inputs.into_iter().enumerate() {
                target.add_edge(mapping[&input], InletId::new(node, ix))?;
            }
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        for node in obliterate {
            target.node_mut(node).op = target.create_dummy();
        }
        debug_assert_eq!(target.input_outlets()?.len(), prior_target_inputs);
        debug_assert_eq!(target.output_outlets()?.len(), prior_target_outputs);
        target.set_input_outlets(&model_input_outlets)?;
        Ok(())
    }

Set label for an outlet and return self.

Find outlet by label.

Computes an evalutation order for the graph inputs and outputs

Examples found in repository?
src/ops/invariants.rs (line 326)
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
pub fn full_axis_tracking(model: &TypedModel) -> TractResult<Vec<AxisTracking>> {
    let mut axes: Vec<AxisTracking> = vec![];
    for node in model.eval_order()? {
        for slot in 0..model.node(node).outputs.len() {
            let outlet = OutletId::new(node, slot);
            let input_fact = model.outlet_fact(outlet)?;
            'axis: for axis in 0..input_fact.rank() {
                if axes.iter().any(|tracking| tracking.outlets.get(&outlet) == Some(&axis)) {
                    continue 'axis;
                }
                axes.push(AxisTracking::for_outlet_and_axis(model, outlet, axis)?);
            }
        }
    }
    Ok(axes)
}
More examples
Hide additional examples
src/optim/op_optim.rs (line 23)
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    fn full_pass(
        &mut self,
        session: &mut OptimizerSession,
        new: &TypedModel,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (ix, &id) in new.eval_order()?.iter().enumerate().skip(self.2) {
            let node = &new.nodes()[id];
            let patch = (self.1)(node.op.as_ref(), session, new, node)
                .with_context(|| format!("{:?} node {}", self, node))?;
            if let Some(mut p) = patch {
                p.push_context(format!("{:?} {}", self, node));
                self.2 = ix + p.dont_apply_twice.is_some() as usize;
                return Ok(Some(p));
            }
        }
        Ok(None)
    }
src/ops/scan/mir.rs (line 85)
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    fn declutter_body_axes(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut suggestions = vec![];
        for n in self.body.eval_order()? {
            let node = self.body.node(n);
            for suggestion in node.op.suggested_axis_changes()? {
                let outlet = suggestion.0.as_outlet(node);
                suggestions.push(AxisChange { outlet, op: suggestion.1 })
            }
        }
        for suggestion in suggestions.into_iter() {
            if let Some(op) =
                self.try_body_axes_change(suggestion, true)?.and_then(|c| c.substitute_op)
            {
                return Ok(Some(TypedModelPatch::replace_single_op(
                    model,
                    node,
                    &node.inputs,
                    op,
                )?));
            }
        }
        Ok(None)
    }
src/optim/change_axes.rs (line 31)
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    fn next(
        &mut self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut interfaces = model.output_outlets()?.to_vec();
        interfaces.extend(model.input_outlets()?.iter());
        for n in model.eval_order()? {
            for suggestion in model.node(n).op.suggested_axis_changes()? {
                if self.0.insert((n, suggestion.clone())) {
                    let outlet = suggestion.0.as_outlet(model.node(n));
                    let change = AxisChange { outlet, op: suggestion.1.clone() };
                    if let Some((patch, _)) = change_axes(model, &change, &interfaces, &[])
                        .with_context(|| {
                            format!("Making patch for {:?} from {}", change, model.node(n))
                        })?
                    {
                        return Ok(Some(patch));
                    }
                }
            }
        }
        Ok(None)
    }
src/optim/push_split_down.rs (line 15)
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    fn next(&mut self, _session: &mut OptimizerSession, model: &TypedModel) -> TractResult<Option<TypedModelPatch>> {
        let mut patch = TypedModelPatch::default();
        for node in model.eval_order()? {
            for output in &model.node(node).outputs {
                for (a, b) in output.successors.iter().tuple_combinations() {
                    if patch.obliterate.contains(&b.node) {
                        continue;
                    }
                    let a = model.node(a.node);
                    let b = model.node(b.node);
                    if a.same_as(b) {
                        for slot in 0..b.outputs.len() {
                            let tap = patch.tap_model(model, OutletId::new(a.id, slot))?;
                            patch.shunt_outside(model, OutletId::new(b.id, slot), tap)?;
                            patch.obliterate(b.id)?;
                        }
                    }
                }
            }
        }
        Ok(Some(patch).filter(|p| !p.is_empty()))
    }
src/model/graph.rs (line 462)
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
    pub fn check_edges(&self) -> TractResult<()> {
        for node_id in self.eval_order()? {
            let node = &self.nodes[node_id];
            for (ix, input) in node.inputs.iter().enumerate() {
                let prec = &self.nodes[input.node];
                if !prec.outputs[input.slot].successors.contains(&InletId::new(node.id, ix)) {
                    bail!(
                        "Mismatched oncoming edge, node:{} input:{} to {:?} not reciprocated",
                        node.id,
                        ix,
                        prec
                    )
                }
            }
            for (ix, output) in node.outputs.iter().enumerate() {
                for succ in &output.successors {
                    if self.nodes[succ.node].inputs[succ.slot] != OutletId::new(node.id, ix) {
                        bail!(
                            "Mismatched outgoing edge, node:{} output:{} to {:?} not reciprocated",
                            node.id,
                            ix,
                            succ
                        )
                    }
                }
            }
        }
        Ok(())
    }

    /// Converts the model into a `RunnableModel` which fixes the inputs and outputs and allows passing data through the model.
    pub fn into_runnable(self) -> TractResult<RunnableModel<F, O, Self>> {
        crate::plan::SimplePlan::new(self)
    }

    pub fn single_prec(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
        let node = &self.nodes()[id];
        if node.inputs.len() != 1 {
            return Ok(None);
        }
        let prec = &self.nodes()[node.inputs[0].node];
        if prec.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        Ok(Some(prec))
    }

    pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_prec(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_succ(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }

    pub fn single_succ(&self, id: usize) -> TractResult<Option<&Node<F, O>>> {
        let node = &self.nodes()[id];
        if node.outputs.iter().map(|of| of.successors.len()).sum::<usize>() != 1 {
            return Ok(None);
        }
        let succ = node.outputs[0].successors[0];
        let succ = &self.nodes()[succ.node];
        if succ.inputs.len() != 1 {
            return Ok(None);
        }
        Ok(Some(succ))
    }

    pub fn outlet_successors(&self, outlet: OutletId) -> &[InletId] {
        &self.nodes[outlet.node].outputs[outlet.slot].successors
    }
}

impl<F: Fact + Clone + 'static, O> Graph<F, O>
where
    F: Fact + Clone + 'static + From<std::sync::Arc<Tensor>> + Hash,
    O: fmt::Debug
        + fmt::Display
        + From<crate::ops::konst::Const>
        + AsRef<dyn Op>
        + AsMut<dyn Op>
        + Clone
        + Hash
        + 'static,
{
    pub fn add_const(
        &mut self,
        name: impl Into<String>,
        v: impl IntoArcTensor,
    ) -> TractResult<OutletId> {
        let v = v.into_arc_tensor();
        let fact = F::from(v.clone());
        let name = name.into();
        self.add_node(name, crate::ops::konst::Const::new(v), tvec!(fact)).map(|id| id.into())
    }
}

impl<F, O> fmt::Display for Graph<F, O>
where
    F: Fact + Hash + Clone + 'static,
    O: fmt::Debug + fmt::Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static + Hash,
{
    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
        for i in 0..self.nodes.len() {
            let input_1 = self.nodes[i]
                .inputs
                .get(0)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let input_2 = self.nodes[i]
                .inputs
                .get(1)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let output_1 = self
                .outlet_successors(OutletId::new(i, 0))
                .get(0)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let output_2 = self
                .outlet_successors(OutletId::new(i, 0))
                .get(1)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            writeln!(
                fmt,
                "{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
                i,
                input_1,
                input_2,
                output_1,
                output_2,
                self.nodes[i].op().name(),
                self.nodes[i].name,
                self.node_input_facts(i).unwrap(),
                self.node_output_facts(i).unwrap(),
            )?;
            if self.nodes[i].inputs.len() > 2 {
                writeln!(
                    fmt,
                    "                                               |   * inputs: {}",
                    self.nodes[i].inputs.iter().map(|s| format!("{:?}", s)).join(", ")
                )?;
            }
            if self.nodes[i].outputs.len() > 1
                || self.outlet_successors((i, 0).into()).len() > 2
                || (self.outlet_label(i.into()).is_some()
                    && self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
            {
                for o in 0..self.nodes[i].outputs.len() {
                    if self.outlet_successors((i, o).into()).len() > 0 {
                        writeln!(
                                    fmt,
                                    "                                               |   * output #{}: {} {}",
                                    o,
                                    self.outlet_label((i, o).into()).unwrap_or(""),
                                    self.outlet_successors((i, o).into())
                                    .iter()
                                    .map(|s| format!("{:?}", s))
                                    .join(", "),
                                    )?;
                    }
                }
            }
        }
        writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{:?}", o)).join(", "))?;
        Ok(())
    }
}

impl<F, O> Graph<F, O>
where
    F: Fact + Clone + 'static + std::hash::Hash + for<'a> std::convert::From<&'a F>,
    O: std::fmt::Display
        + std::fmt::Debug
        + Clone
        + AsRef<dyn Op>
        + AsMut<dyn Op>
        + Clone
        + 'static
        + std::hash::Hash
        + for<'a> std::convert::From<&'a O>,
    Graph<F, O>: SpecialOps<F, O>,
{
    #[cfg(debug_assertions)]
    pub fn check_compact(&self) -> TractResult<()> {
        let order = self.eval_order()?;
        let useless_sources = self
            .input_outlets()?
            .iter()
            .filter(|io| {
                self.outlet_successors(**io).len() == 0
                    && !self.output_outlets().unwrap().contains(io)
            })
            .count();
        if order.len() + useless_sources != self.nodes.len() {
            bail!(
                "Eval order is {} long, nodes are {}, including {} unused sources",
                order.len(),
                self.nodes.len(),
                useless_sources
            );
        }
        if (0..order.len()).any(|ix| order[ix] != ix) {
            bail!("eval order is not trivial");
        }
        let mut seen = std::collections::HashSet::new();
        for (ix, n) in self.nodes.iter().enumerate() {
            if ix != n.id {
                bail!("Invalid node id: position is {}, node is {}", ix, n);
            }
            if seen.contains(&n.name) {
                eprintln!("{}", self);
                bail!("duplicate name {}", n.name);
            }
            seen.insert(&n.name);
        }
        Ok(())
    }

Performs a sanity check on network connections.

Examples found in repository?
src/model/typed.rs (line 108)
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    pub fn check_consistency(&self) -> TractResult<()> {
        self.check_edges()?;
        for node_id in &self.eval_order()? {
            let input_facts = self.node_input_facts(*node_id)?;
            let node = &self.nodes[*node_id];
            if node.id != *node_id {
                bail!("Node at position {} has id {}", node_id, node.id);
            }
            let output_facts = node.op.output_facts(&input_facts)?;
            if node.outputs.len() != output_facts.len() {
                bail!(
                    "Inconsistent model, node output count mismatch. Op says {}, node says {}. {}",
                    output_facts.len(),
                    node.outputs.len(),
                    node
                );
            }
            if node
                .outputs
                .iter()
                .map(|o| &o.fact)
                .zip(output_facts.iter())
                .any(|(a, b)| a.datum_type != b.datum_type || a.shape != b.shape)
            {
                bail!(
                            "Inconsistent model, output types mismatch. Op says: {:?}, node says: {:?}. {} with inputs {:?}. {}",
                            output_facts, node.outputs.iter().map(|o| &o.fact).collect::<Vec<_>>(), node, input_facts, node)
            }
        }
        for node in &self.nodes {
            for (ix, output) in node.outputs.iter().enumerate() {
                output.fact.consistent().with_context(|| {
                    format!("Inconsistent fact {:?}: {:?}", OutletId::new(node.id, ix), output.fact)
                })?
            }
        }
        Ok(())
    }

Converts the model into a RunnableModel which fixes the inputs and outputs and allows passing data through the model.

Examples found in repository?
src/ops/cnn/deconv/unary.rs (line 162)
156
157
158
159
160
161
162
163
    fn eval(&self, mut inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let input = args_1!(inputs);
        let mut model = TypedModel::default();
        let source = model.add_source("source", input.datum_type().fact(input.shape()))?;
        let output = self.wire_with_deconv_sum("adhoc", &mut model, source)?;
        model.set_output_outlets(&output)?;
        model.into_runnable()?.run(tvec!(input))
    }
More examples
Hide additional examples
src/ops/cnn/conv/unary.rs (line 804)
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let mut model = TypedModel::default();

        let mut wires: TVec<OutletId> = inputs
            .iter()
            .enumerate()
            .map(|(ix, v)| {
                model.add_source(format!("source.{}", ix), v.datum_type().fact(v.shape()))
            })
            .collect::<TractResult<_>>()?;
        let new_op = self.kernel_offset_u8_as_i8(&mut wires, &mut model)?;
        let wire = unsafe {
            if self.q_params.is_some() {
                let op_ref = if let Some(op) = new_op.as_ref() { op } else { self };
                op_ref.wire_as_quant_im2col(
                    &mut model,
                    "im2col-adhoc",
                    inputs[0].datum_type(),
                    &wires,
                )?
            } else {
                self.wire_as_im2col_pair(&mut model, "im2col-adhoc", wires[0])?
            }
        };
        model.set_output_outlets(&[wire])?;
        model.into_runnable()?.run(inputs)
    }
src/ops/matmul/mir_quant.rs (line 307)
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        ensure!(
            inputs[0].rank() == inputs[1].rank(),
            "Rank mismatch {:?} vs {:?}",
            inputs[0],
            inputs[1]
        );

        let mut model = TypedModel::default();
        let a = model.add_const("source_a", inputs[0].clone().into_arc_tensor())?;
        let b = model.add_const("source_b", inputs[1].clone().into_arc_tensor())?;
        let bias = model.add_const("source_bias", inputs[2].clone().into_arc_tensor())?;

        let mut input_outlets = tvec![a, b, bias];
        for (i, t) in inputs.iter().enumerate().skip(3) {
            input_outlets
                .push(model.add_const(format!("source_{}", i), t.clone().into_arc_tensor())?)
        }

        let mut params = self.params.as_outlet_ids(
            &mut model,
            "qmatmul_unary",
            &input_outlets,
            inputs[0].datum_type(),
            inputs[1].datum_type(),
            self.output_type,
        )?;

        let a = wire_offset_u8_as_i8(&mut model, "adhoc", a, "a", &mut params[0], "a0")?;
        let b = wire_offset_u8_as_i8(&mut model, "adhoc", b, "b", &mut params[2], "b0")?;

        let new_op = MatMul { axes: self.axes };
        let result = model.wire_node("adhoc.matmul", new_op, &[a, b])?[0];
        let result = wire_matmul_quant(
            &mut model,
            "adhoc",
            a,
            b,
            Some(bias),
            self.axes,
            result,
            self.output_type,
            &params,
        )?;
        model.set_output_outlets(&[result])?;
        model.into_runnable()?.run(tvec![])
    }
src/ops/matmul/mir_quant_unary.rs (line 78)
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        ensure!(inputs[0].rank() == self.a.rank(), "Rank mismatch {:?} vs {:?}", inputs[0], self.a);

        let mut model = TypedModel::default();
        let t_a = self.a.offset_u8_as_i8();
        let a = model.add_const("source_a", self.a.clone())?;
        let b = model.add_const("source_b", inputs[0].clone().into_arc_tensor())?;
        let bias = if let Some(bias) = self.bias.clone() {
            Some(model.add_const("source_bias", bias)?)
        } else {
            None
        };

        let mut input_outlets = tvec![a];
        for (i, t) in inputs.iter().enumerate().skip(1) {
            input_outlets
                .push(model.add_const(format!("source_{}", i), t.clone().into_arc_tensor())?)
        }

        let mut params = self.params.as_outlet_ids(
            &mut model,
            "qmatmul_unary",
            &input_outlets,
            self.a.datum_type(),
            inputs[0].datum_type(),
            self.output_type,
        )?;
        let a = wire_offset_u8_as_i8(&mut model, "adhoc", a, "a", &mut params[0], "a0")?;
        let b = wire_offset_u8_as_i8(&mut model, "adhoc", b, "b", &mut params[2], "b0")?;

        let new_op = MatMulUnary { a: t_a, axes: self.axes };
        let result = model.wire_node("adhoc.matmul", new_op, &[b])?[0];
        let result = wire_matmul_quant(
            &mut model,
            "adhoc",
            a,
            b,
            bias,
            self.axes,
            result,
            self.output_type,
            &params,
        )?;
        model.set_output_outlets(&[result])?;
        model.into_runnable()?.run(tvec![])
    }
Examples found in repository?
src/model/graph.rs (line 511)
508
509
510
511
512
513
514
515
516
517
518
    pub fn single_prec_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_prec(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }
More examples
Hide additional examples
src/ops/math/mod.rs (line 436)
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
fn declutter_recip(model: &TypedModel, node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
    use super::element_wise::*;
    if let Some(prec) = model.single_prec(node.id)? {
        if let Some(ew) = prec.op_as::<ElementWiseOp>() {
            let repl = if ew.0.is::<Sqrt>() {
                Some(rsqrt())
            } else if ew.0.is::<Rsqrt>() {
                Some(sqrt())
            } else {
                None
            };
            if let Some(repl) = repl {
                let mut patch = TypedModelPatch::default();
                let mut wire = patch.tap_model(model, prec.inputs[0])?;
                wire = patch.wire_node(&node.name, repl, &[wire])?[0];
                patch.shunt_outside(model, node.id.into(), wire)?;
                return Ok(Some(patch));
            }
        }
    }
    Ok(None)
}
src/ops/downsample/mod.rs (line 115)
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
fn pull_downsample_up(
    model: &TypedModel,
    down_node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
    model.check_consistency()?;
    let down_op = down_node.op_as::<Downsample>().unwrap();
    if let Some(prec) = model.single_prec(down_node.id)? {
        let (input_facts, output_facts) = model.node_facts(prec.id)?;
        let invariants = prec.op.invariants(&input_facts, &output_facts)?;
        debug!("Consider pull {:?} over {:?} (invariants: {:?})", down_op, prec, invariants);
        if let Some(slice_op) = prec.op_as::<ops::array::Slice>() {
            if let Some(p) = array::pull_downsample_over_slice(model, prec, slice_op, down_node, down_op)? {
                return Ok(Some(p))
            }
        } else if let Some(other_op) = prec.op_as::<AxisOp>() {
            return array::pull_downsample_over_axis_op(model, prec, other_op, down_node, down_op);
        } else if let Some(conv_op) = prec.op_as::<ops::cnn::conv::ConvUnary>() {
            return conv::fuse_downsample_into_conv(model, prec, conv_op, down_node, down_op);
        } else if let Some(other_op) = prec.op_as::<ops::scan::Scan>() {
            return scan::pull_downsample_over_scan(model, prec, other_op, down_node, down_op);
        }
        if let Some(above_axis) = invariants.unary_track_axis_up(down_op.axis, false) {
            let mut patch = TypedModelPatch::default();
            let mut inputs = vec![];
            for (ix, &oo) in prec.inputs.iter().enumerate() {
                let source = patch.tap_model(model, oo)?;
                let mut op = down_op.clone();
                op.axis = above_axis;
                let ds = patch.wire_node(
                    format!("{}.{}-{}", down_node.name, prec.name, ix),
                    op,
                    [source].as_ref(),
                )?;
                inputs.push(ds[0]);
            }
            let other = patch.wire_node(&prec.name, prec.op.clone(), &inputs)?;
            patch.shunt_outside(model, OutletId::new(down_node.id, 0), other[0])?;
            return Ok(Some(patch));
        }
    }
    Ok(None)
}
Examples found in repository?
src/model/graph.rs (line 523)
520
521
522
523
524
525
526
527
528
529
530
    pub fn single_succ_at(&self, id: usize, count: usize) -> TractResult<Option<&Node<F, O>>> {
        let mut node = self.node(id);
        for _ in 0..count {
            if let Some(next) = self.single_succ(node.id)? {
                node = next
            } else {
                return Ok(None);
            }
        }
        Ok(Some(node))
    }
More examples
Hide additional examples
src/model/patch.rs (line 170)
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    pub fn fuse_with_next<IO: Into<O>>(
        patched_model: &Graph<F, O>,
        node: &Node<F, O>,
        new_op: IO,
    ) -> TractResult<ModelPatch<F, O>> {
        let mut patch = ModelPatch::default();
        let succ = if let Some(succ) = patched_model.single_succ(node.id)? {
            succ
        } else {
            bail!("Non single successor fuse attempt")
        };
        let new_op = new_op.into();
        let by = patch.add_node(&*node.name, new_op, tvec!(succ.outputs[0].fact.clone()))?;
        for (ix, i) in node.inputs.iter().enumerate() {
            let o = patch.tap_model(patched_model, *i)?;
            patch.add_edge(o, InletId::new(by, ix))?;
        }
        for ix in 0..node.outputs.len() {
            patch.shunt_outside(
                patched_model,
                OutletId::new(succ.id, ix),
                OutletId::new(by, ix),
            )?;
        }
        Ok(patch)
    }
src/ops/quant.rs (line 153)
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    fn declutter(
        &self,
        model: &TypedModel,
        dequant: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut current = dequant;
        let incoming_dt = model.node_input_facts(dequant.id)?[0].datum_type;
        while let Some(quant) = model.single_succ(current.id)? {
            let q_params = if let Some(op) = quant.op_as::<ElementWiseOp>() {
                if let Some(mop) = op.0.downcast_ref::<QuantizeLinearU8>() {
                    Some((mop.scale, mop.zero_point as i32, u8::datum_type()))
                } else {
                    op.0.downcast_ref::<QuantizeLinearI8>()
                        .map(|mop| (mop.scale, mop.zero_point as i32, i8::datum_type()))
                }
            } else {
                None
            };
            if let Some((scale, zero_point, dt)) = q_params {
                // first, try Op::quantize() on all ops in the chain
                let mut patch = TypedModelPatch::default();
                let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;
                let mut next = model.single_succ(dequant.id)?.unwrap();
                loop {
                    if let Some(op) = next
                        .op
                        .quantize(model, dequant, dt, scale, zero_point)
                        .with_context(|| format!("Quantizing {}", next))?
                    {
                        wire = patch.wire_node(&*next.name, op, [wire].as_ref())?[0];
                    } else {
                        break;
                    }
                    if next.id == current.id {
                        patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
                        return Ok(Some(patch));
                    } else {
                        next = model.single_succ(next.id)?.unwrap();
                    }
                }
                // or else make a lookup table
                if incoming_dt == DatumType::I8 || incoming_dt == DatumType::U8 {
                    let mut adhoc_model = TypedModel::default();
                    let mut wire = adhoc_model.add_source("ad-hoc", dt.fact([256]))?;
                    let mut next = model.single_succ(dequant.id)?.unwrap();
                    let mut name = None;
                    // plug in dequant
                    wire = adhoc_model.wire_node(
                        &*dequant.name,
                        dequant.op.clone(),
                        [wire].as_ref(),
                    )?[0];
                    while next.id != quant.id {
                        name.get_or_insert(&*next.name);
                        wire =
                            adhoc_model.wire_node(&*next.name, next.op.clone(), [wire].as_ref())?
                                [0];
                        next = model.single_succ(next.id)?.unwrap();
                    }
                    // plug in quant
                    wire =
                        adhoc_model.wire_node(&*quant.name, quant.op.clone(), [wire].as_ref())?[0];
                    adhoc_model.set_output_outlets(&[wire])?;
                    let input = (0u8..=255).collect::<Vec<u8>>();
                    let input = match dt {
                        DatumType::I8 => unsafe {
                            tensor1(std::mem::transmute::<&[u8], &[i8]>(&*input))
                        },
                        DatumType::U8 => tensor1(&input),
                        _ => unreachable!(),
                    };
                    let output =
                        SimplePlan::new(adhoc_model)?.run(tvec!(input.into_tvalue()))?.remove(0);
                    let table: &[u8] = match dt {
                        DatumType::I8 => unsafe { std::mem::transmute(output.as_slice::<i8>()?) },
                        DatumType::U8 => output.as_slice::<u8>()?,
                        _ => unreachable!(),
                    };
                    let op = lookup_table((tract_linalg::ops().lut_u8)(table));
                    let mut patch = TypedModelPatch::default();
                    let mut wire: OutletId = patch.tap_model(model, dequant.inputs[0])?;

                    wire = patch.wire_node(name.unwrap_or(&*dequant.name), op, [wire].as_ref())?[0];
                    patch.shunt_outside(model, OutletId::new(quant.id, 0), wire)?;
                    return Ok(Some(patch));
                }
            }
            let (input_facts, output_facts) = model.node_facts(quant.id)?;
            let invariants = quant
                .op
                .invariants(&input_facts, &output_facts)
                .with_context(|| format!("Querying invariants for {}", quant))?;
            if invariants.element_wise() {
                current = quant;
            } else {
                break;
            }
        }
        Ok(None)
    }
Examples found in repository?
src/model/graph.rs (line 592)
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
        for i in 0..self.nodes.len() {
            let input_1 = self.nodes[i]
                .inputs
                .get(0)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let input_2 = self.nodes[i]
                .inputs
                .get(1)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let output_1 = self
                .outlet_successors(OutletId::new(i, 0))
                .get(0)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            let output_2 = self
                .outlet_successors(OutletId::new(i, 0))
                .get(1)
                .map(|o| format!("{:?}", o))
                .unwrap_or_else(|| "".to_string());
            writeln!(
                fmt,
                "{:5} | {:8} {:8} -> {:8} {:8} | {:25} {:50} {:?} => {:?}",
                i,
                input_1,
                input_2,
                output_1,
                output_2,
                self.nodes[i].op().name(),
                self.nodes[i].name,
                self.node_input_facts(i).unwrap(),
                self.node_output_facts(i).unwrap(),
            )?;
            if self.nodes[i].inputs.len() > 2 {
                writeln!(
                    fmt,
                    "                                               |   * inputs: {}",
                    self.nodes[i].inputs.iter().map(|s| format!("{:?}", s)).join(", ")
                )?;
            }
            if self.nodes[i].outputs.len() > 1
                || self.outlet_successors((i, 0).into()).len() > 2
                || (self.outlet_label(i.into()).is_some()
                    && self.outlet_label(i.into()).unwrap() != self.nodes[i].name)
            {
                for o in 0..self.nodes[i].outputs.len() {
                    if self.outlet_successors((i, o).into()).len() > 0 {
                        writeln!(
                                    fmt,
                                    "                                               |   * output #{}: {} {}",
                                    o,
                                    self.outlet_label((i, o).into()).unwrap_or(""),
                                    self.outlet_successors((i, o).into())
                                    .iter()
                                    .map(|s| format!("{:?}", s))
                                    .join(", "),
                                    )?;
                    }
                }
            }
        }
        writeln!(fmt, "outputs: {}", self.outputs.iter().map(|o| format!("{:?}", o)).join(", "))?;
        Ok(())
    }
}

impl<F, O> Graph<F, O>
where
    F: Fact + Clone + 'static + std::hash::Hash + for<'a> std::convert::From<&'a F>,
    O: std::fmt::Display
        + std::fmt::Debug
        + Clone
        + AsRef<dyn Op>
        + AsMut<dyn Op>
        + Clone
        + 'static
        + std::hash::Hash
        + for<'a> std::convert::From<&'a O>,
    Graph<F, O>: SpecialOps<F, O>,
{
    #[cfg(debug_assertions)]
    pub fn check_compact(&self) -> TractResult<()> {
        let order = self.eval_order()?;
        let useless_sources = self
            .input_outlets()?
            .iter()
            .filter(|io| {
                self.outlet_successors(**io).len() == 0
                    && !self.output_outlets().unwrap().contains(io)
            })
            .count();
        if order.len() + useless_sources != self.nodes.len() {
            bail!(
                "Eval order is {} long, nodes are {}, including {} unused sources",
                order.len(),
                self.nodes.len(),
                useless_sources
            );
        }
        if (0..order.len()).any(|ix| order[ix] != ix) {
            bail!("eval order is not trivial");
        }
        let mut seen = std::collections::HashSet::new();
        for (ix, n) in self.nodes.iter().enumerate() {
            if ix != n.id {
                bail!("Invalid node id: position is {}, node is {}", ix, n);
            }
            if seen.contains(&n.name) {
                eprintln!("{}", self);
                bail!("duplicate name {}", n.name);
            }
            seen.insert(&n.name);
        }
        Ok(())
    }
More examples
Hide additional examples
src/ops/change_axes.rs (line 665)
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
pub fn change_axes(
    model: &TypedModel,
    change: &AxisChange,
    locked: &[OutletId],
    bounds: &[TVec<OutletId>],
) -> TractResult<Option<(TypedModelPatch, TVec<(InOut, AxisOp)>)>> {
    trace!("Considering change {:?}", change);
    let mut todo_changes = vec![(change.clone(), None)];
    let mut changed_wires = HashMap::new();
    changed_wires.insert(change.outlet, change.op.clone());
    let mut changed_ops: HashMap<usize, Box<dyn TypedOp>> = HashMap::new();
    while let Some((c, emitter)) = todo_changes.pop() {
        let outlets = if let Some(group) = bounds.iter().find(|b| b.contains(&c.outlet)) {
            group.clone()
        } else {
            tvec![c.outlet]
        };
        for outlet in outlets {
            if locked.contains(&outlet) {
                trace!("  Change {:?} blocked by locked interface {:?}", change, outlet);
                return Ok(None);
            }
            let mut nodes = vec![(outlet.node, InOut::Out(outlet.slot))];
            for inlet in model.outlet_successors(outlet) {
                nodes.push((inlet.node, InOut::In(inlet.slot)));
            }
            for (node_id, io) in nodes {
                if Some(node_id) == emitter {
                    continue;
                }
                let node = model.node(node_id);
                let more = node
                    .op
                    .change_axes(model, node, io, &c.op)
                    .with_context(|| format!("Propagating {:?} to node {}", change, node))?;
                if more.is_none() {
                    trace!("    Propagation of {:?} blocked by {}", change, node);
                    return Ok(None);
                }
                let AxisChangeConsequence { substitute_op, wire_changes } = more.unwrap();
                trace!("    Change {:?} enters {} from {:?}", c.op, node, io);
                trace!("       propagates as {:?}", wire_changes);
                if let Some(op) = substitute_op {
                    trace!("       replace op by {:?}", op);
                    changed_ops.insert(node.id, op);
                }
                for (wire, op) in wire_changes.into_iter() {
                    let outlet = wire.as_outlet(node);
                    match changed_wires.entry(outlet) {
                        Entry::Vacant(entry) => {
                            trace!("         {:?} {:?} change on {:?} is new", wire, op, outlet);
                            entry.insert(op.clone());
                            todo_changes.push((AxisChange { outlet, op }, Some(node_id)));
                        }
                        Entry::Occupied(previous) => {
                            if *previous.get() == op {
                                trace!(
                                    "         {:?} {:?} change on {:?} already done",
                                    wire,
                                    op,
                                    outlet
                                );
                            } else {
                                trace!(
                                    "         {:?} {:?} change on {:?} conflicting with {:?}. Blocked.",
                                    wire,
                                    op,
                                    outlet,
                                    previous
                                );
                                return Ok(None);
                            }
                        }
                    }
                }
            }
        }
    }
    trace!("Translating {:?} to patch", change);
    let mut patch = TypedModelPatch::new(format!("{:?}", change));
    let mut replaced_wires: HashMap<OutletId, OutletId> = HashMap::default();
    let nodes_to_replace = changed_wires
        .keys()
        .map(|o| o.node)
        .chain(changed_ops.keys().copied())
        .collect::<std::collections::HashSet<usize>>();
    for node_id in model.eval_order()? {
        let node = model.node(node_id);
        if nodes_to_replace.contains(&node_id) {
            let mut inputs = tvec!();
            for orig in &node.inputs {
                let tgt = replaced_wires
                    .entry(*orig)
                    .or_insert_with(|| patch.tap_model(model, *orig).unwrap());
                inputs.push(*tgt);
            }
            let op: Box<dyn TypedOp> =
                changed_ops.get(&node_id).cloned().unwrap_or_else(|| node.op.clone());
            let new_wires = patch.wire_node(&node.name, op, &inputs)?;
            if new_wires.len() == 1
                && patch.node(new_wires[0].node).op_is::<crate::ops::source::TypedSource>()
            {
                patch.inputs.insert(new_wires[0].node, node_id);
            }
            for (ix, w) in new_wires.iter().enumerate() {
                replaced_wires.insert((node_id, ix).into(), *w);
            }
        } else {
            for orig in &node.inputs {
                if let Some(replacement) = replaced_wires.get(orig) {
                    patch.shunt_outside(model, *orig, *replacement)?;
                }
            }
        }
    }
    for output in model.output_outlets()? {
        if let Some(replacement) = replaced_wires.get(output) {
            unsafe {
                patch.shunt_outside_unchecked(*output, *replacement)?;
            }
        }
    }
    let mut interface_change = tvec!();
    for (ix, input) in model.input_outlets()?.iter().enumerate() {
        if let Some(change) = changed_wires.get(input) {
            interface_change.push((InOut::In(ix), change.clone()));
        }
    }
    for (ix, output) in model.output_outlets()?.iter().enumerate() {
        if let Some(change) = changed_wires.get(output) {
            interface_change.push((InOut::Out(ix), change.clone()));
        }
    }
    debug_assert!(
        patch.model.nodes.iter().map(|n| &n.name).collect::<std::collections::HashSet<_>>().len()
            == patch.model.nodes.len()
    );
    Ok(Some((patch, interface_change)))
}
Examples found in repository?
src/ops/scan/mir.rs (line 392)
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
    fn declutter_pull_constant_outputs(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
        for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
            if let Some(slot) = mapping.last_value_slot {
                if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
                    let inner_node = self.body.output_outlets()?[model_output_ix].node;
                    let inner_node = self.body.node(inner_node);
                    let mut patch = TypedModelPatch::new(format!("Extract const node {}", inner_node));
                    let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
                    patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
                    return Ok(Some(patch));
                }
            }
        }
        Ok(None)
    }
More examples
Hide additional examples
src/optim/prop_const.rs (line 36)
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    fn next(
        &mut self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut patch = TypedModelPatch::default();
        for n in model.eval_order()? {
            let node = model.node(n);
            if node.op.is_stateless() && !node.op_is::<Const>() {
                if let Some(inputs) = model
                    .node_input_facts(n)?
                    .iter()
                    .map(|f| f.konst.clone().map(|t| t.into_tvalue()))
                    .collect()
                {
                    match node.op.eval(inputs) {
                        Ok(res) => {
                            for (ix, output) in res.into_iter().enumerate() {
                                let mut name = node.name.clone();
                                if ix > 0 {
                                    name = format!("{}.{}", name, ix);
                                }
                                let wire = patch.add_const(name, output.into_arc_tensor())?;
                                patch.shunt_outside(model, (n, ix).into(), wire)?;
                            }
                        }
                        Err(e) => {
                            if !e.root_cause().is::<UndeterminedSymbol>() {
                                Err(e).with_context(|| {
                                    format!("Eager eval {} during optimisation", model.node(n))
                                })?;
                            }
                        }
                    }
                }
            }
        }
        Ok(Some(patch).filter(|p| p.nodes.len() > 0))
    }
src/ops/matmul/mir_quant_unary.rs (line 39)
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        ensure!(inputs[0].rank() == self.a.rank(), "Rank mismatch {:?} vs {:?}", inputs[0], self.a);

        let mut model = TypedModel::default();
        let t_a = self.a.offset_u8_as_i8();
        let a = model.add_const("source_a", self.a.clone())?;
        let b = model.add_const("source_b", inputs[0].clone().into_arc_tensor())?;
        let bias = if let Some(bias) = self.bias.clone() {
            Some(model.add_const("source_bias", bias)?)
        } else {
            None
        };

        let mut input_outlets = tvec![a];
        for (i, t) in inputs.iter().enumerate().skip(1) {
            input_outlets
                .push(model.add_const(format!("source_{}", i), t.clone().into_arc_tensor())?)
        }

        let mut params = self.params.as_outlet_ids(
            &mut model,
            "qmatmul_unary",
            &input_outlets,
            self.a.datum_type(),
            inputs[0].datum_type(),
            self.output_type,
        )?;
        let a = wire_offset_u8_as_i8(&mut model, "adhoc", a, "a", &mut params[0], "a0")?;
        let b = wire_offset_u8_as_i8(&mut model, "adhoc", b, "b", &mut params[2], "b0")?;

        let new_op = MatMulUnary { a: t_a, axes: self.axes };
        let result = model.wire_node("adhoc.matmul", new_op, &[b])?[0];
        let result = wire_matmul_quant(
            &mut model,
            "adhoc",
            a,
            b,
            bias,
            self.axes,
            result,
            self.output_type,
            &params,
        )?;
        model.set_output_outlets(&[result])?;
        model.into_runnable()?.run(tvec![])
    }
}

impl TypedOp for QMatMulUnary {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        if inputs.len() != 1 + self.params.input_count() {
            bail!(
                "Inconsistent q matmul unary. expects {} inputs, got {}",
                1 + self.params.input_count(),
                inputs.len()
            );
        }
        if inputs[0].rank() != self.a.rank() {
            bail!("Inconsistent matmul between {:?} and {:?} (rank mismatch)", inputs[0], self.a);
        }
        let (_m, _k, _n, c_shape) = compute_shape(
            &self.a.shape().iter().map(|d| d.to_dim()).collect::<TVec<_>>(),
            &inputs[0].shape,
            self.axes,
        )?;

        #[allow(clippy::comparison_chain)]
        if let Some(bias) = &self.bias {
            if bias.rank() > 1 {
                anyhow::bail!("Bias must be either scalar or vector (rank 0 or 1).");
            } else if bias.rank() == 1 {
                let expected_len = c_shape[self.axes.c_m].to_usize()?;
                anyhow::ensure!(
                    bias.len() == expected_len,
                    "got: {:?} expected len: {:?}",
                    bias,
                    expected_len
                );
            };
        }

        Ok(tvec!(self.output_type.fact(c_shape)))
    }

    fn invariants(&self, inputs: &[&TypedFact], outputs: &[&TypedFact]) -> TractResult<Invariants> {
        /*
        dbg!(inputs);
        dbg!(&self.params);
        */
        // FIXME: why ?
        if self.params.iter().any(|qp| match qp.1 {
            QParamKind::Attr(t) => t.len() > 1,
            QParamKind::FromInput(ix) => !inputs[*ix].shape.volume().is_one(),
            QParamKind::FromQType => false,
        }) {
            Ok(Invariants::none())
        } else {
            let mut invs =
                super::mir_unary::mir_unary_invariants(inputs[0], outputs[0], self.axes)?;
            for axis in &mut invs.axes {
                axis.inputs.extend(std::iter::repeat(None).take(inputs.len() - 1));
            }
            Ok(invs)
        }
    }

    fn change_axes(
        &self,
        model: &TypedModel,
        node: &TypedNode,
        io: InOut,
        change: &AxisOp,
    ) -> TractResult<Option<AxisChangeConsequence>> {
        if let Some((a, axes, wire_changes)) =
            super::mir_unary::mir_unary_change_axes(model, node, io, change, &self.axes, &self.a)?
        {
            let op = Self { axes, a: a.into_arc_tensor(), ..self.clone() };
            Ok(Some(AxisChangeConsequence { substitute_op: Some(Box::new(op)), wire_changes }))
        } else {
            Ok(None)
        }
    }

    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        use crate::ops::array::TypedConcat;
        if let Some(concat) = model.nodes()[node.inputs[0].node].op().downcast_ref::<TypedConcat>()
        {
            let mut patch = TypedModelPatch::new("split over k-concatenated input");
            let k_axis = self.axes.a_k;
            if concat.axis == self.axes.b_k {
                let concat_node = model.node(node.inputs[0].node);
                let offsets = concat
                    .offsets(&model.node_input_facts(concat_node.id)?)?
                    .iter()
                    .map(|x| x.to_usize())
                    .collect::<TractResult<Vec<usize>>>()?;
                let mut wires = vec![];
                let mut params_for_split = self.params.clone();
                params_for_split.a_scale = tensor0(1.0f32).into();
                params_for_split.b_scale = tensor0(1.0f32).into();
                params_for_split.c_scale = tensor0(1.0f32).into();
                params_for_split.c0 = tensor0(0i32).into();
                let input_outlets = node
                    .inputs
                    .iter()
                    .skip(1)
                    .map(|o| patch.tap_model(model, *o))
                    .collect::<TractResult<TVec<_>>>()?;
                let params_outlets = self.params.as_outlet_ids(
                    &mut patch,
                    &node.name,
                    &input_outlets,
                    self.a.datum_type(),
                    model.node_input_facts(node.id)?[0].datum_type,
                    self.output_type,
                )?;

                let scale = combine_scales(
                    &mut patch,
                    &node.name,
                    params_outlets[1],
                    params_outlets[3],
                    params_outlets[5],
                )?;
                let c0 = params_outlets[4];

                for (ix, input) in concat_node.inputs.iter().enumerate() {
                    let wire = patch.tap_model(model, *input)?;
                    let a = self.a.slice(k_axis, offsets[ix], offsets[ix + 1])?;
                    let wire = patch
                        .wire_node(
                            format!("{}.k-{}-{}", node.name, offsets[ix], offsets[ix + 1]),
                            Self {
                                a: a.into_arc_tensor(),
                                output_type: DatumType::I32,
                                bias: self.bias.clone().filter(|_| ix == 0),
                                params: params_for_split.clone(),
                                ..self.clone()
                            },
                            &[wire],
                        )
                        .context("wiring new matmulunary")?[0];
                    wires.push(wire)
                }
                let mut wire = wires[0];
                for (ix, w) in wires[1..].iter().enumerate() {
                    wire = patch.wire_node(
                        format!("{}.k-add-{}", node.name, ix),
                        crate::ops::binary::TypedBinOp(Box::new(crate::ops::math::Add)),
                        &[wire, *w],
                    )?[0];
                }
                wire = requant(&mut patch, &node.name, wire, self.output_type, scale, c0)?;
                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
                return Ok(Some(patch));
            }
        }
        Ok(None)
    }

    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
        cost(self.a.shape(), &inputs[0].shape.to_tvec(), inputs[0].datum_type, self.axes)
    }

    fn codegen(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut patch = TypedModelPatch::default();
        let t_a = self.a.offset_u8_as_i8();

        if let Some((inputs, qp)) = self.params.inline_static(model, node)? {
            let mut patch = TypedModelPatch::new("inlining matmul quantized params");
            let inputs: Vec<OutletId> =
                inputs.iter().map(|i| patch.tap_model(model, *i)).collect::<TractResult<_>>()?;
            let op = Self {
                a: t_a,
                params: MatMulQParams { a0: qp.a0.offset_u8_as_i8(&patch, &inputs)?, ..qp },
                ..self.clone()
            };
            let wire = patch.wire_node(&node.name, op, &inputs)?;
            patch.shunt_outside(model, node.id.into(), wire[0])?;
            return Ok(Some(patch));
        }

        let a = patch.wire_node(
            format!("{}.a_const", &node.name),
            ops::konst::Const(self.a.clone()),
            &[],
        )?[0];
        let b = patch.tap_model(model, node.inputs[0])?;
        let bias = if let Some(bias) = self.bias.clone() {
            Some(patch.add_const(format!("{}.bias_const", &node.name), bias)?)
        } else {
            None
        };
        let mut input_outlets = tvec![a];
        for i in node.inputs.iter().skip(1) {
            input_outlets.push(patch.tap_model(model, *i)?)
        }
        let mut params = self.params.as_outlet_ids(
            &mut patch,
            &node.name,
            &input_outlets,
            self.a.datum_type(),
            model.node_input_facts(node.id)?[0].datum_type,
            self.output_type,
        )?;

        let a = wire_offset_u8_as_i8(&mut patch, &node.name, a, "a", &mut params[0], "a0")?;
        let b = wire_offset_u8_as_i8(&mut patch, &node.name, b, "b", &mut params[2], "b0")?;

        let new_op = MatMulUnary { a: t_a, axes: self.axes };
        let result = patch.wire_node(format!("{}.matmul", &node.name), new_op, &[b])?[0];
        let result = wire_matmul_quant(
            &mut patch,
            &node.name,
            a,
            b,
            bias,
            self.axes,
            result,
            self.output_type,
            &params,
        )?;
        patch.shunt_outside(model, node.id.into(), result)?;
        Ok(Some(patch))
    }
src/ops/matmul/mir_quant.rs (line 217)
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
    pub fn as_outlet_ids(
        &self,
        model: &mut TypedModel,
        node_name: &str,
        inputs_wires: &[OutletId],
        a_dt: DatumType,
        b_dt: DatumType,
        c_dt: DatumType,
    ) -> TractResult<TVec<OutletId>> {
        let mut params_outlets = tvec!();
        for (mut params, dt) in self.iter().chunks(2).into_iter().zip([a_dt, b_dt, c_dt].iter()) {
            if let Some(qp) = dt.qparams() {
                let (x0_name, x0) = params.next().unwrap();
                let (x_scale_name, x_scale) = params.next().unwrap();
                ensure!(
                    (matches!(x0, QParamKind::FromQType)
                        || x0 == &QParamKind::Attr(rctensor0(qp.zp_scale().0)))
                        && (matches!(x_scale, QParamKind::FromQType)
                            || x_scale == &QParamKind::Attr(rctensor0(qp.zp_scale().1))),
                );
                let (zp, scale) = qp.zp_scale();
                let zp = tensor0(zp);
                let zp = model.add_const(format!("{}.{}", node_name, x0_name), zp)?;
                let scale = tensor0(scale);
                let scale = model.add_const(format!("{}.{}", node_name, x_scale_name), scale)?;
                params_outlets.push(zp);
                params_outlets.push(scale)
            } else {
                for (param_name, param) in params {
                    match param {
                        QParamKind::Attr(t) => params_outlets.push(
                            model.add_const(format!("{}.{}", node_name, param_name), t.clone())?,
                        ),
                        QParamKind::FromInput(i) => params_outlets.push(inputs_wires[*i]),
                        QParamKind::FromQType => {
                            bail!("Param {} has no quantization parameters", param_name)
                        }
                    }
                }
            }
        }
        Ok(params_outlets)
    }
}

#[derive(Debug, Clone, new, Hash)]
pub struct QMatMul {
    pub axes: MatMulAxes,
    pub output_type: DatumType,
    pub params: MatMulQParams,
}

impl_dyn_hash!(QMatMul);

impl Op for QMatMul {
    fn name(&self) -> Cow<str> {
        "QMatMul".into()
    }

    op_as_typed_op!();
}

impl EvalOp for QMatMul {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        ensure!(
            inputs[0].rank() == inputs[1].rank(),
            "Rank mismatch {:?} vs {:?}",
            inputs[0],
            inputs[1]
        );

        let mut model = TypedModel::default();
        let a = model.add_const("source_a", inputs[0].clone().into_arc_tensor())?;
        let b = model.add_const("source_b", inputs[1].clone().into_arc_tensor())?;
        let bias = model.add_const("source_bias", inputs[2].clone().into_arc_tensor())?;

        let mut input_outlets = tvec![a, b, bias];
        for (i, t) in inputs.iter().enumerate().skip(3) {
            input_outlets
                .push(model.add_const(format!("source_{}", i), t.clone().into_arc_tensor())?)
        }

        let mut params = self.params.as_outlet_ids(
            &mut model,
            "qmatmul_unary",
            &input_outlets,
            inputs[0].datum_type(),
            inputs[1].datum_type(),
            self.output_type,
        )?;

        let a = wire_offset_u8_as_i8(&mut model, "adhoc", a, "a", &mut params[0], "a0")?;
        let b = wire_offset_u8_as_i8(&mut model, "adhoc", b, "b", &mut params[2], "b0")?;

        let new_op = MatMul { axes: self.axes };
        let result = model.wire_node("adhoc.matmul", new_op, &[a, b])?[0];
        let result = wire_matmul_quant(
            &mut model,
            "adhoc",
            a,
            b,
            Some(bias),
            self.axes,
            result,
            self.output_type,
            &params,
        )?;
        model.set_output_outlets(&[result])?;
        model.into_runnable()?.run(tvec![])
    }
}

impl TypedOp for QMatMul {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        if inputs.len() != 3 + self.params.input_count() {
            bail!(
                "Inconsistent q matmul. expects {} inputs, got {}",
                3 + self.params.input_count(),
                inputs.len()
            );
        }
        if inputs[0].rank() != inputs[1].rank() {
            bail!(
                "Inconsistent matmul between {:?} and {:?} (rank mismatch)",
                inputs[0],
                inputs[1]
            );
        }
        let (_m, _k, _n, c_shape) = compute_shape(&inputs[0].shape, &inputs[1].shape, self.axes)?;

        let bias = &inputs[2];
        #[allow(clippy::comparison_chain)]
        if bias.rank() > 1 {
            anyhow::bail!("Bias must be either scalar or vector (rank 0 or 1).");
        } else if bias.rank() == 1 {
            let expected_len = &c_shape[self.axes.c_m];
            anyhow::ensure!(
                &bias.shape[0] == expected_len,
                "got: {:?} expected len: {:?}",
                bias,
                expected_len
            );
        };

        Ok(tvec!(self.output_type.fact(c_shape)))
    }

    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let a_fact = model.outlet_fact(node.inputs[0])?;
        let b_fact = model.outlet_fact(node.inputs[1])?;
        let bias_fact = model.outlet_fact(node.inputs[2])?;

        if bias_fact.konst.is_none() {
            return Ok(None);
        }

        let konst_ix = if a_fact.konst.is_some() {
            0
        } else if b_fact.konst.is_some() {
            1
        } else {
            return Ok(None);
        };

        let flip = konst_ix == 1;
        let konst = model.outlet_fact(node.inputs[konst_ix])?.konst.as_ref().unwrap();
        let bias = model.outlet_fact(node.inputs[2])?.konst.clone().unwrap();

        let inputs: Vec<_> = node
            .inputs
            .iter()
            .enumerate()
            .filter_map(|(i, out_id)| if i == konst_ix || i == 2 { None } else { Some(*out_id) })
            .collect();

        let new_params = {
            let mut qp = self.params.clone();
            //compensate for the removed parameter
            for (_, a) in qp.iter_mut() {
                if let QParamKind::FromInput(i) = a {
                    *i -= 2
                }
            }
            if flip {
                MatMulQParams {
                    a0: qp.b0,
                    a_scale: qp.b_scale,
                    b0: qp.a0,
                    b_scale: qp.a_scale,
                    ..qp
                }
            } else {
                qp
            }
        };

        let axes = if flip {
            MatMulAxes {
                a_m: self.axes.b_n,
                a_k: self.axes.b_k,
                b_n: self.axes.a_m,
                b_k: self.axes.a_k,
                c_m: self.axes.c_n,
                c_n: self.axes.c_m,
            }
        } else {
            self.axes
        };

        TypedModelPatch::replace_single_op(
            model,
            node,
            &inputs,
            QMatMulUnary::new(
                konst.clone(),
                // if bias is uniformly zero, it can be discarded
                Some(bias).filter(|b| {
                    b.as_uniform()
                        .map(|b| b.cast_to_scalar::<f32>().unwrap() != 0.0)
                        .unwrap_or(true)
                }),
                axes,
                self.output_type,
                new_params,
            ),
        )
        .map(Some)
    }

    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
        cost(
            &inputs[0].shape.to_tvec(),
            &inputs[1].shape.to_tvec(),
            inputs[0].datum_type,
            self.axes,
        )
    }

    fn codegen(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut patch = TypedModelPatch::default();

        if let Some((inputs, qp)) = self.params.inline_static(model, node)? {
            let mut patch = TypedModelPatch::new("inlining matmul quantized params");
            let inputs: Vec<OutletId> =
                inputs.iter().map(|i| patch.tap_model(model, *i)).collect::<TractResult<_>>()?;
            let op = Self { params: qp, ..self.clone() };
            let wire = patch.wire_node(&node.name, op, &inputs)?;
            patch.shunt_outside(model, node.id.into(), wire[0])?;
            return Ok(Some(patch));
        }

        let a = patch.tap_model(model, node.inputs[0])?;
        let b = patch.tap_model(model, node.inputs[1])?;
        let bias = patch.tap_model(model, node.inputs[2])?;

        let mut input_outlets = tvec![a, b, bias];
        for i in node.inputs.iter().skip(3) {
            input_outlets.push(patch.tap_model(model, *i)?)
        }
        let mut params = self.params.as_outlet_ids(
            &mut patch,
            &node.name,
            &input_outlets,
            model.node_input_facts(node.id)?[0].datum_type,
            model.node_input_facts(node.id)?[1].datum_type,
            self.output_type,
        )?;

        let a = wire_offset_u8_as_i8(&mut patch, &node.name, a, "a", &mut params[0], "a0")?;
        let b = wire_offset_u8_as_i8(&mut patch, &node.name, b, "b", &mut params[2], "b0")?;

        let new_op = MatMul { axes: self.axes };
        let result = patch.wire_node(format!("{}.matmul", &node.name), new_op, &[a, b])?[0];
        let result = wire_matmul_quant(
            &mut patch,
            &node.name,
            a,
            b,
            Some(bias),
            self.axes,
            result,
            self.output_type,
            &params,
        )?;
        patch.shunt_outside(model, node.id.into(), result)?;
        Ok(Some(patch))
    }

    as_op!();
}

/// Wires the offsetting of a matrix and zero point node.
///
/// Only wires nodes of u8 type and leaves nodes of different type untouched.
pub(crate) fn wire_offset_u8_as_i8(
    model: &mut TypedModel,
    model_name: &str,
    matrix: OutletId,
    matrix_name: &str,
    zero_point: &mut OutletId,
    zero_point_name: &str,
) -> TractResult<OutletId> {
    let fact = model.outlet_fact(matrix)?;
    if let DatumType::U8 = fact.datum_type.unquantized() {
        match model.outlet_fact(*zero_point)?.datum_type.unquantized() {
            DatumType::U8 => {
                *zero_point = model.wire_node(
                    format!("{}.offset_{}_as_i8", model_name, zero_point_name),
                    ops::quant::offset_u8_as_i8(),
                    &[*zero_point],
                )?[0];
            }
            DatumType::I32 => {
                let zp_rank = model.outlet_fact(*zero_point)?.rank();
                let cst = model.add_const(
                    format!("{}.offset_{}_as_i8.min", model_name, zero_point_name),
                    tensor0(-128i32).broadcast_into_rank(zp_rank)?.into_arc_tensor(),
                )?;
                *zero_point = model.wire_node(
                    format!("{}.offset_{}_as_i8", model_name, zero_point_name),
                    ops::math::add(),
                    &[*zero_point, cst],
                )?[0];
            }
            _ => (),
        }
        Ok(model.wire_node(
            format!("{}.offset_{}_as_i8", model_name, matrix_name),
            ops::quant::offset_u8_as_i8(),
            &[matrix],
        )?[0])
    } else {
        Ok(matrix)
    }
}

#[allow(clippy::too_many_arguments)]
pub(crate) fn wire_matmul_quant(
    model: &mut TypedModel,
    name: &str,
    a: OutletId,
    b: OutletId,
    bias: Option<OutletId>,
    axes: MatMulAxes,
    mut result: OutletId,
    output_type: DatumType,
    params: &[OutletId],
) -> TractResult<OutletId> {
    let b_fact = model.outlet_fact(b)?.clone();
    // TODO: assumed c_rank == b_rank (== a_rank)

    if let Some(mut bias) = bias {
        // bias is scalar -> ok
        // bias is vec, m is right in C -> broadcast will add left side axes to bias
        // bias is vec, m is not right in C -> we must append in C axes to the right to align them
        let bias_rank = model.outlet_fact(bias)?.rank();
        if bias_rank == 1 && axes.c_m < b_fact.rank() - 1 {
            for i in 0..(b_fact.rank() - axes.c_m - 1) {
                bias = model.wire_node(
                    format!("{}.axis_rank_fix.{}", name, i),
                    AxisOp::Add(bias_rank + i),
                    &[bias],
                )?[0]
            }
        }
        result = wire_with_rank_broadcast(
            &format!("{}.add_bias", &name),
            model,
            ops::math::add(),
            &[result, bias],
        )?[0];
    }

    let k = model.outlet_fact(a)?.shape[axes.a_k].clone();

    let abc_scale = combine_scales(model, name, params[1], params[3], params[5])?;

    let a_i32 =
        model.wire_node(format!("{}.a_as_i32", name), ops::cast::cast(i32::datum_type()), &[a])?[0];
    let b_i32 =
        model.wire_node(format!("{}.b_as_i32", name), ops::cast::cast(i32::datum_type()), &[b])?[0];
    let sum_a = model.wire_node(
        format!("{}.sum_a", name),
        ops::nn::Reduce::new(tvec!(axes.a_k), ops::nn::Reducer::Sum),
        &[a_i32],
    )?[0];
    let sum_a =
        model.wire_node(format!("{}.sum_a_reduced", name), AxisOp::Rm(axes.a_k), &[sum_a])?[0];
    let sum_b = model.wire_node(
        format!("{}.sum_b", name),
        ops::nn::Reduce::new(tvec!(axes.b_k), ops::nn::Reducer::Sum),
        &[b_i32],
    )?[0];
    let sum_b =
        model.wire_node(format!("{}.sum_b_reduced", name), AxisOp::Rm(axes.b_k), &[sum_b])?[0];
    let result = compensate_zero_points(
        model, name, result, k, params[0], params[2], sum_a, sum_b, axes.c_m, axes.c_n,
    )?;
    requant(model, name, result, output_type, abc_scale, params[4])
}

pub(crate) fn combine_scales(
    model: &mut TypedModel,
    name: &str,
    a_scale: OutletId,
    b_scale: OutletId,
    c_scale: OutletId,
) -> TractResult<OutletId> {
    let ab_scale = wire_with_rank_broadcast(
        &format!("{}.ab_scale", name),
        model,
        ops::math::mul(),
        &[a_scale, b_scale],
    )?[0];
    let abc_scale = wire_with_rank_broadcast(
        &format!("{}.abc_scales", name),
        model,
        ops::math::div(),
        &[ab_scale, c_scale],
    )?[0];
    Ok(abc_scale)
}

#[allow(clippy::too_many_arguments)]
pub(crate) fn compensate_zero_points(
    model: &mut TypedModel,
    name: &str,
    result: OutletId,
    k: TDim,
    a0: OutletId,
    b0: OutletId,
    sum_a: OutletId,
    sum_b: OutletId,
    m_axis: usize,
    n_axis: usize,
) -> TractResult<OutletId> {
    let input_shape = model.outlet_fact(result)?.shape.clone();
    let rank = model.outlet_fact(result)?.rank();

    debug_assert_eq!(model.outlet_fact(sum_a)?.rank(), rank - 1);
    debug_assert_eq!(model.outlet_fact(sum_b)?.rank(), rank - 1);

    // make sum_a into from a 1D vector to a vertical matrix, sum_b horizontal
    // switch shapes if c_trans
    let sum_a =
        model.wire_node(format!("{}.reshape_sum_a", name), AxisOp::Add(n_axis), &[sum_a])?[0];

    let sum_b =
        model.wire_node(format!("{}.reshape_sum_b", name), AxisOp::Add(m_axis), &[sum_b])?[0];

    debug_assert_eq!(
        model.outlet_fact(sum_a)?.shape[m_axis],
        model.outlet_fact(result)?.shape[m_axis]
    );
    debug_assert_eq!(
        model.outlet_fact(sum_b)?.shape[n_axis],
        model.outlet_fact(result)?.shape[n_axis]
    );

    let a0 =
        model.wire_node(format!("{}.cast_a0", name), ops::cast::cast(i32::datum_type()), &[a0])?[0];

    let b0 =
        model.wire_node(format!("{}.cast_b0", name), ops::cast::cast(i32::datum_type()), &[b0])?[0];

    let k = model.add_const(format!("{}.k", name), rctensor0(k))?;
    let k =
        model.wire_node(format!("{}.cast_k", name), ops::cast::cast(i32::datum_type()), &[k])?[0];

    let a0_sum_b = wire_with_rank_broadcast(
        &format!("{}.a0_sum_b", name),
        model,
        ops::math::mul(),
        &[a0, sum_b],
    )?[0];

    let b0_sum_a = wire_with_rank_broadcast(
        &format!("{}.b0_sum_a", name),
        model,
        ops::math::mul(),
        &[b0, sum_a],
    )?[0];

    let a0_k =
        wire_with_rank_broadcast(&format!("{}.a0_k", name), model, ops::math::mul(), &[a0, k])?[0];

    let a0_k_b0 = wire_with_rank_broadcast(
        &format!("{}.a0_k_b0", name),
        model,
        ops::math::mul(),
        &[a0_k, b0],
    )?[0];

    let result = wire_with_rank_broadcast(
        &format!("{}.minus_a0_B", &name),
        model,
        ops::math::sub(),
        &[result, a0_sum_b],
    )?[0];
    let result = wire_with_rank_broadcast(
        &format!("{}.minus_b0_A", &name),
        model,
        ops::math::sub(),
        &[result, b0_sum_a],
    )?[0];

    let result = wire_with_rank_broadcast(
        &format!("{}.plus_a0_k_b0", &name),
        model,
        ops::math::add(),
        &[result, a0_k_b0],
    )?[0];

    debug_assert_eq!(model.outlet_fact(result)?.shape, input_shape);
    Ok(result)
}

pub(crate) fn requant(
    model: &mut TypedModel,
    name: &str,
    wire: OutletId,
    dt: DatumType,
    scale: OutletId,
    zero_point: OutletId,
) -> TractResult<OutletId> {
    let wire = wire_with_rank_broadcast(
        &format!("{}.scale", name),
        model,
        ops::quant::scale(),
        &[scale, wire],
    )?[0];

    let zero_point = model.wire_node(
        format!("{}.cast_c0", name),
        ops::cast::cast(i32::datum_type()),
        &[zero_point],
    )?[0];

    let wire = wire_with_rank_broadcast(
        &format!("{}.zeropoint", name),
        model,
        ops::math::add(),
        &[wire, zero_point],
    )?[0];

    clamp_and_cast_to(model, name, dt, wire)
}

pub(crate) fn clamp_and_cast_to(
    model: &mut TypedModel,
    name: &str,
    dt: DatumType,
    wire: OutletId,
) -> TractResult<OutletId> {
    if dt == i32::datum_type() {
        return Ok(wire);
    }
    let rank = model.outlet_fact(wire)?.rank();
    let inf = dt
        .unquantized()
        .min_value()
        .cast_to_dt(DatumType::I32)?
        .into_owned()
        .broadcast_into_rank(rank)?
        .into_arc_tensor();
    let inf = model.add_const(format!("{}.min.const", name), inf)?;
    let sup = dt
        .unquantized()
        .max_value()
        .cast_to_dt(DatumType::I32)?
        .into_owned()
        .broadcast_into_rank(rank)?
        .into_arc_tensor();
    let sup = model.add_const(format!("{}.max.const", name), sup)?;
    let wire = model.wire_node(format!("{}.min", name), ops::math::min(), &[wire, sup])?;
    let wire = model.wire_node(format!("{}.max", name), ops::math::max(), &[wire[0], inf])?;
    let wire = model.wire_node(format!("{}.cast", name), ops::cast::cast(dt), &wire)?;
    Ok(wire[0])
}
src/ops/math/mod.rs (line 251)
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
fn declutter_mul(
    _op: &Mul,
    model: &TypedModel,
    node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
    if let Some(p) = declutter_neutral(model, node, 1, true).context("decluttering neutral")? {
        return Ok(Some(p));
    }
    if let Some(uniform) = crate::ops::binary::one_input_is_uniform(model, node)? {
        let var_fact = model.outlet_fact(uniform.var)?;
        if uniform.uni.cast_to_scalar::<f64>()? == 0.0 {
            let shapes =
                model.node_input_facts(node.id)?.iter().map(|f| &f.shape).collect::<TVec<_>>();
            let shape: ShapeFact =
                crate::broadcast::multi_broadcast(&shapes).context("Failed to broadcast")?.into();
            return Ok(Some(TypedModelPatch::rewire(
                model,
                &[],
                &[node.id.into()],
                &|patch, _| {
                    let scalar =
                        patch.add_const(format!("{}.zero", node.name), uniform.uni.clone())?;
                    let op = MultiBroadcastTo::new(shape.clone());
                    patch.wire_node(&node.name, op, &[scalar])
                },
            )?));
        }
        let dt = uniform.uni.datum_type();
        let integer = uniform.uni.cast_to_scalar::<i64>()?;
        if tensor0(integer)
            .cast_to_dt(uniform.uni.datum_type())?
            .close_enough(&uniform.uni, false)
            .is_ok()
            && dt.is_integer()
            && uniform.uni.cast_to_scalar::<i64>()?.count_ones() == 1
        {
            let shift = integer.trailing_zeros();
            return Ok(Some(TypedModelPatch::rewire(
                model,
                &[uniform.var],
                &[node.id.into()],
                &|patch, taps| {
                    let shift = patch.add_const(
                        format!("{}.shift", node.name),
                        tensor0(shift)
                            .cast_to_dt(dt)?
                            .into_owned()
                            .broadcast_into_rank(var_fact.rank())?,
                    )?;
                    patch.wire_node(&node.name, shift_left(), &[taps[0], shift])
                },
            )?));
        }
    }
    Ok(None)
}

fn declutter_div(
    _op: &Div,
    model: &TypedModel,
    node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
    if let Some(p) = declutter_neutral(model, node, 1, false)? {
        return Ok(Some(p));
    }
    if let &[p, q] = &*model.node_input_facts(node.id)? {
        if let Some(q) = &q.uniform {
            let dt = q.datum_type();
            if let Ok(integer) = q.cast_to_scalar::<i64>() {
                if tensor0(integer).cast_to_dt(dt)?.close_enough(q, false).is_ok()
                    && dt.is_integer()
                    && q.cast_to_scalar::<i64>()?.count_ones() == 1
                {
                    let shift = integer.trailing_zeros();
                    return Ok(Some(TypedModelPatch::rewire(
                        model,
                        &[node.inputs[0]],
                        &[node.id.into()],
                        &|patch, taps| {
                            let shift = patch.add_const(
                                format!("{}.shift", node.name),
                                tensor0(shift)
                                    .cast_to_dt(dt)?
                                    .into_owned()
                                    .broadcast_into_rank(p.rank())?,
                            )?;
                            patch.wire_node(&node.name, shift_right(), &[taps[0], shift])
                        },
                    )?));
                }
            }
            if dt.is_float() {
                return Ok(Some(TypedModelPatch::rewire(
                    model,
                    &node.inputs,
                    &[node.id.into()],
                    &|patch, taps| {
                        let q =
                            patch.wire_node(format!("{}-recip", node.name), recip(), &[taps[1]])?
                                [0];
                        patch.wire_node(&node.name, mul(), &[taps[0], q])
                    },
                )?));
            }
        }
    }
    Ok(None)
}
src/ops/cnn/conv/unary.rs (lines 138-145)
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
    fn kernel_offset_u8_as_i8(
        &self,
        inputs: &mut [OutletId],
        model: &mut TypedModel,
    ) -> TractResult<Option<Self>> {
        if let DatumType::U8 = self.kernel.datum_type().unquantized() {
            let new_op = Self {
                kernel: self.kernel.offset_u8_as_i8(),
                q_params: self
                    .q_params
                    .as_ref()
                    .map(|(dt, qp)| -> TractResult<_> {
                        let a0 = match &qp.a0 {
                            QParamKind::Attr(_) | QParamKind::FromQType => {
                                qp.a0.offset_u8_as_i8(model, &[])?
                            }
                            QParamKind::FromInput(i) => {
                                match model.outlet_fact(inputs[*i])?.datum_type.unquantized() {
                                    DatumType::U8 => {
                                        inputs[*i] = model.wire_node(
                                            format!(
                                                "{}.offset_{}_as_i8",
                                                model.node(inputs[*i].node).name,
                                                "a0"
                                            ),
                                            ops::quant::offset_u8_as_i8(),
                                            &[inputs[*i]],
                                        )?[0];
                                    }
                                    DatumType::I32 => {
                                        let cst = model.add_const(
                                            format!(
                                                "{}.offset_{}_as_i8.cst",
                                                &model.node(inputs[*i].node).name,
                                                "a0"
                                            ),
                                            rctensor0(-128i32),
                                        )?;
                                        inputs[*i] = model.wire_node(
                                            format!(
                                                "{}.offset_{}_as_i8",
                                                model.node(inputs[*i].node).name,
                                                "a0"
                                            ),
                                            ops::math::add(),
                                            &[inputs[*i], cst],
                                        )?[0];
                                    }
                                    _ => (),
                                }
                                QParamKind::FromInput(*i)
                            }
                        };
                        Ok((*dt, MatMulQParams { a0, ..qp.clone() }))
                    })
                    .transpose()?,
                ..self.clone()
            };
            Ok(Some(new_op))
        } else {
            Ok(None)
        }
    }

    fn bias_as_non_linear<T>(&self) -> TractResult<ArrayD<Vec<ProtoFusedSpec>>>
    where
        T: Datum + Copy,
    {
        let mut ops = Array1::from_elem(self.group, vec![]);

        if let Some(bias) = &self.bias {
            let bias = bias.cast_to::<T>()?;
            let bias = bias.as_slice::<T>()?;
            ops.iter_mut().zip(bias.chunks(self.output_channels() / self.group)).for_each(
                |(ops, bias)| {
                    ops.push(ProtoFusedSpec::BinPerRow(
                        rctensor1(bias).into(),
                        tract_linalg::mmm::BinOp::Add,
                    ));
                },
            )
        }
        let mut ops = ops.into_dyn();

        if self.group == 1 {
            ops.index_axis_inplace(Axis(0), 0);
        }
        if self.pool_spec.data_format.has_n() {
            ops.insert_axis_inplace(Axis(0));
        }
        Ok(ops)
    }

    pub unsafe fn wire_as_quant_im2col(
        &self,
        model: &mut TypedModel,
        name: &str,
        b_dt: DatumType,
        wires: &[OutletId],
    ) -> TractResult<OutletId> {
        use crate::ops::matmul::mir_quant as qmm;

        let c_dt = self.q_params.as_ref().unwrap().0;

        let params = self.q_params.as_ref().unwrap().1.as_outlet_ids(
            model,
            name,
            wires,
            self.kernel.datum_type(),
            b_dt,
            c_dt,
        )?;

        let a0 = params[0];
        let a_scale = params[1];
        let mut b0 = params[2];
        let b_scale = params[3];
        let c0 = params[4];
        let c_scale = params[5];

        let b = wire_offset_u8_as_i8(model, name, wires[0], "b", &mut b0, "b0")?;
        let b_fact = model.outlet_fact(b)?.clone();
        let (_, m, k, n, mmm) = self.compute_geo(&b_fact)?;
        let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;

        let abc_scale = qmm::combine_scales(model, name, a_scale, b_scale, c_scale)?;

        let im2col = model.wire_node(
            format!("{}.im2col", name),
            Im2Col::new(self.pool_spec.clone(), self.group, k, &b_fact.shape, mmm.clone())?,
            &[b, b0],
        )?[0];

        let a = self.kernel_as_group_o_ihw()?.into_tensor();
        let a = a.cast_to_dt(i32::datum_type())?;
        let a = a.to_array_view::<i32>()?;
        let mut sum_a = a.sum_axis(Axis(a.ndim() - 1));
        if self.group == 1 {
            sum_a.index_axis_inplace(Axis(0), 0);
        }

        if self.pool_spec.data_format.has_n() {
            sum_a.insert_axis_inplace(Axis(0));
        }
        let sum_a = model.add_const(format!("{}.sum_a", name), sum_a)?;

        let mut sum_b = model.wire_node(
            format!("{}.sum_b", name),
            super::QSumB { n: n.clone(), r: mmm.b_pack().panel_width(), k },
            &[im2col],
        )?[0];

        if self.group > 1 && self.pool_spec.data_format.c_is_last() {
            let has_n = self.pool_spec.data_format.has_n() as usize;
            sum_b = model.wire_node(
                format!("{}.transpose_sum_b", name),
                AxisOp::Move(has_n, 1 + has_n),
                &[sum_b],
            )?[0];
        }

        let b_dt = model.outlet_fact(b)?.datum_type;
        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
        let mut geometry = MatMulGeometry::from(SymbolicMatMulGeometry {
            b_datum_type: b_dt,
            m: m.to_dim(),
            k: k.to_dim(),
            n: n.clone(),
            mmm: mmm.clone(),
        });
        if n.to_usize().is_ok() {
            geometry = geometry.optimize_if(Some(&SymbolValues::default()))?;
        }
        let wire = self.wire_lir_matmatmul(
            model,
            name,
            im2col,
            mmm,
            i32::datum_type(),
            mmm_output_shape.clone().into(),
            m,
            k,
            geometry,
            c_axis,
            h_axis,
        )?;
        let has_n = self.pool_spec.data_format.has_n() as usize;
        let has_group = (self.group > 1) as usize;
        let (m_axis, n_axis) = if self.pool_spec.data_format.c_is_last() {
            (1 + has_group + has_n, has_n)
        } else {
            (has_group + has_n, 1 + has_n + has_group)
        };
        let wire = qmm::compensate_zero_points(
            model,
            name,
            wire,
            k.to_dim(),
            a0,
            b0,
            sum_a,
            sum_b,
            m_axis,
            n_axis,
        )?;

        let mut wire = qmm::requant(model, name, wire, c_dt, abc_scale, c0)?;
        if self.group > 1 {
            wire = model.wire_node(
                format!("{}.reshape_group", name),
                AxisOp::Reshape(
                    c_axis - 1,
                    mmm_output_shape[c_axis - 1..][..2].iter().map(|d| d.to_dim()).collect(),
                    tvec!((m * self.group).to_dim()),
                ),
                &[wire],
            )?[0];
        }
        let wire = Self::wire_geo_reshape(model, name, wire, &output_shape)?;
        Ok(wire)
    }

    pub unsafe fn wire_as_im2col_pair(
        &self,
        model: &mut TypedModel,
        name: &str,
        mut wire: OutletId,
    ) -> TractResult<OutletId> {
        let b_fact = model.outlet_fact(wire)?.clone();
        let b_dt = b_fact.datum_type;
        let c_dt = crate::ops::matmul::output_type(b_fact.datum_type);

        let output_shape = self.pool_spec.output_shape(&b_fact.shape)?;
        let (_, m, k, n, mmm) = self.compute_geo(model.outlet_fact(wire)?)?;
        let padding = model.add_const(format!("{}.b0", name), Tensor::zero_dt(b_dt, &[])?)?;

        wire = model.wire_node(
            format!("{}.im2col", name),
            Im2Col::new(self.pool_spec.clone(), self.group, k, &b_fact.shape, mmm.clone())?,
            &[wire, padding],
        )?[0];

        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&output_shape)?;
        let mut geometry = MatMulGeometry::from(SymbolicMatMulGeometry {
            b_datum_type: b_dt,
            m: m.to_dim(),
            k: k.to_dim(),
            n: n.clone(),
            mmm: mmm.clone(),
        });
        if n.to_usize().is_ok() {
            geometry = geometry.optimize_if(Some(&SymbolValues::default()))?;
        }
        let mut wire = self.wire_lir_matmatmul(
            model,
            name,
            wire,
            mmm,
            c_dt,
            mmm_output_shape.clone().into(),
            m.to_usize().unwrap(),
            k.to_usize().unwrap(),
            geometry,
            c_axis,
            h_axis,
        )?;

        if self.group > 1 {
            wire = model.wire_node(
                format!("{}.reshape_group", name),
                AxisOp::Reshape(
                    c_axis - 1,
                    mmm_output_shape[c_axis - 1..][..2].iter().map(|d| d.to_dim()).collect(),
                    tvec!((m * self.group).to_dim()),
                ),
                &[wire],
            )?[0];
        }
        let wire = Self::wire_geo_reshape(model, name, wire, &output_shape)?;
        Ok(wire)
    }

    fn mmm_output_shape<D: DimLike>(
        &self,
        output_shape: &BaseDataShape<D, TVec<D>>,
    ) -> TractResult<(TVec<D>, usize, usize)> {
        let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
        let shape: BaseDataShape<D, TVec<D>> = output_shape.fmt.from_n_c_hw(
            output_shape.n().cloned().unwrap_or_else(|| 1.into()),
            output_shape.c().clone(),
            tvec!(geo_collapsed_out),
        )?;
        let mut mmm_output_shape: TVec<D> = shape.shape.clone();
        let mut c_axis = shape.c_axis();
        let mut h_axis = shape.h_axis();
        if self.group > 1 {
            mmm_output_shape[shape.c_axis()] =
                mmm_output_shape[shape.c_axis()].clone() / self.group;
            mmm_output_shape.insert(shape.c_axis(), self.group.into());
            if self.group > 1 {
                if h_axis > c_axis {
                    h_axis += 1;
                }
                c_axis += 1;
            }
        }
        Ok((mmm_output_shape, c_axis, h_axis))
    }

    fn wire_geo_reshape<D: DimLike>(
        model: &mut TypedModel,
        name: &str,
        wire: OutletId,
        output_shape: &BaseDataShape<D, TVec<D>>,
    ) -> TractResult<OutletId> {
        let geo_collapsed_out: D = output_shape.hw_dims().iter().cloned().product();
        let wire = model.wire_node(
            name,
            AxisOp::Reshape(
                output_shape.h_axis(),
                tvec!(geo_collapsed_out.to_dim()),
                output_shape.hw_dims().iter().map(|d| d.to_dim()).collect(),
            ),
            &[wire],
        )?;
        Ok(wire[0])
    }

    pub unsafe fn wire_as_lazy_im2col(
        &self,
        model: &mut TypedModel,
        name: &str,
        mut wire: OutletId,
    ) -> TractResult<OutletId> {
        let mut b_fact = model.outlet_fact(wire)?.clone();
        let (geo, m, k, n, mmm) = self.compute_geo(&b_fact)?;
        let input_shape = b_fact.shape.as_concrete().unwrap().to_vec();
        let mut geo = geo.to_concrete(&input_shape)?.into_owned();
        let mut input_shape: DataShape = self.pool_spec.data_format.shape(input_shape.into())?;
        let padding = self.pool_spec.computed_padding(input_shape.hw_dims());
        if padding.iter().any(|axis| axis.pad_before != 0 || axis.pad_after != 0) {
            let mut pads = vec![(0, 0); b_fact.rank()];
            for (ix, ax) in padding.iter().enumerate() {
                pads[input_shape.h_axis() + ix] = (ax.pad_before, ax.pad_after);
            }
            let op = crate::ops::array::Pad {
                mode: crate::ops::array::PadMode::Constant(
                    Tensor::zero_scalar_dt(b_fact.datum_type)?.into_arc_tensor(),
                ),
                pads,
            };
            wire = model.wire_node(format!("{}.pad", name), op, &[wire])?[0];
            let valid_pool_spec =
                PoolSpec { padding: ops::cnn::PaddingSpec::Valid, ..self.pool_spec.clone() };
            b_fact = model.outlet_fact(wire)?.clone();
            let concrete_shape = b_fact.shape.as_concrete().unwrap();
            input_shape = valid_pool_spec.data_format.shape(concrete_shape.into())?;
            geo = valid_pool_spec
                .compute_geo(&b_fact.shape)?
                .to_concrete(concrete_shape)?
                .into_owned();
        }
        let c_dt = crate::ops::matmul::output_type(b_fact.datum_type);
        let c_stride = input_shape.c_stride();
        let size_of_b = b_fact.datum_type.size_of() as isize;
        let n_bytes_offsets: Vec<isize> =
            geo.patch.centers_offsets().into_iter().map(|x| x * size_of_b).collect();
        let k_bytes_offsets: Vec<isize> = (0..self.input_channels())
            .flat_map(|ici| {
                geo.patch
                    .standard_layout_data_field
                    .iter()
                    .map(move |x| (x + (ici * c_stride) as isize) * size_of_b)
            })
            .collect();
        let virtual_input = super::lazy_im2col::LazyIm2colSpec { n_bytes_offsets, k_bytes_offsets };
        let b_storage = mmm.b_virtual_input(Box::new(virtual_input), k);
        let (mmm_output_shape, c_axis, h_axis) = self.mmm_output_shape(&geo.output_shape)?;

        let geometry = MatMulGeometry::Concrete(ConcreteMatMulGeometry {
            m,
            k,
            n: n.to_usize().unwrap(),
            b_storage,
        });
        let wire = self.wire_lir_matmatmul(
            model,
            name,
            wire,
            mmm,
            c_dt,
            mmm_output_shape.into(),
            m.to_usize().unwrap(),
            k,
            geometry,
            c_axis,
            h_axis,
        )?;

        let wire = Self::wire_geo_reshape(model, name, wire, &geo.output_shape)?;
        Ok(wire)
    }

    #[allow(clippy::type_complexity)]
    fn compute_geo(
        &self,
        input_fact: &TypedFact,
    ) -> TractResult<(PoolGeometry, usize, usize, TDim, Box<dyn MatMatMul>)> {
        let a_dt = self.kernel.datum_type();
        let b_dt = input_fact.datum_type;
        let c_dt = crate::ops::matmul::output_type(b_dt);

        let geo = self.pool_spec.compute_geo(&input_fact.shape)?;

        trace!("output channels: {:?}", self.output_channels());
        let m = self.output_channels() / self.group;
        let k = self.kernel.len() / self.output_channels();
        let n: TDim =
            self.pool_spec.output_shape(&input_fact.shape)?.hw_dims().iter().cloned().product();

        let mmm = tract_linalg::ops()
            .mmm(a_dt, b_dt, c_dt, Some(m), Some(k), n.to_usize().ok())
            .with_context(|| format!("No multiplier for {:?}x{:?} to {:?}", a_dt, b_dt, c_dt,))?;

        Ok((geo, m, k, n, mmm))
    }

    #[allow(clippy::too_many_arguments)]
    fn wire_lir_matmatmul(
        &self,
        model: &mut TypedModel,
        name: &str,
        wire: OutletId,
        mmm: Box<dyn MatMatMul>,
        c_datum_type: DatumType,
        mmm_output_shape: ShapeFact,
        m: usize,
        k: usize,
        geometry: MatMulGeometry,
        c_m_axis: usize,
        c_n_axis: usize,
    ) -> TractResult<OutletId> {
        let kernels = self.kernel_as_packed_as(&mmm.a_pack(), k, m)?;
        let shape = kernels.shape();
        let mut fused_ops = dispatch_copy!(Self::bias_as_non_linear(mmm.internal_type())(self))?;
        for fo in &mut fused_ops {
            fo.push(ProtoFusedSpec::Store);
        }
        let mut iter = kernels.iter().cloned().zip(fused_ops.iter().cloned());
        let micro_ops = ArrayD::from_shape_fn(shape, |_| iter.next().unwrap());

        let wire = model.wire_node(
            format!("{}.matmatmul", name),
            LirMatMulUnary {
                c_fact: c_datum_type.fact(mmm_output_shape.clone()),
                micro_ops,
                c_m_axis,
                c_n_axis,
                c_final_shape: mmm_output_shape,
                reshape_post: vec![],
                geometry,
                mmm,
            },
            &[wire],
        )?[0];
        Ok(wire)
    }

    pub fn to_depth_wise<T>(&self, input: &TypedFact) -> TractResult<Box<dyn TypedOp>>
    where
        T: Datum + Clone + ::ndarray::LinalgScalar + PartialEq + Sum,
    {
        let input_shape = input.shape.as_concrete().unwrap();
        let ConcretePoolGeometry { input_shape, patch, output_shape } =
            self.pool_spec.compute_geo(&input.shape)?.to_concrete(input_shape)?.into_owned();
        let bias = if let Some(b) = &self.bias {
            b.clone()
        } else {
            Tensor::zero::<T>(&[*input_shape.c()])?.into_arc_tensor()
        };
        let op = DepthWise::new(
            patch,
            input_shape,
            output_shape,
            self.kernel_as_group_o_ihw().context("in kernel_as_group_o_ihw")?,
            bias,
        );
        Ok(Box::new(op))
    }

    fn declutter_stride_slice_to_downsample(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let input_fact = model.outlet_fact(node.inputs[0])?;
        let spatial_rank = self.kernel.rank() - 2;
        if let Some(axis) = (0..spatial_rank).find(|&ax| {
            self.pool_spec.stride(ax) > 1
                && (self.pool_spec.kernel_shape[ax] == 1
                    || (self.pool_spec.padding.valid_dim(ax, self.pool_spec.stride(ax) == 1)
                        && self.pool_spec.dilation(ax) % self.pool_spec.stride(ax) == 0))
        }) {
            let downsample_factor = self.pool_spec.stride(axis);
            let mut new_op = self.clone();
            if new_op.pool_spec.dilation(axis) > 1 {
                new_op.pool_spec.dilations.as_mut().unwrap()[axis] /= downsample_factor;
            }
            new_op.pool_spec.strides.as_mut().unwrap()[axis] /= downsample_factor;
            let mut patch = TypedModelPatch::default();
            let tap = patch.tap_model(model, node.inputs[0])?;
            let shape = self
                .pool_spec
                .data_format
                .shape(input_fact.shape.iter().collect::<TVec<TDim>>())?;
            let down = patch.wire_node(
                format!("{}.downsample.{}", node.name, axis),
                crate::ops::Downsample::new(axis + shape.h_axis(), downsample_factor as isize, 0),
                &[tap],
            )?;
            let id = patch.wire_node(&*node.name, new_op, &down)?[0];
            patch.shunt_outside(model, OutletId::new(node.id, 0), id)?;
            return Ok(Some(patch));
        }
        Ok(None)
    }

    fn declutter_as_matmul(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        use crate::ops::matmul::*;
        let input_fact = model.outlet_fact(node.inputs[0])?;
        let full_input_shape = input_fact.shape.to_tvec();
        let input_shape = self.pool_spec.data_format.shape(&full_input_shape)?;
        if input_shape.hw_rank() == 1
            && self.group == 1
            && self.pool_spec.stride(0) == 1
            && self.kernel.len() == self.input_channels() * self.output_channels()
        {
            let ci = self.input_channels();
            let co = self.output_channels();
            let ker = self.kernel.clone().into_tensor();
            let (a_shape, a_trans) = if self.kernel_fmt == KernelFormat::HWIO {
                ([ci, co], true)
            } else {
                ([co, ci], false)
            };
            let a = ker
                .into_shape(&a_shape)?
                .broadcast_into_rank(full_input_shape.len())?
                .into_arc_tensor();
            let trans_data = self.pool_spec.data_format == DataFormat::HWC
                || self.pool_spec.data_format == DataFormat::NHWC;
            let mut patch = TypedModelPatch::new("declutter_as_matmul");
            let a = patch.add_const(format!("{}.filters", &node.name), a)?;
            let mut inputs = node
                .inputs
                .iter()
                .map(|i| patch.tap_model(model, *i))
                .collect::<TractResult<TVec<_>>>()?;
            inputs.insert(0, a);
            let axes = MatMulAxes::default_for_rank(full_input_shape.len())
                .transposing(a_trans, trans_data, trans_data);
            // in Q case, the bias has to be injected inside the QMatMul (as it
            // must be added before requantization)
            let wire = if let Some(q_params) = &self.q_params {
                let mut params = q_params.1.clone();
                params.insert_input(0); // kernel as input
                params.insert_input(2); // bias as input
                let bias = self.bias.clone().unwrap_or_else(|| rctensor0(0i32));
                anyhow::ensure!(bias.rank() == 0 || bias.rank() == 1);
                let bias = patch.add_const(format!("{}.bias", &node.name), bias)?;
                inputs.insert(2, bias);
                let op = QMatMul { axes, output_type: q_params.0, params: q_params.1.clone() };
                patch.wire_node(&*node.name, op, &inputs)?[0]
            } else {
                let op = MatMul { axes };
                let mut wire = patch.wire_node(format!("{}.matmul", node.name), op, &inputs)?[0];
                if let Some(b) = self.bias.as_ref().filter(|_| self.q_params.is_none()) {
                    anyhow::ensure!(b.rank() == 0 || b.rank() == 1);
                    let mut bias_shape = tvec!(1; input_shape.rank());
                    bias_shape[input_shape.c_axis()] = co;
                    let b = b.clone().into_tensor().into_shape(&bias_shape)?;
                    let b =
                        patch.add_const(format!("{}.bias.cst", node.name), b.into_arc_tensor())?;
                    wire = patch.wire_node(
                        format!("{}.bias", node.name),
                        crate::ops::math::add(),
                        &[wire, b],
                    )?[0];
                }
                wire
            };
            patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
            return Ok(Some(patch));
        }
        Ok(None)
    }

    fn declutter_precursor_padding(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if self.pool_spec.padding != PaddingSpec::Valid
            && !matches!(self.pool_spec.padding, PaddingSpec::Explicit(_, _, _))
        {
            return Ok(None);
        }
        let prec = model.node(node.inputs[0].node);
        let pad = if let Some(pad) = prec.op_as::<Pad>() { pad } else { return Ok(None) };
        let value = if let PadMode::Constant(c) = &pad.mode {
            c
        } else {
            return Ok(None);
        };
        let shape = self.pool_spec.data_format.shape(&model.outlet_fact(node.inputs[0])?.shape)?;
        if value.cast_to_scalar::<i64>()? != 0
            || (self.pool_spec.data_format.has_n() && pad.pads[0] != (0, 0))
            || pad.pads[shape.c_axis()] != (0, 0)
        {
            return Ok(None);
        }
        let mut before: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.0).collect();
        let mut after: TVec<usize> = pad.pads[shape.hw_axes()].iter().map(|pair| pair.1).collect();
        if let PaddingSpec::Explicit(bef, aft, false) = &self.pool_spec.padding {
            izip!(&mut before, bef).for_each(|(pad, cv)| *pad += cv);
            izip!(&mut after, aft).for_each(|(pad, cv)| *pad += cv);
        }
        let padding = PaddingSpec::Explicit(before, after, false);
        let mut new = self.clone();
        new.pool_spec.padding = padding;
        let mut patch = TypedModelPatch::default();
        let wire = patch.tap_model(model, prec.inputs[0])?;
        let wire = patch.wire_node(&node.name, new, &[wire])?;
        patch.shunt_outside(model, node.id.into(), wire[0])?;
        Ok(Some(patch))
    }
}

impl Op for ConvUnary {
    fn name(&self) -> Cow<str> {
        "ConvUnary".into()
    }

    fn info(&self) -> TractResult<Vec<String>> {
        let mut info = self.pool_spec.info();
        info.push(format!(
            "Kernel {:?} (groups:{}), {:?}",
            self.kernel_fmt, self.group, self.kernel
        ));
        if let Some(b) = &self.bias {
            info.push(format!("Bias: {:?}", b))
        }
        Ok(info)
    }

    fn validation(&self) -> Validation {
        Validation::Rounding
    }

    op_as_typed_op!();
}

impl EvalOp for ConvUnary {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let mut model = TypedModel::default();

        let mut wires: TVec<OutletId> = inputs
            .iter()
            .enumerate()
            .map(|(ix, v)| {
                model.add_source(format!("source.{}", ix), v.datum_type().fact(v.shape()))
            })
            .collect::<TractResult<_>>()?;
        let new_op = self.kernel_offset_u8_as_i8(&mut wires, &mut model)?;
        let wire = unsafe {
            if self.q_params.is_some() {
                let op_ref = if let Some(op) = new_op.as_ref() { op } else { self };
                op_ref.wire_as_quant_im2col(
                    &mut model,
                    "im2col-adhoc",
                    inputs[0].datum_type(),
                    &wires,
                )?
            } else {
                self.wire_as_im2col_pair(&mut model, "im2col-adhoc", wires[0])?
            }
        };
        model.set_output_outlets(&[wire])?;
        model.into_runnable()?.run(inputs)
    }
}

impl TypedOp for ConvUnary {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        let q_inputs = self.q_params.as_ref().map(|(_, qp)| qp.input_count()).unwrap_or(0);
        if inputs.len() != 1 + q_inputs {
            bail!("Wrong number of inputs: expected {} got {}", 1 + q_inputs, inputs.len());
        }
        if self.pool_spec.data_format.shape(&*inputs[0].shape)?.c()
            != &self.input_channels().to_dim()
        {
            bail!(
                "Inconsistent convolution: input is {:?}, kernel expects {} input channels, {:?}",
                inputs[0],
                self.input_channels(),
                self
            );
        }
        if self.pool_spec.output_channel_override != Some(self.output_channels()) {
            bail!(
                "Inconsistent convolution: output channels from pool spec is {:?}, kernel expects {} output channels, {:?}",
                self.pool_spec.output_channel_override,
                self.output_channels(),
                self
                );
        }
        if let Some(bias) = &self.bias {
            ensure!(
                bias.rank() == 0 || (bias.rank() == 1 && bias.len() == self.output_channels()),
                "Bias should be scalar or a vector with one value per output channel, got:{:?}",
                bias
            );
        }

        let mut fact = self.pool_spec.output_facts(inputs)?.remove(0);
        if let Some((dt, _qp)) = self.q_params.as_ref() {
            fact.datum_type = *dt;
        } else {
            ensure!(
                inputs[0].datum_type == self.kernel.datum_type(),
                "Convolution input and weights must have the same type. (resp {:?} and {:?})",
                inputs[0].datum_type,
                self.kernel.datum_type(),
            )
        }
        Ok(tvec!(fact))
    }

    fn invariants(
        &self,
        inputs: &[&TypedFact],
        _outputs: &[&TypedFact],
    ) -> TractResult<Invariants> {
        let fact = &inputs[0];
        let shape = self.pool_spec.data_format.shape(fact.shape.iter().collect::<Vec<TDim>>())?;
        let mut axes = vec![];
        if let Some(n_axis) = shape.n_axis() {
            let mut info = AxisInfo::simple(n_axis).disposable(true);
            info.inputs.extend(std::iter::repeat(None).take(inputs.len() - 1));
            axes.push(info);
        }
        let kernel_spatial_shape =
            &self.kernel.shape()[self.kernel_fmt.h_axis()..][..shape.hw_rank()];
        let h_axis = shape.h_axis();
        for (ix, &dim) in kernel_spatial_shape.iter().enumerate() {
            if dim == 1 && self.pool_spec.stride(ix) == 1 {
                let mut info = AxisInfo::simple(ix + h_axis).disposable(true);
                info.inputs.extend(std::iter::repeat(None).take(inputs.len() - 1));
                axes.push(info)
            }
        }
        Ok(axes.into_iter().collect())
    }

    fn declutter(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if let Some((_, qp)) = self.q_params.as_ref() {
            if let Some((inputs, qp)) = qp.inline_static(model, node)? {
                let mut op = self.clone();
                op.q_params.as_mut().unwrap().1 = qp;
                let patch = TypedModelPatch::replace_single_op(model, node, &inputs, op)?
                    .with_context("inlining quantized conv params");
                return Ok(Some(patch));
            }
        }
        for d in &[Self::declutter_stride_slice_to_downsample, Self::declutter_as_matmul] {
            if let Some(p) = d(self, model, node)? {
                return Ok(Some(p));
            }
        }
        if let Some(p) = self.declutter_precursor_padding(model, node)? {
            return Ok(Some(p));
        }
        Ok(None)
    }

    fn cost(&self, inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
        let shape = self.pool_spec.data_format.shape(inputs[0].shape.to_tvec())?;
        let kernel_spatial_shape =
            &self.kernel.shape()[self.kernel_fmt.h_axis()..][..shape.hw_rank()];
        let output_dims = self.pool_spec.padding.compute(
            shape.hw_dims(),
            kernel_spatial_shape,
            &self
                .pool_spec
                .dilations
                .clone()
                .unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
            &self.pool_spec.strides.clone().unwrap_or_else(|| tvec!(1; kernel_spatial_shape.len())),
        );
        let n_output_points: TDim =
            output_dims.iter().map(|d| d.convoluted.clone()).product::<TDim>();
        let n_output_channels = self.output_channels().to_dim();
        let kernel_surface = kernel_spatial_shape.iter().product::<usize>().to_dim();
        let one = 1.to_dim();
        Ok(tvec!(
            (
                Cost::Params(inputs[0].datum_type.unquantized()),
                (self.kernel.len() + self.bias.as_ref().map(|b| b.len()).unwrap_or(0)).to_dim()
            ),
            (
                Cost::FMA(inputs[0].datum_type),
                shape.n().cloned().unwrap_or(one)
                    * shape.c()
                    * n_output_channels
                    * n_output_points
                    * kernel_surface
                    / self.group
            )
        ))
    }

    fn change_axes(
        &self,
        model: &TypedModel,
        node: &TypedNode,
        _io: InOut,
        change: &AxisOp,
    ) -> TractResult<Option<AxisChangeConsequence>> {
        let full_input_shape = model.outlet_fact(node.inputs[0])?.shape.to_tvec();
        let shape = self.pool_spec.data_format.shape(full_input_shape.clone())?;
        // remove n
        if let Some(n) = shape.n_axis() {
            assert_eq!(n, 0);
            if change == &AxisOp::Rm(n) {
                let op = ConvUnary { pool_spec: self.pool_spec.dispose_n_axis(), ..self.clone() };
                return Ok(Some(AxisChangeConsequence::new(
                    model,
                    node,
                    Some(Box::new(op)),
                    change,
                )));
            }
            if change.transform_axis(n).map(|axis| axis > 0).unwrap_or(true) {
                return Ok(None);
            }
        }
        // format swap: chw <-> hwc
        let (new_format, axis_move) = match self.pool_spec.data_format {
            DataFormat::NCHW => {
                (DataFormat::NHWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
            }
            DataFormat::CHW => {
                (DataFormat::HWC, AxisOp::Move(shape.c_axis(), full_input_shape.len() - 1))
            }
            DataFormat::NHWC => (DataFormat::NCHW, AxisOp::Move(shape.c_axis(), 1)),
            DataFormat::HWC => (DataFormat::CHW, AxisOp::Move(shape.c_axis(), 0)),
        };
        if *change == axis_move {
            let mut new_op = self.clone();
            new_op.pool_spec.data_format = new_format;
            return Ok(Some(AxisChangeConsequence {
                substitute_op: Some(Box::new(new_op)),
                wire_changes: tvec!(
                    (InOut::In(0), change.clone()),
                    (InOut::Out(0), change.clone())
                ),
            }));
        }
        // geo axis manips
        use AxisOp::*;
        let h_axis = shape.h_axis();
        let hw_axes = shape.hw_axes();
        let kh_axis = if self.kernel_fmt == KernelFormat::OIHW { 2 } else { 0 };
        let (geo_adjusted, kernel_adjusted) = match change {
            Rm(a)
                if hw_axes.contains(a)
                    && self.pool_spec.dilation(a - h_axis) == 1
                    && self.pool_spec.stride(a - h_axis) == 1
                    && self.pool_spec.kernel_shape[a - h_axis] == 1 =>
            {
                (Rm(a - h_axis), Rm(a - h_axis + kh_axis))
            }
            Add(a) if hw_axes.contains(a) => (Add(a - h_axis), Add(a - h_axis + kh_axis)),
            Move(f, t) if hw_axes.contains(f) && hw_axes.contains(t) => {
                (Move(f - h_axis, t - h_axis), Move(f - h_axis + kh_axis, t - h_axis + kh_axis))
            }
            _ => return Ok(None),
        };
        let mut kernel = self.kernel.clone().into_tensor();
        kernel_adjusted.change_tensor(&mut kernel, false)?;
        let mut dilations = self.pool_spec.dilations().into_owned().into();
        geo_adjusted.change_shape_array(&mut dilations, false)?;
        let mut kernel_shape = self.pool_spec.kernel_shape.clone();
        geo_adjusted.change_shape_array(&mut kernel_shape, false)?;
        let mut strides = self.pool_spec.strides().into_owned().into();
        geo_adjusted.change_shape_array(&mut strides, false)?;
        let new_op = ConvUnary {
            pool_spec: PoolSpec {
                data_format: self.pool_spec.data_format,
                padding: self.pool_spec.padding.clone(), // fixme (explicit padding)
                dilations: Some(dilations),
                kernel_shape,
                strides: Some(strides),
                output_channel_override: self.pool_spec.output_channel_override,
            },
            kernel_fmt: self.kernel_fmt,
            kernel: kernel.into_arc_tensor(),
            group: self.group,
            bias: self.bias.clone(),
            q_params: self.q_params.clone(),
        };
        Ok(Some(AxisChangeConsequence {
            substitute_op: Some(Box::new(new_op)),
            wire_changes: tvec!((InOut::In(0), change.clone()), (InOut::Out(0), change.clone())),
        }))
    }

    fn codegen(
        &self,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if let DatumType::U8 = self.kernel.datum_type().unquantized() {
            let mut patch = TypedModelPatch::default();
            let mut inputs = node
                .inputs
                .iter()
                .map(|w| patch.tap_model(model, *w))
                .collect::<TractResult<TVec<_>>>()?;
            let new_op = self.kernel_offset_u8_as_i8(&mut inputs, &mut patch)?.unwrap();
            let wire = patch.wire_node(&node.name, new_op, &inputs)?;
            patch.shunt_outside(model, node.id.into(), wire[0])?;
            patch.obliterate(node.id)?;
            return Ok(Some(patch.with_context("kernel-u8-to-i8")));
        }

        let full_input_shape = model.outlet_fact(node.inputs[0])?.shape.to_tvec();
        let input_fact = model.outlet_fact(node.inputs[0])?;
        let input_shape = self.pool_spec.data_format.shape(&full_input_shape)?;
        let spatial_rank = input_shape.hw_rank();
        let kernel_spatial_shape = &self.kernel.shape()[self.kernel_fmt.h_axis()..][..spatial_rank];
        unsafe {
            let dt = input_fact.datum_type;
            if self.q_params.is_some() {
                let mut patch = TypedModelPatch::default();
                let inputs = node
                    .inputs
                    .iter()
                    .map(|w| patch.tap_model(model, *w))
                    .collect::<TractResult<TVec<_>>>()?;
                let wire = self.wire_as_quant_im2col(
                    &mut patch,
                    &node.name,
                    model.node_input_facts(node.id)?[0].datum_type,
                    &inputs,
                )?;
                patch.shunt_outside(model, node.id.into(), wire)?;
                patch.obliterate(node.id)?;
                Ok(Some(patch.with_context("quantized-codegen")))
            } else if kernel_spatial_shape.iter().product::<usize>() == 1
                && (0..spatial_rank)
                    .all(|i| self.pool_spec.stride(i) == 1 && self.pool_spec.dilation(i) == 1)
                && self.group == 1
            {
                use crate::ops::matmul::MatMulUnary;
                let mut patch = TypedModelPatch::default();
                let mut wire = patch.tap_model(model, node.inputs[0])?;
                let input_c_is_last = input_shape.c_axis() == input_shape.rank() - 1;
                let geo_dim: TDim = input_shape.hw_dims().iter().product();
                wire = patch.wire_node(
                    format!("{}.reshape_input", &*node.name),
                    AxisOp::Reshape(
                        input_shape.h_axis(),
                        input_shape.hw_dims().into(),
                        tvec!(geo_dim.clone()),
                    ),
                    &[wire],
                )?[0];
                let kernel_shape = match self.kernel_fmt {
                    KernelFormat::HWIO => &self.kernel.shape()[spatial_rank..],
                    KernelFormat::OIHW => &self.kernel.shape()[..2],
                };
                let operating_rank = input_fact.rank() + 1 - kernel_spatial_shape.len();
                let kernel = self
                    .kernel
                    .as_ref()
                    .clone()
                    .into_shape(kernel_shape)?
                    .broadcast_into_rank(operating_rank)?;
                wire = patch.wire_node(
                    &format!("{}.matmul", &node.name),
                    MatMulUnary::new(
                        kernel.into_arc_tensor(),
                        MatMulAxes::default_for_rank(operating_rank).transposing(
                            self.kernel_fmt == KernelFormat::HWIO,
                            input_c_is_last,
                            input_c_is_last,
                        ),
                    ),
                    &[wire],
                )?[0];
                if let Some(ref bias) = self.bias {
                    let bias_shape =
                        if input_c_is_last { [1, bias.len()] } else { [bias.len(), 1] };
                    let bias = bias
                        .clone()
                        .into_tensor()
                        .into_shape(&bias_shape)?
                        .broadcast_into_rank(operating_rank)?
                        .into_arc_tensor();
                    let bias = patch.add_const(format!("{}.bias.cst", node.name), bias)?;
                    wire = patch.wire_node(
                        format!("{}.bias", node.name),
                        crate::ops::math::add(),
                        &[wire, bias],
                    )?[0];
                }
                wire = patch.wire_node(
                    &*node.name,
                    AxisOp::Reshape(
                        input_shape.h_axis(),
                        tvec!(geo_dim),
                        input_shape.hw_dims().into(),
                    ),
                    &[wire],
                )?[0];
                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
                patch.obliterate(node.id)?;
                Ok(Some(patch))
            } else if input_fact
                .shape
                .as_concrete()
                .map(|s| {
                    should_use_lazy(
                        &self.pool_spec.data_format.shape(s.into()).unwrap(),
                        &self.pool_spec,
                        self.group,
                    )
                })
                .unwrap_or(false)
            {
                let mut patch = TypedModelPatch::new("wire_as_lazy_im2col");
                let mut wire = patch.tap_model(model, node.inputs[0])?;
                wire = self.wire_as_lazy_im2col(&mut patch, &node.name, wire)?;
                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
                patch.obliterate(node.id)?;
                Ok(Some(patch))
            } else if self.group != 1
                && self.group == self.output_channels()
                && self.group == self.input_channels()
                && input_fact.shape.as_concrete().is_some()
            {
                let op = dispatch_floatlike!(Self::to_depth_wise(dt)(self, input_fact))
                    .context("in to_depth_wise")?;
                Ok(Some(TypedModelPatch::single_unary_op(model, node, op)?))
            } else {
                let mut patch = TypedModelPatch::default();
                let wire = patch.tap_model(model, node.inputs[0])?;
                let wire = self
                    .wire_as_im2col_pair(&mut patch, &node.name, wire)
                    .context("in wire_as_im2col_pair")?;
                patch.shunt_outside(model, OutletId::new(node.id, 0), wire)?;
                patch.obliterate(node.id)?;
                Ok(Some(patch))
            }
        }
    }
Examples found in repository?
src/model/graph.rs (line 702)
697
698
699
700
701
702
703
704
705
706
    pub fn compact(&mut self) -> TractResult<()> {
        use crate::model::translator::Translate;
        let mut result = crate::model::translator::IntoTranslator.translate_model(self)?;
        #[cfg(debug_assertions)]
        {
            result.check_compact().context("after graph compaction")?;
        }
        std::mem::swap(self, &mut result);
        Ok(())
    }
Examples found in repository?
src/model/graph.rs (line 709)
708
709
710
711
    pub fn into_compact(mut self) -> TractResult<Self> {
        self.compact()?;
        Ok(self)
    }
More examples
Hide additional examples
src/optim/mod.rs (line 91)
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
        model.check_consistency().context("during optimizer preflight check")?;
        model.compact().context("during optimizer preflight compaction")?;
        for i in 0.. {
            let old = self.counter;
            self.run_all_passes(i, model)?;
            if old == self.counter {
                return Ok(());
            }
            model.compact()?;
        }
        unreachable!()
    }

    pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
        let mut passes = self.optimizer.passes.clone();
        for p in passes.iter_mut() {
            self.run_one_pass_outer(i, p.as_mut(), model)
                .with_context(|| format!("running pass {:?}", p))?;
            model.compact()?;
            model
                .check_consistency()
                .with_context(|| format!("consistency check after pass {:?}", p))?;
        }
        Ok(())
    }

    pub fn run_one_pass_outer(
        &mut self,
        i: usize,
        p: &mut dyn TypedPass,
        model: &mut TypedModel,
    ) -> TractResult<()> {
        loop {
            let old_counter = self.counter;
            self.run_one_pass_inner(i, p, model)?;
            if self.counter == old_counter {
                return Ok(());
            }
            model.compact().with_context(|| format!("after pass {:?}", p))?;
        }
    }
src/ops/scan/mir.rs (line 550)
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
    fn try_body_axes_change(
        &self,
        change: AxisChange,
        locked_interface: bool,
    ) -> TractResult<Option<AxisChangeConsequence>> {
        self.body.check_consistency()?;
        let interface = self.body_exposed_outlets()?;
        let (patch, body_changed_wires) = if let Some(changes) =
            crate::ops::change_axes::change_axes(
                &self.body,
                &change,
                if locked_interface { &interface } else { &[] },
                &self.body_bounds()?,
            )? {
            changes
        } else {
            return Ok(None);
        };
        let mut body = self.body.clone();
        patch.apply(&mut body)?;
        body.compact()?;
        let mut wire_changes = tvec!();
        let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
        for (ix, m) in input_mapping.iter_mut().enumerate() {
            if let Some(change) = body_changed_wires
                .iter()
                .find(|(iface, _change)| iface == &InOut::In(ix))
                .map(|pair| pair.1.clone())
            {
                if let Some(slot) = m.slot() {
                    wire_changes.push((InOut::In(slot), change.clone()));
                }
                match &*m {
                    InputMapping::Full { .. } => (),
                    &InputMapping::Scan(info) => {
                        if let Some(axis) = change.transform_axis(info.axis) {
                            *m = InputMapping::Scan(ScanInfo { axis, ..info });
                        } else {
                            return Ok(None);
                        };
                    }
                    InputMapping::State { initializer } => match initializer {
                        StateInitializer::FromInput(_) => (),
                        StateInitializer::Value(ref v) => {
                            let mut v = v.clone().into_tensor();
                            change.change_tensor(&mut v, false)?;
                            *m = InputMapping::State {
                                initializer: StateInitializer::Value(v.into_arc_tensor()),
                            };
                        }
                    },
                };
            }
        }
        let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
        for (ix, m) in output_mapping.iter_mut().enumerate() {
            if let Some(change) = body_changed_wires
                .iter()
                .find(|(iface, _change)| iface == &InOut::Out(ix))
                .map(|pair| pair.1.clone())
            {
                if let Some(info) = m.scan.as_mut() {
                    if let Some(new_axis) = change.transform_axis(info.axis) {
                        info.axis = new_axis;
                    } else {
                        return Ok(None);
                    }
                    wire_changes.push((InOut::Out(info.slot), change.clone()));
                }
                if let Some(slot) = m.last_value_slot {
                    wire_changes.push((InOut::Out(slot), change.clone()));
                }
            };
        }
        body.check_consistency()?;
        let op = Some(Box::new(Scan {
            body,
            input_mapping,
            output_mapping,
            decluttered: false,
            ..self.clone()
        }) as _);
        Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
    }
Examples found in repository?
src/ops/scan/mir.rs (line 24)
21
22
23
24
25
26
27
28
29
30
31
32
33
34
    pub fn to_codegen_op(&self, optimize_inner: bool) -> TractResult<LirScan> {
        let mut model = self.body.clone();
        if optimize_inner {
            model = model.into_optimized()?;
        }
        let plan = SimplePlan::new(model)?;

        Ok(LirScan::new(Arc::new(LirScanOpParams::new(
            self.skip,
            Arc::new(plan),
            self.input_mapping.clone(),
            self.output_mapping.clone(),
        ))))
    }
Examples found in repository?
src/optim/mod.rs (line 90)
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    pub fn optimize(&mut self, model: &mut TypedModel) -> TractResult<()> {
        model.check_consistency().context("during optimizer preflight check")?;
        model.compact().context("during optimizer preflight compaction")?;
        for i in 0.. {
            let old = self.counter;
            self.run_all_passes(i, model)?;
            if old == self.counter {
                return Ok(());
            }
            model.compact()?;
        }
        unreachable!()
    }

    pub fn run_all_passes(&mut self, i: usize, model: &mut TypedModel) -> TractResult<()> {
        let mut passes = self.optimizer.passes.clone();
        for p in passes.iter_mut() {
            self.run_one_pass_outer(i, p.as_mut(), model)
                .with_context(|| format!("running pass {:?}", p))?;
            model.compact()?;
            model
                .check_consistency()
                .with_context(|| format!("consistency check after pass {:?}", p))?;
        }
        Ok(())
    }

    pub fn run_one_pass_outer(
        &mut self,
        i: usize,
        p: &mut dyn TypedPass,
        model: &mut TypedModel,
    ) -> TractResult<()> {
        loop {
            let old_counter = self.counter;
            self.run_one_pass_inner(i, p, model)?;
            if self.counter == old_counter {
                return Ok(());
            }
            model.compact().with_context(|| format!("after pass {:?}", p))?;
        }
    }

    pub fn run_one_pass_inner(
        &mut self,
        i: usize,
        p: &mut dyn TypedPass,
        model: &mut TypedModel,
    ) -> TractResult<()> {
        p.reset()?;
        if let Some(steps) = self.optimizer.steps {
            if self.counter >= steps {
                return Ok(());
            }
        }
        while let Some(mut patch) = p.next(self, model)? {
            patch.push_context(format!("{:?}/{}", p, i));
            patch.model.check_consistency().context("checking patch internal consistency")?;
            model
                .check_consistency()
                .context("Checking target model consistency before patching")?;
            if let Some(watchdog) = patch.dont_apply_twice.take() {
                if self.seen.contains(&watchdog) {
                    debug!("Loop detected: {} seen before", watchdog);
                    continue;
                } else {
                    self.seen.insert(watchdog);
                }
            }
            debug!("applying patch #{}: {}", self.counter, patch.context.iter().rev().join(" >> "),);
            patch.apply(model)?;
            model
                .check_consistency()
                .context("Checking target model consistency after patchign")?;
            self.counter += 1;
            if let Some(steps) = self.optimizer.steps {
                if self.counter >= steps {
                    return Ok(());
                }
            }
        }
        model.check_consistency().with_context(|| format!("after pass {:?}", p))?;
        Ok(())
    }
More examples
Hide additional examples
src/ops/scan/mir.rs (line 43)
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
    pub fn new(
        body: TypedModel,
        input_mapping: Vec<InputMapping>,
        output_mapping: Vec<OutputMapping<TDim>>,
        seq_length_input_slot: Option<usize>,
        skip: usize,
    ) -> TractResult<Scan> {
        body.check_consistency()?;
        ensure!(input_mapping.len() == body.input_outlets()?.len());
        ensure!(output_mapping.len() == body.output_outlets()?.len());
        Ok(Scan {
            skip,
            body,
            decluttered: false,
            input_mapping,
            output_mapping,
            seq_length_input_slot,
        })
    }

    pub fn iteration_count(&self, inputs: &[&TypedFact]) -> Option<TDim> {
        self.to_codegen_op(false).unwrap().iteration_count(inputs)
    }

    fn declutter_body(
        &self,
        session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        if !self.decluttered {
            let mut new = self.clone();
            let mut body = self.body.clone();
            session.optimize(&mut body)?;
            new.body = body;
            new.decluttered = true;
            Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
        } else {
            Ok(None)
        }
    }

    fn declutter_body_axes(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let mut suggestions = vec![];
        for n in self.body.eval_order()? {
            let node = self.body.node(n);
            for suggestion in node.op.suggested_axis_changes()? {
                let outlet = suggestion.0.as_outlet(node);
                suggestions.push(AxisChange { outlet, op: suggestion.1 })
            }
        }
        for suggestion in suggestions.into_iter() {
            if let Some(op) =
                self.try_body_axes_change(suggestion, true)?.and_then(|c| c.substitute_op)
            {
                return Ok(Some(TypedModelPatch::replace_single_op(
                    model,
                    node,
                    &node.inputs,
                    op,
                )?));
            }
        }
        Ok(None)
    }

    fn remove_outer_input_from_mappings(
        mappings: &[InputMapping],
        discarded: usize,
    ) -> Vec<InputMapping> {
        mappings
            .iter()
            .map(|m| match m {
                &InputMapping::Full { slot } => {
                    InputMapping::Full { slot: slot - (slot > discarded) as usize }
                }
                &InputMapping::Scan(info) => InputMapping::Scan(ScanInfo {
                    slot: info.slot - (info.slot > discarded) as usize,
                    ..info
                }),
                InputMapping::State { initializer } => {
                    let initializer = match initializer {
                        StateInitializer::FromInput(n) => {
                            StateInitializer::FromInput(*n - (*n > discarded) as usize)
                        }
                        StateInitializer::Value(v) => StateInitializer::Value(v.clone()),
                    };
                    InputMapping::State { initializer }
                }
            })
            .collect()
    }

    fn remove_outer_output_from_mappings(
        mappings: &[OutputMapping<TDim>],
        discarded: usize,
    ) -> Vec<OutputMapping<TDim>> {
        mappings
            .iter()
            .map(|m| OutputMapping {
                scan: m.scan.map(|info| ScanInfo {
                    slot: info.slot - (info.slot > discarded) as usize,
                    ..info
                }),
                last_value_slot: m.last_value_slot.map(|n| n - (n > discarded) as usize),
                full_dim_hint: m.full_dim_hint.clone(),
                state: m.state,
            })
            .collect()
    }

    fn declutter_const_initializer(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        let inputs = model.node_input_facts(node.id)?;
        for (ix, mapping) in self.input_mapping.iter().enumerate() {
            if let InputMapping::State { initializer: StateInitializer::FromInput(n) } = mapping {
                if let Some(i) = inputs[*n].konst.as_ref() {
                    let mut op = self.clone();
                    op.input_mapping[ix] =
                        InputMapping::State { initializer: StateInitializer::Value(i.clone()) };
                    op.input_mapping =
                        Self::remove_outer_input_from_mappings(&op.input_mapping, *n);
                    let mut inputs = node.inputs.clone();
                    inputs.remove(*n);
                    return Ok(Some(TypedModelPatch::replace_single_op(model, node, &inputs, op)?));
                }
            }
        }
        Ok(None)
    }

    fn declutter_discard_unused_input_mapping(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (inner_input_id, input) in self.body.input_outlets()?.iter().enumerate() {
            let source_node = self.body.node(input.node);
            if source_node.outputs[0].successors.len() == 0
                && !self.body.output_outlets()?.contains(input)
            {
                let mut new_inputs = node.inputs.clone();
                let slot = match &self.input_mapping[inner_input_id] {
                    InputMapping::Full { slot } => Some(*slot),
                    InputMapping::Scan(info) => Some(info.slot),
                    InputMapping::State { initializer } => match initializer {
                        StateInitializer::FromInput(n) => Some(*n),
                        _ => None,
                    },
                };
                let mut new_mappings: Vec<_> = self.input_mapping.clone();
                new_mappings.remove(inner_input_id);
                if let Some(slot) = slot {
                    new_mappings = Self::remove_outer_input_from_mappings(&new_mappings, slot);
                }
                let mut model_inputs = self.body.input_outlets()?.to_vec();
                if let Some(slot) = slot {
                    new_inputs.remove(slot);
                }
                model_inputs.remove(inner_input_id);
                let mut body = self.body.clone();
                let mut patch = TypedModelPatch::default();
                patch.obliterate(source_node.id)?;
                patch.apply(&mut body)?;
                body.set_input_outlets(&model_inputs)?;
                body.declutter()?;
                let op = Self {
                    body,
                    skip: self.skip,
                    seq_length_input_slot: self.seq_length_input_slot,
                    input_mapping: new_mappings,
                    decluttered: true,
                    output_mapping: self.output_mapping.clone(),
                };
                return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
            }
        }
        Ok(None)
    }

    fn declutter_discard_useless_outer_output(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (ix, o) in node.outputs.iter().enumerate() {
            if o.successors.len() == 0
                && !model.output_outlets()?.contains(&OutletId::new(node.id, ix))
            {
                let mappings = self
                    .output_mapping
                    .iter()
                    .map(|m| OutputMapping {
                        scan: m.scan.filter(|info| info.slot != ix),
                        last_value_slot: m.last_value_slot.filter(|s| *s != ix),
                        full_dim_hint: m.full_dim_hint.clone(),
                        state: m.state,
                    })
                    .collect::<Vec<_>>();
                let mut op = self.clone();
                op.output_mapping = Self::remove_outer_output_from_mappings(&mappings, ix);
                let mut patch = TypedModelPatch::default();
                let inputs = node
                    .inputs
                    .iter()
                    .map(|&i| patch.tap_model(model, i))
                    .collect::<TractResult<Vec<_>>>()?;
                let wires = patch.wire_node(&*node.name, op, &inputs)?;
                for oix in 0..node.outputs.len() {
                    if oix != ix {
                        patch.shunt_outside(
                            model,
                            OutletId::new(node.id, oix),
                            wires[oix - (oix > ix) as usize],
                        )?;
                    }
                }
                return Ok(Some(patch));
            }
        }
        Ok(None)
    }

    fn declutter_discard_empty_output_mapping_with_body_output(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (ix, om) in self.output_mapping.iter().enumerate() {
            if om.last_value_slot.is_none() && om.scan.is_none() && !om.state {
                let mut new_op = self.clone();
                new_op.output_mapping.remove(ix);
                new_op.body.outputs.remove(ix);
                new_op.decluttered = false;
                return Ok(Some(TypedModelPatch::replace_single_op(
                    model,
                    node,
                    &node.inputs,
                    new_op,
                )?));
            }
        }
        Ok(None)
    }

    fn declutter_pull_batcheable_input(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (model_input, input) in self.input_mapping.iter().enumerate() {
            if let Some(info) = input.as_scan() {
                let scan_source = self.body.input_outlets()?[model_input];
                let scan_source_node = self.body.node(scan_source.node);
                for successor in &scan_source_node.outputs[0].successors {
                    let successor_node = self.body.node(successor.node);
                    if successor_node.inputs.len() != 1 || successor_node.outputs.len() != 1 {
                        continue;
                    }
                    let (input_facts, output_facts) = self.body.node_facts(successor_node.id)?;
                    let invariants = successor_node.op.invariants(&input_facts, &output_facts)?;
                    if let Some(axis_after) = invariants.unary_track_axis_down(info.axis, false) {
                        let mut outside_patch = TypedModelPatch::new(format!(
                            "Outer patch for input extraction of {}",
                            successor_node
                        ));
                        let mut patch_inputs = node
                            .inputs
                            .iter()
                            .map(|&i| outside_patch.tap_model(model, i))
                            .collect::<TractResult<TVec<_>>>()?;
                        let input = patch_inputs[info.slot];
                        let new_input_wire = outside_patch.wire_node(
                            format!("{}.extracted.{}", node.name, successor_node.name),
                            successor_node.op.clone(),
                            &[input],
                        )?[0];
                        patch_inputs.push(new_input_wire);
                        let new_input_outer_fact = outside_patch.outlet_fact(new_input_wire)?;
                        let mut new_input_inner_fact = new_input_outer_fact.clone();
                        new_input_inner_fact.shape.set(axis_after, info.chunk.abs().to_dim());

                        let mut new_body = self.body.clone();
                        let new_source_wire = new_body.add_source(
                            format!("{}.extracted.{}", node.name, successor_node.name),
                            new_input_inner_fact,
                        )?;
                        let mut inner_patch = TypedModelPatch::new(format!(
                            "Inner body patch for extraction of {}",
                            successor_node
                        ));
                        let new_source_wire_in_patch =
                            inner_patch.tap_model(&new_body, new_source_wire)?;
                        inner_patch
                            .shunt_outside(
                                &new_body,
                                OutletId::new(successor.node, 0),
                                new_source_wire_in_patch,
                            )
                            .with_context(|| "patching inner model")?;
                        inner_patch.apply(&mut new_body)?;

                        let mut input_mapping = self.input_mapping.clone();
                        input_mapping.push(InputMapping::Scan(ScanInfo {
                            axis: axis_after,
                            chunk: info.chunk,
                            slot: node.inputs.len(),
                        }));

                        let new_op = Self {
                            input_mapping,
                            output_mapping: self.output_mapping.clone(),
                            decluttered: false,
                            body: new_body,
                            skip: self.skip,
                            seq_length_input_slot: self.seq_length_input_slot,
                        };
                        let output_wires =
                            outside_patch.wire_node(&*node.name, new_op, &patch_inputs)?;
                        for w in output_wires {
                            outside_patch
                                .shunt_outside(model, OutletId::new(node.id, w.slot), w)
                                .with_context(|| "patching outer model")?;
                        }
                        return Ok(Some(outside_patch));
                    }
                }
            }
        }
        Ok(None)
    }

    fn declutter_pull_constant_outputs(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
        for (model_output_ix, mapping) in self.output_mapping.iter().enumerate() {
            if let Some(slot) = mapping.last_value_slot {
                if let Some(k) = self.body.output_fact(model_output_ix)?.konst.clone() {
                    let inner_node = self.body.output_outlets()?[model_output_ix].node;
                    let inner_node = self.body.node(inner_node);
                    let mut patch = TypedModelPatch::new(format!("Extract const node {}", inner_node));
                    let cst = patch.add_const(format!("{}.{}", &node.name, &inner_node.name), k)?;
                    patch.shunt_outside(model, OutletId::new(node.id, slot), cst)?;
                    return Ok(Some(patch));
                }
            }
        }
        Ok(None)
    }

    fn declutter_pull_batcheable_output(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (model_ix, mapping) in self.output_mapping.iter().enumerate() {
            if let Some(info) = mapping.scan {
                let emitter_outlet = self.body.output_outlets()?[model_ix];
                let emitter_node = self.body.node(emitter_outlet.node);
                if emitter_node.outputs[emitter_outlet.slot].successors.len() > 0
                    || mapping.state
                    || mapping.scan.map(|i| i.chunk > 1).unwrap_or(true)
                {
                    // continue if both last_value and full values are exported
                    continue;
                }
                let (input_facts, output_facts) = self.body.node_facts(emitter_node.id)?;
                let invariants = emitter_node.op.invariants(&input_facts, &output_facts)?;
                let Some(axis_before) = invariants.unary_track_axis_up(info.axis, false)
                else {
                    continue;
                };

                let mut new_body = self.body.clone();
                let mut new_output_mapping = self.output_mapping.clone();
                let mut new_scan_outputs = node.outputs.len();
                let mut outer_slots = vec![];

                for input in &emitter_node.inputs {
                    if new_body.outputs.iter().all(|o| o != input) {
                        new_output_mapping.push(OutputMapping::default());
                        new_body.outputs.push(*input);
                    }
                    let body_output_id = new_body.outputs.iter().position(|o| o == input).unwrap();
                    let mut mapping = &mut new_output_mapping[body_output_id];
                    let outer_slot = if new_body.outlet_fact(*input)?.konst.is_some() {
                        if mapping.last_value_slot.is_none() {
                            mapping.last_value_slot = Some(new_scan_outputs);
                        }
                        new_scan_outputs += 1;
                        mapping.last_value_slot.unwrap()
                    } else {
                        if mapping.scan.is_none() {
                            mapping.scan = Some(ScanInfo {
                                slot: new_scan_outputs,
                                axis: axis_before,
                                chunk: info.chunk,
                            });
                            new_scan_outputs += 1;
                        }
                        mapping.scan.unwrap().slot
                    };
                    outer_slots.push(outer_slot);
                }
                let mut outside_patch = TypedModelPatch::new(format!(
                    "Outside patch for output extraction of {}",
                    emitter_node
                ));
                let inputs = node
                    .inputs
                    .iter()
                    .map(|&i| outside_patch.tap_model(model, i))
                    .collect::<TractResult<TVec<_>>>()?;
                let new_op = Self {
                    input_mapping: self.input_mapping.clone(),
                    output_mapping: new_output_mapping,
                    decluttered: false,
                    body: new_body,
                    skip: self.skip,
                    seq_length_input_slot: self.seq_length_input_slot,
                };
                let scan_outputs = outside_patch.wire_node(&node.name, new_op, &inputs)?;
                let output = mapping.scan.unwrap();
                let inputs =
                    outer_slots.iter().map(|slot| scan_outputs[*slot]).collect::<TVec<_>>();
                let wire = outside_patch.wire_node(
                    &*emitter_node.name,
                    emitter_node.op.clone(),
                    &inputs,
                )?[0];
                outside_patch.shunt_outside(model, OutletId::new(node.id, output.slot), wire)?;
                for output_slot in 0..node.outputs.len() {
                    if output_slot != output.slot {
                        outside_patch.shunt_outside(
                            model,
                            OutletId::new(node.id, output_slot),
                            OutletId::new(scan_outputs[0].node, output_slot),
                        )?;
                    }
                }
                return Ok(Some(outside_patch));
            }
        }
        Ok(None)
    }

    fn body_bounds(&self) -> TractResult<TVec<TVec<OutletId>>> {
        let input_state_outlets = self
            .input_mapping
            .iter()
            .zip(self.body.input_outlets()?.iter())
            .filter(|(m, _)| m.as_state().is_some())
            .map(|(_, o)| o);
        let output_state_outlets = self
            .output_mapping
            .iter()
            .zip(self.body.output_outlets()?.iter())
            .filter(|(m, _)| m.state)
            .map(|(_, o)| o);
        Ok(input_state_outlets.zip(output_state_outlets).map(|(&i, &o)| tvec!(i, o)).collect())
    }

    fn body_exposed_outlets(&self) -> TractResult<TVec<OutletId>> {
        let input_outlets = self
            .input_mapping
            .iter()
            .zip(self.body.input_outlets()?.iter())
            .filter(|(m, _)| !m.invisible())
            .map(|(_, o)| o);
        let output_outlets = self
            .output_mapping
            .iter()
            .zip(self.body.output_outlets()?.iter())
            .filter(|(m, _)| !m.invisible())
            .map(|(_, o)| o);
        Ok(input_outlets.chain(output_outlets).cloned().collect())
    }

    fn try_body_axes_change(
        &self,
        change: AxisChange,
        locked_interface: bool,
    ) -> TractResult<Option<AxisChangeConsequence>> {
        self.body.check_consistency()?;
        let interface = self.body_exposed_outlets()?;
        let (patch, body_changed_wires) = if let Some(changes) =
            crate::ops::change_axes::change_axes(
                &self.body,
                &change,
                if locked_interface { &interface } else { &[] },
                &self.body_bounds()?,
            )? {
            changes
        } else {
            return Ok(None);
        };
        let mut body = self.body.clone();
        patch.apply(&mut body)?;
        body.compact()?;
        let mut wire_changes = tvec!();
        let mut input_mapping: Vec<InputMapping> = self.input_mapping.clone();
        for (ix, m) in input_mapping.iter_mut().enumerate() {
            if let Some(change) = body_changed_wires
                .iter()
                .find(|(iface, _change)| iface == &InOut::In(ix))
                .map(|pair| pair.1.clone())
            {
                if let Some(slot) = m.slot() {
                    wire_changes.push((InOut::In(slot), change.clone()));
                }
                match &*m {
                    InputMapping::Full { .. } => (),
                    &InputMapping::Scan(info) => {
                        if let Some(axis) = change.transform_axis(info.axis) {
                            *m = InputMapping::Scan(ScanInfo { axis, ..info });
                        } else {
                            return Ok(None);
                        };
                    }
                    InputMapping::State { initializer } => match initializer {
                        StateInitializer::FromInput(_) => (),
                        StateInitializer::Value(ref v) => {
                            let mut v = v.clone().into_tensor();
                            change.change_tensor(&mut v, false)?;
                            *m = InputMapping::State {
                                initializer: StateInitializer::Value(v.into_arc_tensor()),
                            };
                        }
                    },
                };
            }
        }
        let mut output_mapping: Vec<OutputMapping<TDim>> = self.output_mapping.clone();
        for (ix, m) in output_mapping.iter_mut().enumerate() {
            if let Some(change) = body_changed_wires
                .iter()
                .find(|(iface, _change)| iface == &InOut::Out(ix))
                .map(|pair| pair.1.clone())
            {
                if let Some(info) = m.scan.as_mut() {
                    if let Some(new_axis) = change.transform_axis(info.axis) {
                        info.axis = new_axis;
                    } else {
                        return Ok(None);
                    }
                    wire_changes.push((InOut::Out(info.slot), change.clone()));
                }
                if let Some(slot) = m.last_value_slot {
                    wire_changes.push((InOut::Out(slot), change.clone()));
                }
            };
        }
        body.check_consistency()?;
        let op = Some(Box::new(Scan {
            body,
            input_mapping,
            output_mapping,
            decluttered: false,
            ..self.clone()
        }) as _);
        Ok(Some(AxisChangeConsequence { substitute_op: op, wire_changes }))
    }
src/ops/downsample/mod.rs (line 113)
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
fn pull_downsample_up(
    model: &TypedModel,
    down_node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
    model.check_consistency()?;
    let down_op = down_node.op_as::<Downsample>().unwrap();
    if let Some(prec) = model.single_prec(down_node.id)? {
        let (input_facts, output_facts) = model.node_facts(prec.id)?;
        let invariants = prec.op.invariants(&input_facts, &output_facts)?;
        debug!("Consider pull {:?} over {:?} (invariants: {:?})", down_op, prec, invariants);
        if let Some(slice_op) = prec.op_as::<ops::array::Slice>() {
            if let Some(p) = array::pull_downsample_over_slice(model, prec, slice_op, down_node, down_op)? {
                return Ok(Some(p))
            }
        } else if let Some(other_op) = prec.op_as::<AxisOp>() {
            return array::pull_downsample_over_axis_op(model, prec, other_op, down_node, down_op);
        } else if let Some(conv_op) = prec.op_as::<ops::cnn::conv::ConvUnary>() {
            return conv::fuse_downsample_into_conv(model, prec, conv_op, down_node, down_op);
        } else if let Some(other_op) = prec.op_as::<ops::scan::Scan>() {
            return scan::pull_downsample_over_scan(model, prec, other_op, down_node, down_op);
        }
        if let Some(above_axis) = invariants.unary_track_axis_up(down_op.axis, false) {
            let mut patch = TypedModelPatch::default();
            let mut inputs = vec![];
            for (ix, &oo) in prec.inputs.iter().enumerate() {
                let source = patch.tap_model(model, oo)?;
                let mut op = down_op.clone();
                op.axis = above_axis;
                let ds = patch.wire_node(
                    format!("{}.{}-{}", down_node.name, prec.name, ix),
                    op,
                    [source].as_ref(),
                )?;
                inputs.push(ds[0]);
            }
            let other = patch.wire_node(&prec.name, prec.op.clone(), &inputs)?;
            patch.shunt_outside(model, OutletId::new(down_node.id, 0), other[0])?;
            return Ok(Some(patch));
        }
    }
    Ok(None)
}
src/ops/downsample/scan.rs (line 19)
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
pub fn pull_downsample_over_scan(
    model: &TypedModel,
    scan_node: &TypedNode,
    scan_op: &ops::scan::Scan,
    down_node: &TypedNode,
    down_op: &Downsample,
) -> TractResult<Option<TypedModelPatch>> {
    if down_op.stride < 0 {
        return Ok(None);
    }

    // introduce downsample at end of body
    let mut downsampled_body = scan_op.body.clone();
    downsampled_body.check_consistency()?;
    let outputs = downsampled_body.output_outlets()?.to_owned();
    let downsample_outputs = outputs
        .into_iter()
        .enumerate()
        .map(|(ix, oo)| {
            Ok(downsampled_body.wire_node(
                format!("{}-{}", &down_node.name, ix),
                down_op.clone(),
                &[oo],
            )?[0])
        })
        .collect::<TractResult<Vec<_>>>()?;
    downsampled_body.set_output_outlets(&downsample_outputs)?;
    downsampled_body.declutter()?;
    downsampled_body.check_consistency()?;

    // check if downsample ops introduced at end have swimmed up to scan inputs during declutter
    for input in downsampled_body.input_outlets()? {
        let input = downsampled_body.node(input.node);
        if input.outputs[0]
            .successors
            .iter()
            .any(|succ| !downsampled_body.node(succ.node).op().same_as(down_op))
        {
            return Ok(None);
        }
    }

    let inputs = downsampled_body.input_outlets()?.to_vec();
    for input in inputs {
        let node = &mut downsampled_body.node_mut(input.node);
        let fact = &mut node.outputs[0].fact;
        *fact = down_op.transform_fact(fact)?;
        node.op_as_mut::<crate::ops::source::TypedSource>().unwrap().fact = fact.clone();
        let downsamples = downsampled_body.node(input.node).outputs[0].successors.clone();
        for ds in downsamples {
            TypedModelPatch::shunt_one_op(&downsampled_body as _, downsampled_body.node(ds.node))?
                .apply(&mut downsampled_body)?;
        }
    }

    downsampled_body.check_consistency()?;
    let inner_model = downsampled_body.into_decluttered()?;

    let mut new_scan = scan_op.clone();
    new_scan.body = inner_model;
    for input in &mut new_scan.input_mapping {
        match input {
            InputMapping::State { ref mut initializer } => {
                if let StateInitializer::Value(ref v) = initializer {
                    let mut new_v = down_op.eval(tvec!(v.clone().into_tvalue()))?;
                    *initializer = StateInitializer::Value(new_v.remove(0).into_arc_tensor());
                }
            }
            InputMapping::Scan(info) => {
                if info.chunk > 0 && info.chunk as usize % down_op.stride as usize != 0 {
                    return Ok(None);
                }
                info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
                    * info.chunk.signum()
            }
            _ => (),
        }
    }
    for output in &mut new_scan.output_mapping {
        if let Some(d) = output.full_dim_hint.as_mut() {
            *d = down_op.transform_dim(d)
        }
        if let Some(info) = &mut output.scan {
            if info.chunk as usize % down_op.stride as usize != 0 {
                return Ok(None);
            }
            info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
                * info.chunk.signum()
        }
    }

    let mut patch = TypedModelPatch::default();
    let mut inputs = tvec!();
    for (ix, &i) in scan_node.inputs.iter().enumerate() {
        let tap = patch.tap_model(model, i)?;
        let ds = patch.wire_node(format!("{}-{}", down_node.name, ix), down_op.clone(), &[tap])?[0];
        inputs.push(ds);
    }
    let scan = patch.wire_node(&*scan_node.name, new_scan, &inputs)?;
    for ix in 0..scan_node.outputs.len() {
        // FIXME need to check earlier on that all output are followed by a ds
        let succ = scan_node.outputs[ix].successors[0].node;
        patch.shunt_outside(model, OutletId::new(succ, 0), scan[ix])?;
    }
    Ok(Some(patch))
}
Examples found in repository?
src/ops/downsample/scan.rs (line 62)
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
pub fn pull_downsample_over_scan(
    model: &TypedModel,
    scan_node: &TypedNode,
    scan_op: &ops::scan::Scan,
    down_node: &TypedNode,
    down_op: &Downsample,
) -> TractResult<Option<TypedModelPatch>> {
    if down_op.stride < 0 {
        return Ok(None);
    }

    // introduce downsample at end of body
    let mut downsampled_body = scan_op.body.clone();
    downsampled_body.check_consistency()?;
    let outputs = downsampled_body.output_outlets()?.to_owned();
    let downsample_outputs = outputs
        .into_iter()
        .enumerate()
        .map(|(ix, oo)| {
            Ok(downsampled_body.wire_node(
                format!("{}-{}", &down_node.name, ix),
                down_op.clone(),
                &[oo],
            )?[0])
        })
        .collect::<TractResult<Vec<_>>>()?;
    downsampled_body.set_output_outlets(&downsample_outputs)?;
    downsampled_body.declutter()?;
    downsampled_body.check_consistency()?;

    // check if downsample ops introduced at end have swimmed up to scan inputs during declutter
    for input in downsampled_body.input_outlets()? {
        let input = downsampled_body.node(input.node);
        if input.outputs[0]
            .successors
            .iter()
            .any(|succ| !downsampled_body.node(succ.node).op().same_as(down_op))
        {
            return Ok(None);
        }
    }

    let inputs = downsampled_body.input_outlets()?.to_vec();
    for input in inputs {
        let node = &mut downsampled_body.node_mut(input.node);
        let fact = &mut node.outputs[0].fact;
        *fact = down_op.transform_fact(fact)?;
        node.op_as_mut::<crate::ops::source::TypedSource>().unwrap().fact = fact.clone();
        let downsamples = downsampled_body.node(input.node).outputs[0].successors.clone();
        for ds in downsamples {
            TypedModelPatch::shunt_one_op(&downsampled_body as _, downsampled_body.node(ds.node))?
                .apply(&mut downsampled_body)?;
        }
    }

    downsampled_body.check_consistency()?;
    let inner_model = downsampled_body.into_decluttered()?;

    let mut new_scan = scan_op.clone();
    new_scan.body = inner_model;
    for input in &mut new_scan.input_mapping {
        match input {
            InputMapping::State { ref mut initializer } => {
                if let StateInitializer::Value(ref v) = initializer {
                    let mut new_v = down_op.eval(tvec!(v.clone().into_tvalue()))?;
                    *initializer = StateInitializer::Value(new_v.remove(0).into_arc_tensor());
                }
            }
            InputMapping::Scan(info) => {
                if info.chunk > 0 && info.chunk as usize % down_op.stride as usize != 0 {
                    return Ok(None);
                }
                info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
                    * info.chunk.signum()
            }
            _ => (),
        }
    }
    for output in &mut new_scan.output_mapping {
        if let Some(d) = output.full_dim_hint.as_mut() {
            *d = down_op.transform_dim(d)
        }
        if let Some(info) = &mut output.scan {
            if info.chunk as usize % down_op.stride as usize != 0 {
                return Ok(None);
            }
            info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
                * info.chunk.signum()
        }
    }

    let mut patch = TypedModelPatch::default();
    let mut inputs = tvec!();
    for (ix, &i) in scan_node.inputs.iter().enumerate() {
        let tap = patch.tap_model(model, i)?;
        let ds = patch.wire_node(format!("{}-{}", down_node.name, ix), down_op.clone(), &[tap])?[0];
        inputs.push(ds);
    }
    let scan = patch.wire_node(&*scan_node.name, new_scan, &inputs)?;
    for ix in 0..scan_node.outputs.len() {
        // FIXME need to check earlier on that all output are followed by a ds
        let succ = scan_node.outputs[ix].successors[0].node;
        patch.shunt_outside(model, OutletId::new(succ, 0), scan[ix])?;
    }
    Ok(Some(patch))
}

Perform declutter passes on the network.

Examples found in repository?
src/model/typed.rs (line 96)
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    pub fn into_optimized(mut self) -> TractResult<TypedModel> {
        self.declutter()?;
        self.optimize()?;
        Ok(self)
    }
    #[cfg(not(all(debug_assertions, feature = "paranoid_assertions")))]
    #[inline]
    pub fn check_consistency(&self) -> TractResult<()> {
        Ok(())
    }

    #[cfg(all(debug_assertions, feature = "paranoid_assertions"))]
    pub fn check_consistency(&self) -> TractResult<()> {
        self.check_edges()?;
        for node_id in &self.eval_order()? {
            let input_facts = self.node_input_facts(*node_id)?;
            let node = &self.nodes[*node_id];
            if node.id != *node_id {
                bail!("Node at position {} has id {}", node_id, node.id);
            }
            let output_facts = node.op.output_facts(&input_facts)?;
            if node.outputs.len() != output_facts.len() {
                bail!(
                    "Inconsistent model, node output count mismatch. Op says {}, node says {}. {}",
                    output_facts.len(),
                    node.outputs.len(),
                    node
                );
            }
            if node
                .outputs
                .iter()
                .map(|o| &o.fact)
                .zip(output_facts.iter())
                .any(|(a, b)| a.datum_type != b.datum_type || a.shape != b.shape)
            {
                bail!(
                            "Inconsistent model, output types mismatch. Op says: {:?}, node says: {:?}. {} with inputs {:?}. {}",
                            output_facts, node.outputs.iter().map(|o| &o.fact).collect::<Vec<_>>(), node, input_facts, node)
            }
        }
        for node in &self.nodes {
            for (ix, output) in node.outputs.iter().enumerate() {
                output.fact.consistent().with_context(|| {
                    format!("Inconsistent fact {:?}: {:?}", OutletId::new(node.id, ix), output.fact)
                })?
            }
        }
        Ok(())
    }

    pub fn into_decluttered(mut self) -> TractResult<TypedModel> {
        self.declutter()?;
        Ok(self)
    }
More examples
Hide additional examples
src/ops/scan/mir.rs (line 211)
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    fn declutter_discard_unused_input_mapping(
        &self,
        _session: &mut OptimizerSession,
        model: &TypedModel,
        node: &TypedNode,
    ) -> TractResult<Option<TypedModelPatch>> {
        for (inner_input_id, input) in self.body.input_outlets()?.iter().enumerate() {
            let source_node = self.body.node(input.node);
            if source_node.outputs[0].successors.len() == 0
                && !self.body.output_outlets()?.contains(input)
            {
                let mut new_inputs = node.inputs.clone();
                let slot = match &self.input_mapping[inner_input_id] {
                    InputMapping::Full { slot } => Some(*slot),
                    InputMapping::Scan(info) => Some(info.slot),
                    InputMapping::State { initializer } => match initializer {
                        StateInitializer::FromInput(n) => Some(*n),
                        _ => None,
                    },
                };
                let mut new_mappings: Vec<_> = self.input_mapping.clone();
                new_mappings.remove(inner_input_id);
                if let Some(slot) = slot {
                    new_mappings = Self::remove_outer_input_from_mappings(&new_mappings, slot);
                }
                let mut model_inputs = self.body.input_outlets()?.to_vec();
                if let Some(slot) = slot {
                    new_inputs.remove(slot);
                }
                model_inputs.remove(inner_input_id);
                let mut body = self.body.clone();
                let mut patch = TypedModelPatch::default();
                patch.obliterate(source_node.id)?;
                patch.apply(&mut body)?;
                body.set_input_outlets(&model_inputs)?;
                body.declutter()?;
                let op = Self {
                    body,
                    skip: self.skip,
                    seq_length_input_slot: self.seq_length_input_slot,
                    input_mapping: new_mappings,
                    decluttered: true,
                    output_mapping: self.output_mapping.clone(),
                };
                return Ok(Some(TypedModelPatch::replace_single_op(model, node, &new_inputs, op)?));
            }
        }
        Ok(None)
    }
src/ops/downsample/scan.rs (line 33)
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
pub fn pull_downsample_over_scan(
    model: &TypedModel,
    scan_node: &TypedNode,
    scan_op: &ops::scan::Scan,
    down_node: &TypedNode,
    down_op: &Downsample,
) -> TractResult<Option<TypedModelPatch>> {
    if down_op.stride < 0 {
        return Ok(None);
    }

    // introduce downsample at end of body
    let mut downsampled_body = scan_op.body.clone();
    downsampled_body.check_consistency()?;
    let outputs = downsampled_body.output_outlets()?.to_owned();
    let downsample_outputs = outputs
        .into_iter()
        .enumerate()
        .map(|(ix, oo)| {
            Ok(downsampled_body.wire_node(
                format!("{}-{}", &down_node.name, ix),
                down_op.clone(),
                &[oo],
            )?[0])
        })
        .collect::<TractResult<Vec<_>>>()?;
    downsampled_body.set_output_outlets(&downsample_outputs)?;
    downsampled_body.declutter()?;
    downsampled_body.check_consistency()?;

    // check if downsample ops introduced at end have swimmed up to scan inputs during declutter
    for input in downsampled_body.input_outlets()? {
        let input = downsampled_body.node(input.node);
        if input.outputs[0]
            .successors
            .iter()
            .any(|succ| !downsampled_body.node(succ.node).op().same_as(down_op))
        {
            return Ok(None);
        }
    }

    let inputs = downsampled_body.input_outlets()?.to_vec();
    for input in inputs {
        let node = &mut downsampled_body.node_mut(input.node);
        let fact = &mut node.outputs[0].fact;
        *fact = down_op.transform_fact(fact)?;
        node.op_as_mut::<crate::ops::source::TypedSource>().unwrap().fact = fact.clone();
        let downsamples = downsampled_body.node(input.node).outputs[0].successors.clone();
        for ds in downsamples {
            TypedModelPatch::shunt_one_op(&downsampled_body as _, downsampled_body.node(ds.node))?
                .apply(&mut downsampled_body)?;
        }
    }

    downsampled_body.check_consistency()?;
    let inner_model = downsampled_body.into_decluttered()?;

    let mut new_scan = scan_op.clone();
    new_scan.body = inner_model;
    for input in &mut new_scan.input_mapping {
        match input {
            InputMapping::State { ref mut initializer } => {
                if let StateInitializer::Value(ref v) = initializer {
                    let mut new_v = down_op.eval(tvec!(v.clone().into_tvalue()))?;
                    *initializer = StateInitializer::Value(new_v.remove(0).into_arc_tensor());
                }
            }
            InputMapping::Scan(info) => {
                if info.chunk > 0 && info.chunk as usize % down_op.stride as usize != 0 {
                    return Ok(None);
                }
                info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
                    * info.chunk.signum()
            }
            _ => (),
        }
    }
    for output in &mut new_scan.output_mapping {
        if let Some(d) = output.full_dim_hint.as_mut() {
            *d = down_op.transform_dim(d)
        }
        if let Some(info) = &mut output.scan {
            if info.chunk as usize % down_op.stride as usize != 0 {
                return Ok(None);
            }
            info.chunk = info.chunk.unsigned_abs().divceil(down_op.stride as usize) as isize
                * info.chunk.signum()
        }
    }

    let mut patch = TypedModelPatch::default();
    let mut inputs = tvec!();
    for (ix, &i) in scan_node.inputs.iter().enumerate() {
        let tap = patch.tap_model(model, i)?;
        let ds = patch.wire_node(format!("{}-{}", down_node.name, ix), down_op.clone(), &[tap])?[0];
        inputs.push(ds);
    }
    let scan = patch.wire_node(&*scan_node.name, new_scan, &inputs)?;
    for ix in 0..scan_node.outputs.len() {
        // FIXME need to check earlier on that all output are followed by a ds
        let succ = scan_node.outputs[ix].successors[0].node;
        patch.shunt_outside(model, OutletId::new(succ, 0), scan[ix])?;
    }
    Ok(Some(patch))
}

Perform optimization passes on the model, using a given optimizer session.

Examples found in repository?
src/ops/scan/mir.rs (line 823)
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
    fn concretize_dims(
        &self,
        _source: &TypedModel,
        node: &TypedNode,
        target: &mut TypedModel,
        mapping: &HashMap<OutletId, OutletId>,
        values: &SymbolValues,
    ) -> TractResult<TVec<OutletId>> {
        let inputs = node.inputs.iter().map(|o| mapping[o]).collect::<TVec<_>>();
        let op = Self {
            output_mapping: self
                .output_mapping
                .iter()
                .map(|om| om.concretize_dims(values))
                .collect::<TractResult<Vec<_>>>()?,
            body: self.body.concretize_dims(values)?,
            ..self.clone()
        };
        target.wire_node(&node.name, op, &inputs)
    }

Translate the graph to locally optimized operators (LIR or MIR ops).

Examples found in repository?
src/model/typed.rs (line 97)
95
96
97
98
99
    pub fn into_optimized(mut self) -> TractResult<TypedModel> {
        self.declutter()?;
        self.optimize()?;
        Ok(self)
    }
Examples found in repository?
src/ops/scan/mir.rs (line 686)
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
    fn invariants(
        &self,
        _inputs: &[&TypedFact],
        _outputs: &[&TypedFact],
    ) -> TractResult<Invariants> {
        let mut invariants = tvec!();
        let body_invs = self.body.invariants().with_context(|| "Computing body invariants")?;
        for body_axis in body_invs.axes {
            let mut info = AxisInfo::default().with_period(1);
            for (ix, input_mapping) in self.input_mapping.iter().enumerate() {
                if let Some(slot) = input_mapping.slot() {
                    while info.inputs.len() <= slot {
                        info.inputs.push(None);
                    }
                    info.inputs[slot] = body_axis.inputs[ix];
                }
            }
            for (ix, output_mapping) in self.output_mapping.iter().enumerate() {
                let mut slots = vec![];
                if let Some(scan) = output_mapping.scan {
                    slots.push(scan.slot);
                }
                if let Some(slot) = output_mapping.last_value_slot {
                    slots.push(slot);
                }
                for slot in slots {
                    while info.outputs.len() <= slot {
                        info.outputs.push(None);
                    }
                    info.outputs[slot] = body_axis.outputs[ix];
                }
            }
            if info.inputs.iter().any(|i| i.is_some()) || info.outputs.iter().any(|i| i.is_some()) {
                info.disposable = body_axis.disposable;
                invariants.push(info);
            }
        }
        Ok(Invariants::from(invariants))
    }

Trait Implementations§

Returns a copy of the value. Read more
Performs copy-assignment from source. Read more
Formats the value using the given formatter. Read more
Returns the “default value” for a type. Read more
Formats the value using the given formatter. Read more
Feeds this value into the given Hasher. Read more
Feeds a slice of this type into the given Hasher. Read more

Auto Trait Implementations§

Blanket Implementations§

Gets the TypeId of self. Read more
Immutably borrows from an owned value. Read more
Mutably borrows from an owned value. Read more
Convert Box<dyn Trait> (where Trait: Downcast) to Box<dyn Any>. Box<dyn Any> can then be further downcast into Box<ConcreteType> where ConcreteType implements Trait.
Convert Rc<Trait> (where Trait: Downcast) to Rc<Any>. Rc<Any> can then be further downcast into Rc<ConcreteType> where ConcreteType implements Trait.
Convert &Trait (where Trait: Downcast) to &Any. This is needed since Rust cannot generate &Any’s vtable from &Trait’s.
Convert &mut Trait (where Trait: Downcast) to &Any. This is needed since Rust cannot generate &mut Any’s vtable from &mut Trait’s.
Convert Arc<Trait> (where Trait: Downcast) to Arc<Any>. Arc<Any> can then be further downcast into Arc<ConcreteType> where ConcreteType implements Trait.

Returns the argument unchanged.

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

The resulting type after obtaining ownership.
Creates owned data from borrowed data, usually by cloning. Read more
Uses borrowed data to replace owned data, usually by cloning. Read more
Converts the given value to a String. Read more
The type returned in the event of a conversion error.
Performs the conversion.
The type returned in the event of a conversion error.
Performs the conversion.