1use crate::model::Model;
2use crate::tensor::RunTensors;
3use crate::tensor::make_inputs_for_model;
4use crate::{annotations::*, capture_gpu_trace};
5use std::any::TypeId;
6use std::time::{Duration, Instant};
7use tract_core::internal::*;
8use tract_core::num_traits::Zero;
9use tract_core::ops::submodel::TypedModelOpState;
10
11pub fn reusable_state(runnable: &Arc<dyn Runnable>) -> bool {
12 runnable.typed_model().is_some_and(|model| model.properties().contains_key("pulse.delay"))
13}
14
15pub fn run_one_step(
16 runnable: &Arc<dyn Runnable>,
17 state: &mut Box<dyn State>,
18 inputs: &RunTensors,
19) -> TractResult<Duration> {
20 if !reusable_state(runnable) {
21 *state = runnable.spawn()?;
22 state.init_state(&inputs.state_initializers.clone())?;
23 }
24 let start = Instant::now();
25 for source in &inputs.sources {
26 state.run(source.clone())?;
27 }
28 Ok(start.elapsed())
29}
30
31pub struct BenchLimits {
32 pub warmup_loops: usize,
33 pub warmup_time: std::time::Duration,
34 pub max_loops: usize,
35 pub max_time: std::time::Duration,
36}
37
38impl Default for BenchLimits {
39 fn default() -> Self {
40 BenchLimits {
41 warmup_loops: 0,
42 warmup_time: Duration::default(),
43 max_loops: 100_000,
44 max_time: std::time::Duration::from_secs(5),
45 }
46 }
47}
48
49impl BenchLimits {
50 pub fn warmup(&self, runnable: &Arc<dyn Runnable>, inputs: &RunTensors) -> TractResult<()> {
51 if self.warmup_time.is_zero() && self.warmup_loops.is_zero() {
52 return Ok(());
53 }
54 let reuse = reusable_state(runnable);
55 let mut state = runnable.spawn()?;
56 state.init_state(&inputs.state_initializers.clone())?;
57
58 let mut iters = 0;
59 let max_loops = if self.warmup_loops.is_zero() { usize::MAX } else { self.warmup_loops };
60 let max_time = if self.warmup_time.is_zero() { Duration::MAX } else { self.warmup_time };
61
62 let start_warmup = Instant::now();
63 info!("Warming up before profiling...");
64 while iters < max_loops && start_warmup.elapsed() < max_time {
65 if !reuse {
66 state = runnable.spawn()?;
67 state.init_state(&inputs.state_initializers.clone())?;
68 }
69 state.run(inputs.sources[0].clone())?;
70 iters += 1;
71 }
72 info!("Done warming up.");
73
74 Ok(())
75 }
76
77 pub fn bench(
78 &self,
79 runnable: &Arc<dyn Runnable>,
80 inputs: &RunTensors,
81 ) -> TractResult<(usize, Duration)> {
82 if self.max_time.is_zero() && self.max_loops.is_zero() {
83 return Ok(Default::default());
84 }
85 let reuse = reusable_state(runnable);
86 let mut state = runnable.spawn()?;
87 state.init_state(&inputs.state_initializers.clone())?;
88
89 let mut iters = 0;
90 let max_loops = if self.max_loops.is_zero() { usize::MAX } else { self.max_loops };
91 let max_time = if self.max_time.is_zero() { Duration::MAX } else { self.max_time };
92
93 let mut dur = Duration::default();
94 let start = Instant::now();
95 while iters < max_loops && start.elapsed() < max_time {
96 if !reuse {
97 state = runnable.spawn()?;
98 state.init_state(&inputs.state_initializers.clone())?;
99 }
100 let start_inner = Instant::now();
101 state.run(inputs.sources[0].clone())?;
102 dur += start_inner.elapsed();
103 iters += 1;
104 }
105
106 Ok((iters, dur))
107 }
108}
109
110pub fn profile(
111 runnable: &Arc<dyn Runnable>,
112 bench_limits: &BenchLimits,
113 dg: &mut Annotations,
114 inputs: &RunTensors,
115 custom_profiler: Option<HashMap<TypeId, Profiler>>,
116 folded: bool,
117) -> TractResult<()> {
118 let Some(plan) = runnable.typed_plan() else {
119 bail!("Can only profile TypedRunnable");
120 };
121 info!("Running entire network");
122 let mut iters = 0usize;
123 let prefix = tvec!();
124
125 bench_limits.warmup(runnable, inputs)?;
126
127 let reuse = reusable_state(runnable);
128 let mut state = plan.spawn()?;
129 state.init_state(&inputs.state_initializers.clone())?;
130
131 let mut dur = Duration::default();
132 let mut time_accounted_by_inner_nodes = Duration::default();
133 while iters < bench_limits.max_loops && dur < bench_limits.max_time {
134 if !reuse {
135 state = plan.spawn()?;
136 state.init_state(&inputs.state_initializers.clone())?;
137 }
138 let start = Instant::now();
139
140 for source in &inputs.sources {
141 rec_profiler(
142 &mut state,
143 dg,
144 source,
145 custom_profiler.as_ref(),
146 &prefix,
147 None,
148 &mut time_accounted_by_inner_nodes,
149 folded,
150 )?;
151 }
152 dur += start.elapsed();
153 iters += 1;
154 }
155
156 dur -= time_accounted_by_inner_nodes;
157
158 info!("Running {} iterations max. for each node.", bench_limits.max_loops);
159 info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
160
161 let denum = (iters as f32).recip();
162 let entire = dur.mul_f32(denum);
163 for d in dg.tags.values_mut() {
164 if let Some(d) = d.profile.as_mut() {
165 *d = d.mul_f32(denum);
166 }
167
168 if let Some(d) = d.accelerator_profile.as_mut() {
169 *d = d.mul_f32(denum);
170 }
171 }
172 let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
173 let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
174 let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
175 dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
176 Ok(())
177}
178
179pub fn profile_gpu(
180 runnable: &Arc<dyn Runnable>,
181 bench_limits: &BenchLimits,
182 sub_matches: &clap::ArgMatches,
183 dg: &mut Annotations,
184 inputs: &RunTensors,
185) -> TractResult<()> {
186 let Some(plan) = runnable.typed_plan() else {
187 bail!("Can only profile TypedRunnable");
188 };
189 info!("Running entire network");
190 let mut iters = 0usize;
191 let prefix = tvec!();
192
193 bench_limits.warmup(runnable, inputs)?;
194
195 let reuse = reusable_state(runnable);
196 let mut state = plan.spawn()?;
197 state.init_state(&inputs.state_initializers.clone())?;
198
199 let mut dur = Duration::default();
200
201 capture_gpu_trace(sub_matches, || -> TractResult<()> {
202 while iters < bench_limits.max_loops && dur < bench_limits.max_time {
203 if !reuse {
204 state = plan.spawn()?;
205 state.init_state(&inputs.state_initializers.clone())?;
206 }
207 let start = Instant::now();
208 for source in &inputs.sources {
209 rec_profiler_gpu(&mut state, dg, source, &prefix)?;
210 }
211 dur += start.elapsed();
212 iters += 1;
213 }
214 Ok(())
215 })?;
216
217 info!("Running {} iterations max. for each node.", bench_limits.max_loops);
218 info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
219
220 let denum = (iters as f32).recip();
221 let entire = dur.mul_f32(denum);
222 for d in dg.tags.values_mut() {
223 if let Some(d) = d.profile.as_mut() {
224 *d = d.mul_f32(denum);
225 }
226
227 if let Some(d) = d.accelerator_profile.as_mut() {
228 *d = d.mul_f32(denum);
229 }
230 }
231 let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
232 let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
233 let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
234 dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
235 Ok(())
236}
237
238pub fn rec_profiler_gpu(
239 state: &mut TypedSimpleState,
240 dg: &mut Annotations,
241 inputs: &TVec<TValue>,
242 prefix: &[(usize, String)],
243) -> TractResult<TVec<TValue>> {
244 let r = state.run_plan_with_eval(
245 inputs.clone(),
246 |session_state, mut node_state, node, input| {
247 let start = crate::time::now();
249 let res = tract_core::plan::eval(
250 session_state,
251 node_state.as_deref_mut(),
252 node,
253 input.clone(),
254 );
255 let elapsed = start.elapsed();
256 let node_id = NodeQId(prefix.into(), node.id);
257 *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
258
259 res
260 },
261 )?;
262
263 Ok(r)
264}
265
266#[allow(clippy::too_many_arguments)]
267pub fn rec_profiler(
268 state: &mut TypedSimpleState,
269 dg: &mut Annotations,
270 inputs: &TVec<TValue>,
271 profilers: Option<&HashMap<TypeId, Profiler>>,
272 prefix: &[(usize, String)],
273 multiplier: Option<usize>,
274 time_accounted_by_inner_nodes: &mut Duration,
275 folded: bool,
276) -> TractResult<TVec<TValue>> {
277 let r = state.run_plan_with_eval(
278 inputs.clone(),
279 |session_state, mut node_state, node, input| {
280 let start = crate::time::now();
282 let res = tract_core::plan::eval(
283 session_state,
284 node_state.as_deref_mut(),
285 node,
286 input.clone(),
287 );
288 let elapsed = start.elapsed().mul_f32(multiplier.unwrap_or(1) as _);
289 let node_id = NodeQId(prefix.into(), node.id);
290 *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
291
292 if !folded {
293 let start = crate::time::now();
294 profile_submodel(
295 node,
296 node_state,
297 input,
298 dg,
299 profilers,
300 prefix,
301 time_accounted_by_inner_nodes,
302 )?;
303 *time_accounted_by_inner_nodes += start.elapsed();
304 }
305
306 let prefix_vec = prefix.to_vec();
308 if !prefix_vec.is_empty() {
309 (1..prefix_vec.len() + 1).map(|idx| prefix_vec[..idx].to_vec()).for_each(
310 |parent_path| {
311 let parent_node = parent_path.last().map(|it| it.0).unwrap();
312 let parent = dg
313 .node_mut(NodeQId(
314 parent_path[..parent_path.len() - 1].into(),
315 parent_node,
316 ))
317 .profile
318 .get_or_insert(Duration::default());
319 *parent -= elapsed.min(*parent);
320 },
321 );
322 }
323 res
324 },
325 )?;
326 Ok(r)
327}
328
329fn profile_submodel(
330 node: &TypedNode,
331 mut node_state: Option<&mut dyn OpState>,
332 input: TVec<TValue>,
333 dg: &mut Annotations,
334 profilers: Option<&HashMap<TypeId, Profiler>>,
335 prefix: &[(usize, String)],
336 time_accounted_by_inner_nodes: &mut Duration,
337) -> TractResult<()> {
338 if let Some(ref mut op_state) = node_state {
339 if let Some(profiler) = profilers.and_then(|it| it.get(&op_state.type_id())) {
340 let mut new_prefix: TVec<_> = prefix.into();
341 new_prefix.push((node.id, "submodel".to_string()));
342
343 let (_, _) =
344 (profiler.func)(*op_state, input, dg, &new_prefix, time_accounted_by_inner_nodes)?;
345 } else if let Some(scan_state) = op_state.downcast_mut::<tract_core::ops::scan::State>() {
346 let mut new_prefix: TVec<_> = prefix.into();
347 new_prefix.push((node.id, "loop".to_string()));
348
349 let scan_inputs = make_inputs_for_model(scan_state.model_state.model())?;
350 let multi = scan_state.iteration_count(&input);
351
352 rec_profiler(
353 &mut scan_state.model_state,
354 dg,
355 &scan_inputs,
356 None,
357 &new_prefix,
358 Some(multi),
359 time_accounted_by_inner_nodes,
360 false,
361 )?;
362 } else if let Some(typed_model_state) = op_state.downcast_mut::<TypedModelOpState>() {
363 let mut new_prefix: TVec<_> = prefix.into();
364 new_prefix.push((node.id, "submodel".to_string()));
365
366 rec_profiler(
367 typed_model_state,
368 dg,
369 &input,
370 None,
371 &new_prefix,
372 None,
373 time_accounted_by_inner_nodes,
374 false,
375 )?;
376 }
377 }
378
379 Ok(())
380}
381
382type ProfilerFn = fn(
383 &mut dyn OpState,
384 TVec<TValue>,
385 &mut Annotations,
386 &[(usize, String)],
387 &mut Duration,
388) -> TractResult<(TractResult<TVec<TValue>>, Duration)>;
389
390#[derive(Clone)]
391pub struct Profiler {
392 pub func: ProfilerFn,
393 pub name: &'static str,
394}
395
396impl Hash for Profiler {
397 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
398 self.name.hash(state)
399 }
400}
401
402pub fn extract_costs(
403 annotations: &mut Annotations,
404 model: &dyn Model,
405 extra_symbols: &SymbolValues,
406) -> TractResult<()> {
407 fn extract_costs_rec(
408 annotations: &mut Annotations,
409 model: &dyn Model,
410 prefix: &[(usize, String)],
411 multiplier: TDim,
412 extra_symbols: &SymbolValues,
413 ) -> TractResult<()> {
414 if let Some(model) = model.downcast_ref::<TypedModel>() {
415 for node_id in 0..model.nodes().len() {
416 let inputs = model.node_input_facts(node_id)?;
417 let cost = model
418 .node(node_id)
419 .op
420 .cost(&inputs)
421 .with_context(|| format!("costing node {}", model.node(node_id)))?;
422 annotations.node_mut(NodeQId(prefix.into(), node_id)).cost = cost
423 .into_iter()
424 .map(|(k, v)| {
425 let cost = if k.is_compute() { v * &multiplier } else { v };
426 (k, cost.eval(extra_symbols))
427 })
428 .collect();
429
430 let nested_subs = model.nested_models(node_id);
431 let nested_multis = (model as &dyn Model).nested_models_iters(node_id, &inputs);
432 for (name, sub) in nested_subs {
433 let mut prefix: TVec<_> = prefix.into();
434 prefix.push((node_id, name.to_string()));
435 extract_costs_rec(
436 annotations,
437 sub,
438 &prefix,
439 nested_multis.clone().unwrap_or_else(|| 1.into()) * &multiplier,
440 extra_symbols,
441 )?;
442 }
443 }
444 }
445 Ok(())
446 }
447 extract_costs_rec(annotations, model, &[], 1.into(), extra_symbols)
448}