1use nu_ansi_term::Style;
2use std::collections::HashMap;
3#[allow(unused_imports)]
4use std::convert::TryFrom;
5use std::time::Duration;
6use tract_core::internal::*;
7use tract_core::ops::scan::Scan;
8use tract_itertools::Itertools;
9use tract_itertools::izip;
10
11use crate::model::Model;
12
13#[derive(Debug, Clone, Hash, PartialEq, Eq)]
14pub struct NodeQId(pub TVec<(usize, String)>, pub usize);
15
16impl From<usize> for NodeQId {
17 fn from(id: usize) -> NodeQId {
18 NodeQId(tvec!(), id)
19 }
20}
21
22impl NodeQId {
23 pub fn model<'a>(&self, model: &'a dyn Model) -> Option<&'a dyn Model> {
24 fn scope<'a>(path: &[(usize, String)], model: &'a dyn Model) -> Option<&'a dyn Model> {
25 if path.is_empty() {
26 Some(model)
27 } else {
28 model
29 .nested_models(path[0].0)
30 .into_iter()
31 .find(|(name, _sub)| name == &path[0].1)
32 .map(|(_, sub)| sub)
33 }
34 }
35 scope(&self.0, model)
36 }
37}
38
39#[derive(Debug, Default, Clone)]
40pub struct NodeTags {
41 pub cost: Vec<(Cost, TDim)>,
42 pub tmp_mem_usage: Option<TDim>,
43 pub style: Option<Style>,
44 pub labels: Vec<String>,
45 pub sections: Vec<Vec<String>>,
46 pub profile: Option<Duration>,
47 pub accelerator_profile: Option<Duration>,
48 pub model_input: Option<String>,
49 pub model_output: Option<String>,
50 pub outlet_labels: Vec<Vec<String>>,
51 pub outlet_axes: Vec<Vec<String>>,
52}
53
54impl<'a> std::ops::Add<&'a NodeTags> for &'a NodeTags {
55 type Output = NodeTags;
56 fn add(self, other: &'a NodeTags) -> NodeTags {
57 let cost = self
58 .cost
59 .iter()
60 .chain(other.cost.iter())
61 .sorted_by_key(|(a, _)| a)
62 .group_by(|(a, _)| a)
63 .into_iter()
64 .map(|(cost, dims)| {
65 (cost.clone(), dims.into_iter().fold(0.to_dim(), |acc, d| acc + &d.1))
66 })
67 .collect::<Vec<(Cost, TDim)>>();
68
69 let tmp_mem_usage = match (self.tmp_mem_usage.clone(), other.tmp_mem_usage.clone()) {
70 (Some(self_mem), Some(other_mem)) => Some(self_mem + other_mem),
71 (_, Some(mem)) | (Some(mem), _) => Some(mem),
72 (None, None) => None,
73 };
74
75 let profile = self.profile.unwrap_or_default() + other.profile.unwrap_or_default();
76 let profile = if profile != Duration::default() { Some(profile) } else { None };
77 let accelerator_profile = self.accelerator_profile.unwrap_or_default()
78 + other.accelerator_profile.unwrap_or_default();
79 let accelerator_profile = if accelerator_profile != Duration::default() {
80 Some(accelerator_profile)
81 } else {
82 None
83 };
84
85 let style = self.style.or(other.style);
86 let labels = self.labels.iter().chain(other.labels.iter()).cloned().collect();
87 let sections = self.sections.iter().chain(other.sections.iter()).cloned().collect();
88 let model_input = self.model_input.clone().or_else(|| other.model_input.clone());
89 let model_output = self.model_output.clone().or_else(|| other.model_output.clone());
90 let outlet_labels = izip!(&self.outlet_labels, &other.outlet_labels)
91 .map(|(s, o)| s.iter().chain(o.iter()).cloned().collect())
92 .collect();
93 let outlet_axes = izip!(&self.outlet_axes, &other.outlet_axes)
94 .map(|(s, o)| s.iter().chain(o.iter()).cloned().collect())
95 .collect();
96 NodeTags {
97 cost,
98 tmp_mem_usage,
99 profile,
100 accelerator_profile,
101 style,
102 labels,
103 sections,
104 model_input,
105 model_output,
106 outlet_labels,
107 outlet_axes,
108 }
109 }
110}
111
112impl<'a> std::iter::Sum<&'a NodeTags> for NodeTags {
113 fn sum<I>(iter: I) -> NodeTags
114 where
115 I: std::iter::Iterator<Item = &'a NodeTags>,
116 {
117 iter.fold(EMPTY, |a, b| &a + b)
118 }
119}
120
121const EMPTY: NodeTags = NodeTags {
122 cost: Vec::new(),
123 tmp_mem_usage: None,
124 style: None,
125 labels: Vec::new(),
126 sections: Vec::new(),
127 profile: None,
128 accelerator_profile: None,
129 model_output: None,
130 model_input: None,
131 outlet_labels: Vec::new(),
132 outlet_axes: Vec::new(),
133};
134
135#[derive(Debug, Clone, Default)]
136pub struct Annotations {
137 pub tags: HashMap<NodeQId, NodeTags>,
138 pub profile_summary: Option<ProfileSummary>,
139 pub memory_summary: Option<MemorySummary>,
140}
141
142impl Annotations {
143 pub fn node_mut(&mut self, qid: NodeQId) -> &mut NodeTags {
144 self.tags.entry(qid).or_default()
145 }
146
147 pub fn track_tmp_memory_usage<Flushable>(
148 &mut self,
149 model: &dyn Model,
150 flushable: Flushable,
151 skip_order_opt_ram: bool,
152 ) -> TractResult<()>
153 where
154 Flushable: Fn(&TypedNode) -> bool,
155 {
156 let Some(model) = model.downcast_ref::<TypedModel>() else { return Ok(()) };
157 let order = if skip_order_opt_ram {
158 tract_core::model::order::eval_order(model)?
159 } else {
160 tract_core::model::order::eval_order_opt_ram(model)?
161 };
162
163 let tmp_mem_usage = model.eval_tmp_memory_usage(&order, &flushable)?;
164
165 let peak_tmp_mem_usage = tmp_mem_usage
166 .iter()
167 .map(|(n, mem)| mem.to_usize().map(|m| (*n, m)))
168 .collect::<TractResult<TVec<_>>>()
169 .ok()
170 .and_then(|mems| {
171 mems.into_iter().map(|(n, mem)| (NodeQId(tvec![], n), mem)).max_by_key(|it| it.1)
172 });
173
174 self.memory_summary =
175 peak_tmp_mem_usage.map(|(n, mem)| MemorySummary { max: mem, max_reached_by_node: n });
176
177 for (n, mem_size) in tmp_mem_usage.into_iter() {
178 let qid = NodeQId(tvec![], n);
179 let tags = self.tags.entry(qid).or_default();
180 tags.tmp_mem_usage = Some(mem_size.simplify());
181 }
182 Ok(())
183 }
184
185 pub fn track_axes(
186 &mut self,
187 model: &dyn Model,
188 hints: &HashMap<OutletId, TVec<String>>,
189 ) -> TractResult<()> {
190 let Some(model) = model.downcast_ref::<TypedModel>() else { return Ok(()) };
191 fn sub(
192 annotations: &mut Annotations,
193 prefix: &[(usize, String)],
194 name_prefix: &str,
195 model: &TypedModel,
196 hints: &HashMap<OutletId, TVec<String>>,
197 ) -> TractResult<()> {
198 let tracking = tract_core::axes::full_axis_tracking(model)?;
199 for (ix, axis) in tracking.iter().enumerate() {
200 let name = axis
201 .creators
202 .iter()
203 .find_map(|cre| hints.get(cre).and_then(|hints| hints.get(axis.outlets[cre])))
204 .cloned()
205 .unwrap_or_else(|| format!("{name_prefix}x{ix}"));
206 for outlet in axis.outlets.keys() {
207 let axis = axis.outlets[&outlet];
208 let qid = NodeQId(prefix.into(), outlet.node);
209 let tags = annotations.tags.entry(qid).or_default();
210 while tags.outlet_axes.len() <= outlet.slot {
211 tags.outlet_axes.push(vec![]);
212 }
213 while tags.outlet_axes[outlet.slot].len() <= axis {
214 tags.outlet_axes[outlet.slot].push(Default::default());
215 }
216 tags.outlet_axes[outlet.slot][axis].clone_from(&name);
217 }
218 }
219 for node in &model.nodes {
220 if let Some(scan) = node.op_as::<Scan>() {
221 let mut prefix: TVec<_> = prefix.into();
222 prefix.push((node.id, "loop".to_string()));
223 sub(
224 annotations,
225 &prefix,
226 &format!("{name_prefix}loop_"),
227 &scan.body,
228 &Default::default(),
229 )?;
230 }
231 }
232 Ok(())
233 }
234 sub(self, &[], "", model, hints)
235 }
236
237 pub fn from_model(model: &dyn Model) -> TractResult<Annotations> {
238 let mut annotations = Annotations::default();
239 fn set_subio_labels(
240 model: &dyn Model,
241 prefix: &[(usize, String)],
242 annotations: &mut Annotations,
243 ) {
244 for n in 0..model.nodes_len() {
245 for output in 0..model.node_output_count(n) {
246 if let Some(label) = model.outlet_label((n, output).into()) {
247 let qid = NodeQId(prefix.into(), n);
248 annotations
249 .tags
250 .entry(qid.clone())
251 .or_default()
252 .outlet_labels
253 .resize(output + 1, vec![]);
254 annotations.tags.entry(qid).or_default().outlet_labels[output] =
255 vec![label.to_string()];
256 }
257 }
258 for (label, sub ) in model.nested_models(n) {
259 let mut prefix: TVec<(usize, String)> = prefix.into();
260 prefix.push((n, label.to_string()));
261 set_subio_labels(sub, &prefix, annotations);
262 }
272 }
273 }
274 set_subio_labels(model, &[], &mut annotations);
275 Ok(annotations)
276 }
277}
278
279#[derive(Debug, Clone)]
280pub struct ProfileSummary {
281 pub max: Duration,
282 pub sum: Duration,
283 pub accel_sum: Duration,
284 pub entire: Duration,
285 pub iters: usize,
286}
287
288#[derive(Debug, Clone)]
289pub struct MemorySummary {
290 pub max: usize,
291 pub max_reached_by_node: NodeQId,
292}