tract_core/model/
memory.rs1use super::*;
2use crate::prelude::*;
3use std::collections::HashSet;
4use std::fmt::Debug;
5use std::fmt::Display;
6use tract_data::internal::*;
7
8pub fn eval_tmp_memory_usage<F, O, Flushable>(
10 model: &Graph<F, O>,
11 order: &[usize],
12 flushable: Flushable,
13) -> TractResult<TVec<(usize, TDim)>>
14where
15 F: Fact + Clone + 'static,
16 O: Debug + Display + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
17 Flushable: Fn(&Node<F, O>) -> bool,
18{
19 let outputs = model.output_outlets()?.to_vec();
20
21 let flush_lists = super::order::build_flush_list(model, order, &outputs, &flushable);
22 let mut values: TVec<bool> = tvec![false; model.nodes.len()];
23
24 let mut mem_by_steps: TVec<_> = tvec![(0, 0.into()); order.len()];
25
26 let flushable_nodes = model
27 .nodes()
28 .iter()
29 .filter(|node| (flushable)(node))
30 .map(|it| it.id)
31 .collect::<HashSet<_>>();
32
33 for (step, n) in order.iter().enumerate() {
34 let node = model.node(*n);
35
36 for flush in flush_lists[step].iter() {
37 values[*flush] = false;
38 }
39
40 let mut step_active_nodes: HashSet<_> =
42 values.iter().enumerate().filter_map(|(n, active)| active.then_some(n)).collect();
43
44 step_active_nodes.extend(node.inputs.iter().map(|it| it.node));
45 step_active_nodes.insert(*n);
46
47 values[*n] = true;
48
49 let step_active_flushable_nodes = step_active_nodes.intersection(&flushable_nodes);
51
52 mem_by_steps[step] = (*n, 0.into());
53
54 for n in step_active_flushable_nodes {
55 let out_facts = model
56 .node_output_facts(*n)?
57 .into_iter()
58 .map(|it| it.to_typed_fact())
59 .collect::<TractResult<TVec<_>>>()?;
60 mem_by_steps[step].1 += out_facts.iter().map(|it| it.mem_size()).sum::<TDim>();
61 }
62 }
63 Ok(mem_by_steps)
64}