tract_libcli/
profile.rs

1use crate::model::Model;
2use crate::tensor::RunTensors;
3use crate::tensor::make_inputs_for_model;
4use crate::{annotations::*, capture_gpu_trace};
5use std::any::TypeId;
6use std::time::{Duration, Instant};
7use tract_core::internal::*;
8use tract_core::num_traits::Zero;
9use tract_core::ops::scan::State;
10use tract_core::ops::submodel::TypedModelOpState;
11
12pub struct BenchLimits {
13    pub warmup_loops: usize,
14    pub warmup_time: std::time::Duration,
15    pub max_loops: usize,
16    pub max_time: std::time::Duration,
17}
18
19impl Default for BenchLimits {
20    fn default() -> Self {
21        BenchLimits {
22            warmup_loops: 0,
23            warmup_time: Duration::default(),
24            max_loops: 100_000,
25            max_time: std::time::Duration::from_secs(5),
26        }
27    }
28}
29
30impl BenchLimits {
31    pub fn warmup(&self, model: &TypedModel, inputs: &RunTensors) -> TractResult<()> {
32        if self.warmup_time.is_zero() && self.warmup_loops.is_zero() {
33            return Ok(());
34        }
35        let plan = TypedSimplePlan::new(model.clone())?;
36        let mut state = TypedSimpleState::new(Arc::new(plan))?;
37        let mut iters = 0;
38        let max_loops = if self.warmup_loops.is_zero() { usize::MAX } else { self.warmup_loops };
39        let max_time = if self.warmup_time.is_zero() { Duration::MAX } else { self.warmup_time };
40
41        let start_warmup = Instant::now();
42        debug!("Warming up before profiling...");
43        while iters < max_loops && start_warmup.elapsed() < max_time {
44            if state.model().properties().contains_key("pulse.delay") {
45                state.run(inputs.sources[0].clone())?;
46            } else {
47                state.init_states(&mut inputs.state_initializers.clone())?;
48                state.run(inputs.sources[0].clone())?;
49                state.reset_op_states()?
50            }
51            iters += 1;
52        }
53        debug!("Done warming up.");
54
55        Ok(())
56    }
57}
58
59pub fn profile(
60    model: &TypedModel,
61    bench_limits: &BenchLimits,
62    dg: &mut Annotations,
63    plan_options: &PlanOptions,
64    inputs: &RunTensors,
65    custom_profiler: Option<HashMap<TypeId, Profiler>>,
66    folded: bool,
67) -> TractResult<()> {
68    info!("Running entire network");
69    let mut iters = 0usize;
70    let prefix = tvec!();
71
72    bench_limits.warmup(model, inputs)?;
73
74    let plan = TypedSimplePlan::new_with_options(model.clone(), plan_options)?;
75    let mut state = TypedSimpleState::new(Arc::new(plan))?;
76
77    let mut dur = Duration::default();
78    let mut time_accounted_by_inner_nodes = Duration::default();
79    while iters < bench_limits.max_loops && dur < bench_limits.max_time {
80        if !state.model().properties().contains_key("pulse.delay") {
81            state.init_states(&mut inputs.state_initializers.clone())?;
82        }
83        let start = Instant::now();
84        rec_profiler(
85            &mut state,
86            dg,
87            &inputs.sources[0],
88            custom_profiler.as_ref(),
89            &prefix,
90            None,
91            &mut time_accounted_by_inner_nodes,
92            folded,
93        )?;
94        dur += start.elapsed();
95        if !state.model().properties().contains_key("pulse.delay") {
96            state.reset_op_states()?;
97        }
98        iters += 1;
99    }
100
101    dur -= time_accounted_by_inner_nodes;
102
103    info!("Running {} iterations max. for each node.", bench_limits.max_loops);
104    info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
105
106    let denum = (iters as f32).recip();
107    let entire = dur.mul_f32(denum);
108    for d in dg.tags.values_mut() {
109        if let Some(d) = d.profile.as_mut() {
110            *d = d.mul_f32(denum);
111        }
112
113        if let Some(d) = d.accelerator_profile.as_mut() {
114            *d = d.mul_f32(denum);
115        }
116    }
117    let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
118    let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
119    let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
120    dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
121    Ok(())
122}
123
124pub fn profile_gpu(
125    model: &TypedModel,
126    bench_limits: &BenchLimits,
127    sub_matches: &clap::ArgMatches,
128    dg: &mut Annotations,
129    plan_options: &PlanOptions,
130    inputs: &RunTensors,
131) -> TractResult<()> {
132    info!("Running entire network");
133    let mut iters = 0usize;
134    let prefix = tvec!();
135
136    bench_limits.warmup(model, inputs)?;
137
138    let mut plan = TypedSimplePlan::new_with_options(model.clone(), plan_options)?;
139    let state = TypedSimpleState::new_from_inputs(&plan, inputs.sources[0].clone())?;
140
141    let session_handler = tract_gpu::session_handler::DeviceSessionHandler::from_plan(
142        &plan,
143        &state.session_state.resolved_symbols,
144    )?;
145
146    plan = plan.with_session_handler(session_handler);
147
148    let mut state = TypedSimpleState::new(Arc::new(plan))?;
149    let mut dur = Duration::default();
150
151    capture_gpu_trace(sub_matches, || -> TractResult<()> {
152        while iters < bench_limits.max_loops && dur < bench_limits.max_time {
153            if !state.model().properties().contains_key("pulse.delay") {
154                state.init_states(&mut inputs.state_initializers.clone())?;
155            }
156            let start = Instant::now();
157            rec_profiler_gpu(&mut state, dg, &inputs.sources[0], &prefix)?;
158            dur += start.elapsed();
159            if !state.model().properties().contains_key("pulse.delay") {
160                state.reset_op_states()?;
161            }
162            iters += 1;
163        }
164        Ok(())
165    })?;
166
167    info!("Running {} iterations max. for each node.", bench_limits.max_loops);
168    info!("Running for {} ms max. for each node.", bench_limits.max_time.as_millis());
169
170    let denum = (iters as f32).recip();
171    let entire = dur.mul_f32(denum);
172    for d in dg.tags.values_mut() {
173        if let Some(d) = d.profile.as_mut() {
174            *d = d.mul_f32(denum);
175        }
176
177        if let Some(d) = d.accelerator_profile.as_mut() {
178            *d = d.mul_f32(denum);
179        }
180    }
181    let max = dg.tags.values().filter_map(|t| t.profile).max().unwrap();
182    let sum = dg.tags.values().filter_map(|t| t.profile).sum::<Duration>();
183    let accel_sum = dg.tags.values().filter_map(|t| t.accelerator_profile).sum::<Duration>();
184    dg.profile_summary = Some(ProfileSummary { max, sum, accel_sum, entire, iters });
185    Ok(())
186}
187
188pub fn rec_profiler_gpu(
189    state: &mut TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>,
190    dg: &mut Annotations,
191    inputs: &TVec<TValue>,
192    prefix: &[(usize, String)],
193) -> TractResult<TVec<TValue>> {
194    let r = state.run_plan_with_eval(
195        inputs.clone(),
196        |session_state, mut node_state, node, input| {
197            // Profile node
198            let start = crate::time::now();
199            let res = tract_core::plan::eval(
200                session_state,
201                node_state.as_deref_mut(),
202                node,
203                input.clone(),
204            );
205            let elapsed = start.elapsed();
206            let node_id = NodeQId(prefix.into(), node.id);
207            *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
208
209            res
210        },
211    )?;
212
213    Ok(r)
214}
215
216#[allow(clippy::too_many_arguments)]
217pub fn rec_profiler(
218    state: &mut TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>,
219    dg: &mut Annotations,
220    inputs: &TVec<TValue>,
221    profilers: Option<&HashMap<TypeId, Profiler>>,
222    prefix: &[(usize, String)],
223    multiplier: Option<usize>,
224    time_accounted_by_inner_nodes: &mut Duration,
225    folded: bool,
226) -> TractResult<TVec<TValue>> {
227    let r = state.run_plan_with_eval(
228        inputs.clone(),
229        |session_state, mut node_state, node, input| {
230            // Profile node
231            let start = crate::time::now();
232            let res = tract_core::plan::eval(
233                session_state,
234                node_state.as_deref_mut(),
235                node,
236                input.clone(),
237            );
238            let elapsed = start.elapsed().mul_f32(multiplier.unwrap_or(1) as _);
239            let node_id = NodeQId(prefix.into(), node.id);
240            *dg.node_mut(node_id).profile.get_or_insert(Duration::default()) += elapsed;
241
242            if !folded {
243                let start = crate::time::now();
244                profile_submodel(
245                    node,
246                    node_state,
247                    input,
248                    dg,
249                    profilers,
250                    prefix,
251                    time_accounted_by_inner_nodes,
252                )?;
253                *time_accounted_by_inner_nodes += start.elapsed();
254            }
255
256            // Update parent nodes if any (childs timings are deducted from parents)
257            let prefix_vec = prefix.to_vec();
258            if !prefix_vec.is_empty() {
259                (1..prefix_vec.len() + 1).map(|idx| prefix_vec[..idx].to_vec()).for_each(
260                    |parent_path| {
261                        let parent_node = parent_path.last().map(|it| it.0).unwrap();
262                        let parent = dg
263                            .node_mut(NodeQId(
264                                parent_path[..parent_path.len() - 1].into(),
265                                parent_node,
266                            ))
267                            .profile
268                            .get_or_insert(Duration::default());
269                        *parent -= elapsed.min(*parent);
270                    },
271                );
272            }
273            res
274        },
275    )?;
276    Ok(r)
277}
278
279fn profile_submodel(
280    node: &TypedNode,
281    mut node_state: Option<&mut dyn OpState>,
282    input: TVec<TValue>,
283    dg: &mut Annotations,
284    profilers: Option<&HashMap<TypeId, Profiler>>,
285    prefix: &[(usize, String)],
286    time_accounted_by_inner_nodes: &mut Duration,
287) -> TractResult<()> {
288    if let Some(ref mut op_state) = node_state {
289        if let Some(profiler) = profilers.and_then(|it| it.get(&op_state.type_id())) {
290            let mut new_prefix: TVec<_> = prefix.into();
291            new_prefix.push((node.id, "submodel".to_string()));
292
293            let (_, _) =
294                (profiler.func)(*op_state, input, dg, &new_prefix, time_accounted_by_inner_nodes)?;
295        } else if let Some(scan_state) = op_state.downcast_mut::<State>() {
296            let mut new_prefix: TVec<_> = prefix.into();
297            new_prefix.push((node.id, "loop".to_string()));
298
299            let scan_inputs = make_inputs_for_model(scan_state.model_state.model())?;
300            let multi = scan_state.iteration_count(&input);
301
302            rec_profiler(
303                &mut scan_state.model_state,
304                dg,
305                &scan_inputs,
306                None,
307                &new_prefix,
308                Some(multi),
309                time_accounted_by_inner_nodes,
310                false,
311            )?;
312        } else if let Some(typed_model_state) = op_state.downcast_mut::<TypedModelOpState>() {
313            let mut new_prefix: TVec<_> = prefix.into();
314            new_prefix.push((node.id, "submodel".to_string()));
315
316            rec_profiler(
317                typed_model_state,
318                dg,
319                &input,
320                None,
321                &new_prefix,
322                None,
323                time_accounted_by_inner_nodes,
324                false,
325            )?;
326        }
327    }
328
329    Ok(())
330}
331
332type ProfilerFn = fn(
333    &mut dyn OpState,
334    TVec<TValue>,
335    &mut Annotations,
336    &[(usize, String)],
337    &mut Duration,
338) -> TractResult<(TractResult<TVec<TValue>>, Duration)>;
339
340#[derive(Clone)]
341pub struct Profiler {
342    pub func: ProfilerFn,
343    pub name: &'static str,
344}
345
346impl Hash for Profiler {
347    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
348        self.name.hash(state)
349    }
350}
351
352pub fn extract_costs(
353    annotations: &mut Annotations,
354    model: &dyn Model,
355    extra_symbols: &SymbolValues,
356) -> TractResult<()> {
357    fn extract_costs_rec(
358        annotations: &mut Annotations,
359        model: &dyn Model,
360        prefix: &[(usize, String)],
361        multiplier: TDim,
362        extra_symbols: &SymbolValues,
363    ) -> TractResult<()> {
364        if let Some(model) = model.downcast_ref::<TypedModel>() {
365            for node_id in 0..model.nodes().len() {
366                let inputs = model.node_input_facts(node_id)?;
367                let cost = model
368                    .node(node_id)
369                    .op
370                    .cost(&inputs)
371                    .with_context(|| format!("costing node {}", model.node(node_id)))?;
372                annotations.node_mut(NodeQId(prefix.into(), node_id)).cost = cost
373                    .into_iter()
374                    .map(|(k, v)| {
375                        let cost = if k.is_compute() { v * &multiplier } else { v };
376                        (k, cost.eval(extra_symbols))
377                    })
378                    .collect();
379
380                let nested_subs = model.nested_models(node_id);
381                let nested_multis = (model as &dyn Model).nested_models_iters(node_id, &inputs);
382                for (name, sub) in nested_subs {
383                    let mut prefix: TVec<_> = prefix.into();
384                    prefix.push((node_id, name.to_string()));
385                    extract_costs_rec(
386                        annotations,
387                        sub,
388                        &prefix,
389                        nested_multis.clone().unwrap_or_else(|| 1.into()) * &multiplier,
390                        extra_symbols,
391                    )?;
392                }
393            }
394        }
395        Ok(())
396    }
397    extract_costs_rec(annotations, model, &[], 1.into(), extra_symbols)
398}