1use crate::model::Model;
2use crate::tensor::RunTensors;
3use crate::tensor::make_inputs_for_model;
4use crate::{annotations::*, capture_gpu_trace};
5use std::any::TypeId;
6use std::time::{Duration, Instant};
7use tract_core::internal::*;
8use tract_core::num_traits::Zero;
9use tract_core::ops::scan::State;
10use tract_core::ops::submodel::TypedModelOpState;
11
12pub struct BenchLimits {
13 pub warmup_loops: usize,
14 pub warmup_time: std::time::Duration,
15 pub max_loops: usize,
16 pub max_time: std::time::Duration,
17}
18
19impl Default for BenchLimits {
20 fn default() -> Self {
21 BenchLimits {
22 warmup_loops: 0,
23 warmup_time: Duration::default(),
24 max_loops: 100_000,
25 max_time: std::time::Duration::from_secs(5),
26 }
27 }
28}
29
30impl BenchLimits {
31 pub fn warmup(&self, model: &TypedModel, inputs: &RunTensors) -> TractResult<()> {
32 if self.warmup_time.is_zero() && self.warmup_loops.is_zero() {
33 return Ok(());
34 }
35 let plan = TypedSimplePlan::new(model.clone())?;
36 let mut state = TypedSimpleState::new(Arc::new(plan))?;
37 let mut iters = 0;
38 let max_loops = if self.warmup_loops.is_zero() { usize::MAX } else { self.warmup_loops };
39 let max_time = if self.warmup_time.is_zero() { Duration::MAX } else { self.warmup_time };
40
41 let start_warmup = Instant::now();
42 debug!("Warming up before profiling...");
43 while iters < max_loops && start_warmup.elapsed() < max_time {
44 if state.model().properties().contains_key("pulse.delay") {
45 state.run(inputs.sources[0].clone())?;
46 } else {
47 state.init_states(&mut inputs.state_initializers.clone())?;
48 state.run(inputs.sources[0].clone())?;
49 state.reset_op_states()?
50 }
51 iters += 1;
52 }
53 debug!("Done warming up.");
54
55 Ok(())
56 }
57}
58
59pub fn profile(
60 model: &TypedModel,
61 bench_limits: &BenchLimits,
62 dg: &mut Annotations,
63 plan_options: &PlanOptions,
64 inputs: &RunTensors,
65 custom_profiler: Option<HashMap<TypeId, Profiler>>,
66 folded: bool,
67) -> TractResult<()> {
68 info!("Running entire network");
69 let mut iters = 0usize;
70 let prefix = tvec!();
71
72 bench_limits.warmup(model, inputs)?;
73
74 let plan = TypedSimplePlan::new_with_options(model.clone(), plan_options)?;
75 let mut state = TypedSimpleState::new(Arc::new(plan))?;
76
77 let mut dur = Duration::default();
78 let mut time_accounted_by_inner_nodes = Duration::default();
79 while iters < bench_limits.max_loops && dur < bench_limits.max_time {
80 if !state.model().properties().contains_key("pulse.delay") {
81 state.init_states(&mut inputs.state_initializers.clone())?;
82 }
83 let start = Instant::now();
84 rec_profiler(
85 &mut state,
86 dg,
87 &inputs.sources[0],
88 custom_profiler.as_ref(),
89 &prefix,
90 None,
91 &mut time_accounted_by_inner_nodes,
92 folded,
93 )?;
94 dur += start.elapsed();
95 if !state.model().properties().contains_key("pulse.delay") {
96 state.reset_op_states()?;
97 }
98 iters += 1;
99 }
100
101 dur -= time_accounted_by_inner_nodes;
102
103 info!("Running {} iterations max. for each node.", bench_limits.max_loops);
104 info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
105
106 let denum = (iters as f32).recip();
107 let entire = dur.mul_f32(denum);
108 for d in dg.tags.values_mut() {
109 if let Some(d) = d.profile.as_mut() {
110 *d = d.mul_f32(denum);
111 }
112
113 if let Some(d) = d.accelerator_profile.as_mut() {
114 *d = d.mul_f32(denum);
115 }
116 }
117 let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
118 let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
119 let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
120 dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
121 Ok(())
122}
123
124pub fn profile_gpu(
125 model: &TypedModel,
126 bench_limits: &BenchLimits,
127 sub_matches: &clap::ArgMatches,
128 dg: &mut Annotations,
129 plan_options: &PlanOptions,
130 inputs: &RunTensors,
131) -> TractResult<()> {
132 info!("Running entire network");
133 let mut iters = 0usize;
134 let prefix = tvec!();
135
136 bench_limits.warmup(model, inputs)?;
137
138 let mut plan = TypedSimplePlan::new_with_options(model.clone(), plan_options)?;
139 let state = TypedSimpleState::new_from_inputs(&plan, inputs.sources[0].clone())?;
140
141 let session_handler = tract_gpu::session_handler::DeviceSessionHandler::from_plan(
142 &plan,
143 &state.session_state.resolved_symbols,
144 )?;
145
146 plan = plan.with_session_handler(session_handler);
147
148 let mut state = TypedSimpleState::new(Arc::new(plan))?;
149 let mut dur = Duration::default();
150
151 capture_gpu_trace(sub_matches, || -> TractResult<()> {
152 while iters < bench_limits.max_loops && dur < bench_limits.max_time {
153 if !state.model().properties().contains_key("pulse.delay") {
154 state.init_states(&mut inputs.state_initializers.clone())?;
155 }
156 let start = Instant::now();
157 rec_profiler_gpu(&mut state, dg, &inputs.sources[0], &prefix)?;
158 dur += start.elapsed();
159 if !state.model().properties().contains_key("pulse.delay") {
160 state.reset_op_states()?;
161 }
162 iters += 1;
163 }
164 Ok(())
165 })?;
166
167 info!("Running {} iterations max. for each node.", bench_limits.max_loops);
168 info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
169
170 let denum = (iters as f32).recip();
171 let entire = dur.mul_f32(denum);
172 for d in dg.tags.values_mut() {
173 if let Some(d) = d.profile.as_mut() {
174 *d = d.mul_f32(denum);
175 }
176
177 if let Some(d) = d.accelerator_profile.as_mut() {
178 *d = d.mul_f32(denum);
179 }
180 }
181 let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
182 let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
183 let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
184 dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
185 Ok(())
186}
187
188pub fn rec_profiler_gpu(
189 state: &mut TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>,
190 dg: &mut Annotations,
191 inputs: &TVec<TValue>,
192 prefix: &[(usize, String)],
193) -> TractResult<TVec<TValue>> {
194 let r = state.run_plan_with_eval(
195 inputs.clone(),
196 |session_state, mut node_state, node, input| {
197 let start = crate::time::now();
199 let res = tract_core::plan::eval(
200 session_state,
201 node_state.as_deref_mut(),
202 node,
203 input.clone(),
204 );
205 let elapsed = start.elapsed();
206 let node_id = NodeQId(prefix.into(), node.id);
207 *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
208
209 res
210 },
211 )?;
212
213 Ok(r)
214}
215
216#[allow(clippy::too_many_arguments)]
217pub fn rec_profiler(
218 state: &mut TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>,
219 dg: &mut Annotations,
220 inputs: &TVec<TValue>,
221 profilers: Option<&HashMap<TypeId, Profiler>>,
222 prefix: &[(usize, String)],
223 multiplier: Option<usize>,
224 time_accounted_by_inner_nodes: &mut Duration,
225 folded: bool,
226) -> TractResult<TVec<TValue>> {
227 let r = state.run_plan_with_eval(
228 inputs.clone(),
229 |session_state, mut node_state, node, input| {
230 let start = crate::time::now();
232 let res = tract_core::plan::eval(
233 session_state,
234 node_state.as_deref_mut(),
235 node,
236 input.clone(),
237 );
238 let elapsed = start.elapsed().mul_f32(multiplier.unwrap_or(1) as _);
239 let node_id = NodeQId(prefix.into(), node.id);
240 *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
241
242 if !folded {
243 let start = crate::time::now();
244 profile_submodel(
245 node,
246 node_state,
247 input,
248 dg,
249 profilers,
250 prefix,
251 time_accounted_by_inner_nodes,
252 )?;
253 *time_accounted_by_inner_nodes += start.elapsed();
254 }
255
256 let prefix_vec = prefix.to_vec();
258 if !prefix_vec.is_empty() {
259 (1..prefix_vec.len() + 1).map(|idx| prefix_vec[..idx].to_vec()).for_each(
260 |parent_path| {
261 let parent_node = parent_path.last().map(|it| it.0).unwrap();
262 let parent = dg
263 .node_mut(NodeQId(
264 parent_path[..parent_path.len() - 1].into(),
265 parent_node,
266 ))
267 .profile
268 .get_or_insert(Duration::default());
269 *parent -= elapsed.min(*parent);
270 },
271 );
272 }
273 res
274 },
275 )?;
276 Ok(r)
277}
278
279fn profile_submodel(
280 node: &TypedNode,
281 mut node_state: Option<&mut dyn OpState>,
282 input: TVec<TValue>,
283 dg: &mut Annotations,
284 profilers: Option<&HashMap<TypeId, Profiler>>,
285 prefix: &[(usize, String)],
286 time_accounted_by_inner_nodes: &mut Duration,
287) -> TractResult<()> {
288 if let Some(ref mut op_state) = node_state {
289 if let Some(profiler) = profilers.and_then(|it| it.get(&op_state.type_id())) {
290 let mut new_prefix: TVec<_> = prefix.into();
291 new_prefix.push((node.id, "submodel".to_string()));
292
293 let (_, _) =
294 (profiler.func)(*op_state, input, dg, &new_prefix, time_accounted_by_inner_nodes)?;
295 } else if let Some(scan_state) = op_state.downcast_mut::<State>() {
296 let mut new_prefix: TVec<_> = prefix.into();
297 new_prefix.push((node.id, "loop".to_string()));
298
299 let scan_inputs = make_inputs_for_model(scan_state.model_state.model())?;
300 let multi = scan_state.iteration_count(&input);
301
302 rec_profiler(
303 &mut scan_state.model_state,
304 dg,
305 &scan_inputs,
306 None,
307 &new_prefix,
308 Some(multi),
309 time_accounted_by_inner_nodes,
310 false,
311 )?;
312 } else if let Some(typed_model_state) = op_state.downcast_mut::<TypedModelOpState>() {
313 let mut new_prefix: TVec<_> = prefix.into();
314 new_prefix.push((node.id, "submodel".to_string()));
315
316 rec_profiler(
317 typed_model_state,
318 dg,
319 &input,
320 None,
321 &new_prefix,
322 None,
323 time_accounted_by_inner_nodes,
324 false,
325 )?;
326 }
327 }
328
329 Ok(())
330}
331
332type ProfilerFn = fn(
333 &mut dyn OpState,
334 TVec<TValue>,
335 &mut Annotations,
336 &[(usize, String)],
337 &mut Duration,
338) -> TractResult<(TractResult<TVec<TValue>>, Duration)>;
339
340#[derive(Clone)]
341pub struct Profiler {
342 pub func: ProfilerFn,
343 pub name: &'static str,
344}
345
346impl Hash for Profiler {
347 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
348 self.name.hash(state)
349 }
350}
351
352pub fn extract_costs(
353 annotations: &mut Annotations,
354 model: &dyn Model,
355 extra_symbols: &SymbolValues,
356) -> TractResult<()> {
357 fn extract_costs_rec(
358 annotations: &mut Annotations,
359 model: &dyn Model,
360 prefix: &[(usize, String)],
361 multiplier: TDim,
362 extra_symbols: &SymbolValues,
363 ) -> TractResult<()> {
364 if let Some(model) = model.downcast_ref::<TypedModel>() {
365 for node_id in 0..model.nodes().len() {
366 let inputs = model.node_input_facts(node_id)?;
367 let cost = model
368 .node(node_id)
369 .op
370 .cost(&inputs)
371 .with_context(|| format!("costing node {}", model.node(node_id)))?;
372 annotations.node_mut(NodeQId(prefix.into(), node_id)).cost = cost
373 .into_iter()
374 .map(|(k, v)| {
375 let cost = if k.is_compute() { v * &multiplier } else { v };
376 (k, cost.eval(extra_symbols))
377 })
378 .collect();
379
380 let nested_subs = model.nested_models(node_id);
381 let nested_multis = (model as &dyn Model).nested_models_iters(node_id, &inputs);
382 for (name, sub) in nested_subs {
383 let mut prefix: TVec<_> = prefix.into();
384 prefix.push((node_id, name.to_string()));
385 extract_costs_rec(
386 annotations,
387 sub,
388 &prefix,
389 nested_multis.clone().unwrap_or_else(|| 1.into()) * &multiplier,
390 extra_symbols,
391 )?;
392 }
393 }
394 }
395 Ok(())
396 }
397 extract_costs_rec(annotations, model, &[], 1.into(), extra_symbols)
398}