tract_core/model/
memory.rs

1use super::*;
2use crate::prelude::*;
3use std::collections::HashSet;
4use std::fmt::Debug;
5use std::fmt::Display;
6use tract_data::internal::*;
7
8/// Evaluate temporary memory usage with its related node at each step of the given order.
9pub fn eval_tmp_memory_usage<F, O, Flushable>(
10    model: &Graph<F, O>,
11    order: &[usize],
12    flushable: Flushable,
13) -> TractResult<TVec<(usize, TDim)>>
14where
15    F: Fact + Clone + 'static,
16    O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
17    Flushable: Fn(&Node<F, O>) -> bool,
18{
19    let outputs = model.output_outlets()?.to_vec();
20
21    let flush_lists = super::order::build_flush_list(model, order, &outputs, &flushable);
22    let mut values: TVec<bool> = tvec![false; model.nodes.len()];
23
24    let mut mem_by_steps: TVec<_> = tvec![(0, 0.into()); order.len()];
25
26    let flushable_nodes = model
27        .nodes()
28        .iter()
29        .filter(|node| (flushable)(node))
30        .map(|it| it.id)
31        .collect::<HashSet<_>>();
32
33    for (step, n) in order.iter().enumerate() {
34        let node = model.node(*n);
35
36        for flush in flush_lists[step].iter() {
37            values[*flush] = false;
38        }
39
40        // Active nodes are node that has not been flushed + inputs of the current node and current node.
41        let mut step_active_nodes: HashSet<_> =
42            values.iter().enumerate().filter_map(|(n, active)| active.then_some(n)).collect();
43
44        step_active_nodes.extend(node.inputs.iter().map(|it| it.node));
45        step_active_nodes.insert(*n);
46
47        values[*n] = true;
48
49        // Keep only flushable nodes.
50        let step_active_flushable_nodes = step_active_nodes.intersection(&flushable_nodes);
51
52        mem_by_steps[step] = (*n, 0.into());
53
54        for n in step_active_flushable_nodes {
55            let out_facts = model
56                .node_output_facts(*n)?
57                .into_iter()
58                .map(|it| it.to_typed_fact())
59                .collect::<TractResult<TVec<_>>>()?;
60            mem_by_steps[step].1 += out_facts.iter().map(|it| it.mem_size()).sum::<TDim>();
61        }
62    }
63    Ok(mem_by_steps)
64}