1use crate::model::Model;
2use crate::tensor::RunTensors;
3use crate::tensor::make_inputs_for_model;
4use crate::{annotations::*, capture_gpu_trace};
5use std::any::TypeId;
6use std::time::{Duration, Instant};
7use tract_core::internal::*;
8use tract_core::num_traits::Zero;
9use tract_core::ops::submodel::TypedModelOpState;
10
11pub fn reusable_state(runnable: &Arc<dyn Runnable>) -> bool {
12 runnable.typed_model().is_some_and(|model| model.properties().contains_key("pulse.delay"))
13}
14
15pub fn run_one_step(
16 runnable: &Arc<dyn Runnable>,
17 state: &mut Box<dyn State>,
18 inputs: &RunTensors,
19) -> TractResult<Duration> {
20 if !reusable_state(runnable) {
21 *state = runnable.spawn()?;
22 }
23 let start = Instant::now();
24 for source in &inputs.sources {
25 state.run(source.clone())?;
26 }
27 Ok(start.elapsed())
28}
29
30pub struct BenchLimits {
31 pub warmup_loops: usize,
32 pub warmup_time: std::time::Duration,
33 pub max_loops: usize,
34 pub max_time: std::time::Duration,
35}
36
37impl Default for BenchLimits {
38 fn default() -> Self {
39 BenchLimits {
40 warmup_loops: 0,
41 warmup_time: Duration::default(),
42 max_loops: 100_000,
43 max_time: std::time::Duration::from_secs(5),
44 }
45 }
46}
47
48impl BenchLimits {
49 pub fn warmup(&self, runnable: &Arc<dyn Runnable>, inputs: &RunTensors) -> TractResult<()> {
50 if self.warmup_time.is_zero() && self.warmup_loops.is_zero() {
51 return Ok(());
52 }
53 let reuse = reusable_state(runnable);
54 let mut state = runnable.spawn()?;
55
56 let mut iters = 0;
57 let max_loops = if self.warmup_loops.is_zero() { usize::MAX } else { self.warmup_loops };
58 let max_time = if self.warmup_time.is_zero() { Duration::MAX } else { self.warmup_time };
59
60 let start_warmup = Instant::now();
61 info!("Warming up before profiling...");
62 while iters < max_loops && start_warmup.elapsed() < max_time {
63 if !reuse {
64 state = runnable.spawn()?;
65 }
66 state.run(inputs.sources[0].clone())?;
67 iters += 1;
68 }
69 info!("Done warming up.");
70
71 Ok(())
72 }
73
74 pub fn bench(
75 &self,
76 runnable: &Arc<dyn Runnable>,
77 inputs: &RunTensors,
78 ) -> TractResult<(usize, Duration)> {
79 if self.max_time.is_zero() && self.max_loops.is_zero() {
80 return Ok(Default::default());
81 }
82 let reuse = reusable_state(runnable);
83 let mut state = runnable.spawn()?;
84
85 let mut iters = 0;
86 let max_loops = if self.max_loops.is_zero() { usize::MAX } else { self.max_loops };
87 let max_time = if self.max_time.is_zero() { Duration::MAX } else { self.max_time };
88
89 let mut dur = Duration::default();
90 let start = Instant::now();
91 while iters < max_loops && start.elapsed() < max_time {
92 if !reuse {
93 state = runnable.spawn()?;
94 }
95 let start_inner = Instant::now();
96 state.run(inputs.sources[0].clone())?;
97 dur += start_inner.elapsed();
98 iters += 1;
99 }
100
101 Ok((iters, dur))
102 }
103}
104
105pub fn profile(
106 runnable: &Arc<dyn Runnable>,
107 bench_limits: &BenchLimits,
108 dg: &mut Annotations,
109 inputs: &RunTensors,
110 custom_profiler: Option<HashMap<TypeId, Profiler>>,
111 folded: bool,
112) -> TractResult<()> {
113 let Some(plan) = runnable.typed_plan() else {
114 bail!("Can only profile TypedRunnable");
115 };
116 info!("Running entire network");
117 let mut iters = 0usize;
118 let prefix = tvec!();
119
120 bench_limits.warmup(runnable, inputs)?;
121
122 let reuse = reusable_state(runnable);
123 let mut state = plan.spawn()?;
124
125 let mut dur = Duration::default();
126 let mut time_accounted_by_inner_nodes = Duration::default();
127 while iters < bench_limits.max_loops && dur < bench_limits.max_time {
128 if !reuse {
129 state = plan.spawn()?;
130 }
131 let start = Instant::now();
132
133 for source in &inputs.sources {
134 rec_profiler(
135 &mut state,
136 dg,
137 source,
138 custom_profiler.as_ref(),
139 &prefix,
140 None,
141 &mut time_accounted_by_inner_nodes,
142 folded,
143 )?;
144 }
145 dur += start.elapsed();
146 iters += 1;
147 }
148
149 dur -= time_accounted_by_inner_nodes;
150
151 info!("Running {} iterations max. for each node.", bench_limits.max_loops);
152 info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
153
154 let denum = (iters as f32).recip();
155 let entire = dur.mul_f32(denum);
156 for d in dg.tags.values_mut() {
157 if let Some(d) = d.profile.as_mut() {
158 *d = d.mul_f32(denum);
159 }
160
161 if let Some(d) = d.accelerator_profile.as_mut() {
162 *d = d.mul_f32(denum);
163 }
164 }
165 let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
166 let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
167 let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
168 dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
169 Ok(())
170}
171
172pub fn profile_gpu(
173 runnable: &Arc<dyn Runnable>,
174 bench_limits: &BenchLimits,
175 sub_matches: &clap::ArgMatches,
176 dg: &mut Annotations,
177 inputs: &RunTensors,
178) -> TractResult<()> {
179 let Some(plan) = runnable.typed_plan() else {
180 bail!("Can only profile TypedRunnable");
181 };
182 info!("Running entire network");
183 let mut iters = 0usize;
184 let prefix = tvec!();
185
186 bench_limits.warmup(runnable, inputs)?;
187
188 let reuse = reusable_state(runnable);
189 let mut state = plan.spawn()?;
190
191 let mut dur = Duration::default();
192
193 capture_gpu_trace(sub_matches, || -> TractResult<()> {
194 while iters < bench_limits.max_loops && dur < bench_limits.max_time {
195 if !reuse {
196 state = plan.spawn()?;
197 }
198 let start = Instant::now();
199 for source in &inputs.sources {
200 rec_profiler_gpu(&mut state, dg, source, &prefix)?;
201 }
202 dur += start.elapsed();
203 iters += 1;
204 }
205 Ok(())
206 })?;
207
208 info!("Running {} iterations max. for each node.", bench_limits.max_loops);
209 info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
210
211 let denum = (iters as f32).recip();
212 let entire = dur.mul_f32(denum);
213 for d in dg.tags.values_mut() {
214 if let Some(d) = d.profile.as_mut() {
215 *d = d.mul_f32(denum);
216 }
217
218 if let Some(d) = d.accelerator_profile.as_mut() {
219 *d = d.mul_f32(denum);
220 }
221 }
222 let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
223 let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
224 let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
225 dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
226 Ok(())
227}
228
229pub fn rec_profiler_gpu(
230 state: &mut TypedSimpleState,
231 dg: &mut Annotations,
232 inputs: &TVec<TValue>,
233 prefix: &[(usize, String)],
234) -> TractResult<TVec<TValue>> {
235 let r = state.run_plan_with_eval(
236 inputs.clone(),
237 |session_state, mut node_state, node, input| {
238 let start = crate::time::now();
240 let res = tract_core::plan::eval(
241 session_state,
242 node_state.as_deref_mut(),
243 node,
244 input.clone(),
245 );
246 let elapsed = start.elapsed();
247 let node_id = NodeQId(prefix.into(), node.id);
248 *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
249
250 res
251 },
252 )?;
253
254 Ok(r)
255}
256
257#[allow(clippy::too_many_arguments)]
258pub fn rec_profiler(
259 state: &mut TypedSimpleState,
260 dg: &mut Annotations,
261 inputs: &TVec<TValue>,
262 profilers: Option<&HashMap<TypeId, Profiler>>,
263 prefix: &[(usize, String)],
264 multiplier: Option<usize>,
265 time_accounted_by_inner_nodes: &mut Duration,
266 folded: bool,
267) -> TractResult<TVec<TValue>> {
268 let r = state.run_plan_with_eval(
269 inputs.clone(),
270 |session_state, mut node_state, node, input| {
271 let start = crate::time::now();
273 let res = tract_core::plan::eval(
274 session_state,
275 node_state.as_deref_mut(),
276 node,
277 input.clone(),
278 );
279 let elapsed = start.elapsed().mul_f32(multiplier.unwrap_or(1) as _);
280 let node_id = NodeQId(prefix.into(), node.id);
281 *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
282
283 if !folded {
284 let start = crate::time::now();
285 profile_submodel(
286 node,
287 node_state,
288 input,
289 dg,
290 profilers,
291 prefix,
292 time_accounted_by_inner_nodes,
293 )?;
294 *time_accounted_by_inner_nodes += start.elapsed();
295 }
296
297 let prefix_vec = prefix.to_vec();
299 if !prefix_vec.is_empty() {
300 (1..prefix_vec.len() + 1).map(|idx| prefix_vec[..idx].to_vec()).for_each(
301 |parent_path| {
302 let parent_node = parent_path.last().map(|it| it.0).unwrap();
303 let parent = dg
304 .node_mut(NodeQId(
305 parent_path[..parent_path.len() - 1].into(),
306 parent_node,
307 ))
308 .profile
309 .get_or_insert(Duration::default());
310 *parent -= elapsed.min(*parent);
311 },
312 );
313 }
314 res
315 },
316 )?;
317 Ok(r)
318}
319
320fn profile_submodel(
321 node: &TypedNode,
322 mut node_state: Option<&mut dyn OpState>,
323 input: TVec<TValue>,
324 dg: &mut Annotations,
325 profilers: Option<&HashMap<TypeId, Profiler>>,
326 prefix: &[(usize, String)],
327 time_accounted_by_inner_nodes: &mut Duration,
328) -> TractResult<()> {
329 if let Some(ref mut op_state) = node_state {
330 if let Some(profiler) = profilers.and_then(|it| it.get(&op_state.type_id())) {
331 let mut new_prefix: TVec<_> = prefix.into();
332 new_prefix.push((node.id, "submodel".to_string()));
333
334 let (_, _) =
335 (profiler.func)(*op_state, input, dg, &new_prefix, time_accounted_by_inner_nodes)?;
336 } else if let Some(scan_state) = op_state.downcast_mut::<tract_core::ops::scan::State>() {
337 let mut new_prefix: TVec<_> = prefix.into();
338 new_prefix.push((node.id, "loop".to_string()));
339
340 let scan_inputs = make_inputs_for_model(scan_state.model_state.model())?;
341 let multi = scan_state.iteration_count(&input);
342
343 rec_profiler(
344 &mut scan_state.model_state,
345 dg,
346 &scan_inputs,
347 None,
348 &new_prefix,
349 Some(multi),
350 time_accounted_by_inner_nodes,
351 false,
352 )?;
353 } else if let Some(typed_model_state) = op_state.downcast_mut::<TypedModelOpState>() {
354 let mut new_prefix: TVec<_> = prefix.into();
355 new_prefix.push((node.id, "submodel".to_string()));
356
357 rec_profiler(
358 typed_model_state,
359 dg,
360 &input,
361 None,
362 &new_prefix,
363 None,
364 time_accounted_by_inner_nodes,
365 false,
366 )?;
367 }
368 }
369
370 Ok(())
371}
372
373type ProfilerFn = fn(
374 &mut dyn OpState,
375 TVec<TValue>,
376 &mut Annotations,
377 &[(usize, String)],
378 &mut Duration,
379) -> TractResult<(TractResult<TVec<TValue>>, Duration)>;
380
381#[derive(Clone)]
382pub struct Profiler {
383 pub func: ProfilerFn,
384 pub name: &'static str,
385}
386
387impl Hash for Profiler {
388 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
389 self.name.hash(state)
390 }
391}
392
393pub fn extract_costs(
394 annotations: &mut Annotations,
395 model: &dyn Model,
396 extra_symbols: &SymbolValues,
397) -> TractResult<()> {
398 fn extract_costs_rec(
399 annotations: &mut Annotations,
400 model: &dyn Model,
401 prefix: &[(usize, String)],
402 multiplier: TDim,
403 extra_symbols: &SymbolValues,
404 ) -> TractResult<()> {
405 if let Some(model) = model.downcast_ref::<TypedModel>() {
406 for node_id in 0..model.nodes().len() {
407 let inputs = model.node_input_facts(node_id)?;
408 let cost = model
409 .node(node_id)
410 .op
411 .cost(&inputs)
412 .with_context(|| format!("costing node {}", model.node(node_id)))?;
413 annotations.node_mut(NodeQId(prefix.into(), node_id)).cost = cost
414 .into_iter()
415 .map(|(k, v)| {
416 let cost = if k.is_compute() { v * &multiplier } else { v };
417 (k, cost.eval(extra_symbols))
418 })
419 .collect();
420
421 let nested_subs = model.nested_models(node_id);
422 let nested_multis = (model as &dyn Model).nested_models_iters(node_id, &inputs);
423 for (name, sub) in nested_subs {
424 let mut prefix: TVec<_> = prefix.into();
425 prefix.push((node_id, name.to_string()));
426 extract_costs_rec(
427 annotations,
428 sub,
429 &prefix,
430 nested_multis.clone().unwrap_or_else(|| 1.into()) * &multiplier,
431 extra_symbols,
432 )?;
433 }
434 }
435 }
436 Ok(())
437 }
438 extract_costs_rec(annotations, model, &[], 1.into(), extra_symbols)
439}