triton_vm/
execution_trace_profiler.rs

1use std::collections::HashSet;
2use std::fmt::Display;
3use std::fmt::Formatter;
4use std::fmt::Result as FmtResult;
5use std::ops::Add;
6use std::ops::AddAssign;
7use std::ops::Sub;
8
9use air::table::TableId;
10use arbitrary::Arbitrary;
11use strum::IntoEnumIterator;
12
13use crate::aet::AlgebraicExecutionTrace;
14use crate::table::u32::U32TableEntry;
15
16#[derive(Debug, Default, Clone, Eq, PartialEq, Arbitrary)]
17pub(crate) struct ExecutionTraceProfiler {
18    call_stack: Vec<usize>,
19    profile: Vec<ProfileLine>,
20    u32_table_entries: HashSet<U32TableEntry>,
21}
22
23/// A single line in a [profile report](ExecutionTraceProfile) for profiling
24/// [Triton](crate) programs.
25#[non_exhaustive]
26#[derive(Debug, Default, Clone, Eq, PartialEq, Hash, Arbitrary)]
27pub struct ProfileLine {
28    pub label: String,
29    pub call_depth: usize,
30
31    /// Table heights at the start of this span, _i.e._, right after the
32    /// corresponding [`call`](isa::instruction::Instruction::Call)
33    /// instruction was executed.
34    pub table_heights_start: VMTableHeights,
35
36    /// Table heights at the end of this span, _i.e._, right after the
37    /// corresponding [`return`](isa::instruction::Instruction::Return)
38    /// or [`recurse_or_return`](isa::instruction::Instruction::RecurseOrReturn)
39    /// (in “return” mode) was executed
40    pub table_heights_stop: VMTableHeights,
41}
42
43/// A report for the completed execution of a [Triton](crate) program.
44///
45/// Offers a human-readable [`Display`] implementation and can be processed
46/// programmatically.
47#[non_exhaustive]
48#[derive(Debug, Clone, Eq, PartialEq, Hash, Arbitrary)]
49pub struct ExecutionTraceProfile {
50    /// The total height of all tables in the [AET](AlgebraicExecutionTrace).
51    pub total: VMTableHeights,
52
53    /// The profile lines, each representing a span of the program execution.
54    pub profile: Vec<ProfileLine>,
55
56    /// The [padded height](AlgebraicExecutionTrace::padded_height) of the
57    /// [AET](AlgebraicExecutionTrace).
58    pub padded_height: usize,
59}
60
61/// The heights of various [tables](AlgebraicExecutionTrace)
62/// relevant for proving the correct execution in [Triton VM](crate).
63#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
64pub struct VMTableHeights {
65    pub program: u32,
66    pub processor: u32,
67    pub op_stack: u32,
68    pub ram: u32,
69    pub jump_stack: u32,
70    pub hash: u32,
71    pub cascade: u32,
72    pub lookup: u32,
73    pub u32: u32,
74}
75
76impl ExecutionTraceProfiler {
77    pub fn new() -> Self {
78        Self {
79            call_stack: vec![],
80            profile: vec![],
81            u32_table_entries: HashSet::default(),
82        }
83    }
84
85    pub fn enter_span(&mut self, label: impl Into<String>, aet: &AlgebraicExecutionTrace) {
86        let profile_line = ProfileLine {
87            label: label.into(),
88            call_depth: self.call_stack.len(),
89            table_heights_start: Self::extract_table_heights(aet),
90            table_heights_stop: VMTableHeights::default(),
91        };
92
93        let line_number = self.profile.len();
94        self.profile.push(profile_line);
95        self.call_stack.push(line_number);
96    }
97
98    pub fn exit_span(&mut self, aet: &AlgebraicExecutionTrace) {
99        if let Some(line_number) = self.call_stack.pop() {
100            self.profile[line_number].table_heights_stop = Self::extract_table_heights(aet);
101        };
102    }
103
104    pub fn finish(mut self, aet: &AlgebraicExecutionTrace) -> ExecutionTraceProfile {
105        let table_heights = Self::extract_table_heights(aet);
106        for &line_number in &self.call_stack {
107            self.profile[line_number].table_heights_stop = table_heights;
108        }
109
110        ExecutionTraceProfile {
111            total: table_heights,
112            profile: self.profile,
113            padded_height: aet.padded_height(),
114        }
115    }
116
117    fn extract_table_heights(aet: &AlgebraicExecutionTrace) -> VMTableHeights {
118        // any table in Triton VM can be of length at most u32::MAX
119        let height = |id| aet.height_of_table(id).try_into().unwrap_or(u32::MAX);
120
121        VMTableHeights {
122            program: height(TableId::Program),
123            processor: height(TableId::Processor),
124            op_stack: height(TableId::OpStack),
125            ram: height(TableId::Ram),
126            jump_stack: height(TableId::JumpStack),
127            hash: height(TableId::Hash),
128            cascade: height(TableId::Cascade),
129            lookup: height(TableId::Lookup),
130            u32: height(TableId::U32),
131        }
132    }
133}
134
135impl VMTableHeights {
136    fn height_of_table(&self, table: TableId) -> u32 {
137        match table {
138            TableId::Program => self.program,
139            TableId::Processor => self.processor,
140            TableId::OpStack => self.op_stack,
141            TableId::Ram => self.ram,
142            TableId::JumpStack => self.jump_stack,
143            TableId::Hash => self.hash,
144            TableId::Cascade => self.cascade,
145            TableId::Lookup => self.lookup,
146            TableId::U32 => self.u32,
147        }
148    }
149
150    fn max(&self) -> u32 {
151        TableId::iter()
152            .map(|id| self.height_of_table(id))
153            .max()
154            .unwrap()
155    }
156}
157
158impl Sub for VMTableHeights {
159    type Output = Self;
160
161    fn sub(self, rhs: Self) -> Self::Output {
162        Self {
163            program: self.program.saturating_sub(rhs.program),
164            processor: self.processor.saturating_sub(rhs.processor),
165            op_stack: self.op_stack.saturating_sub(rhs.op_stack),
166            ram: self.ram.saturating_sub(rhs.ram),
167            jump_stack: self.jump_stack.saturating_sub(rhs.jump_stack),
168            hash: self.hash.saturating_sub(rhs.hash),
169            cascade: self.cascade.saturating_sub(rhs.cascade),
170            lookup: self.lookup.saturating_sub(rhs.lookup),
171            u32: self.u32.saturating_sub(rhs.u32),
172        }
173    }
174}
175
176impl Add for VMTableHeights {
177    type Output = Self;
178
179    fn add(self, rhs: Self) -> Self::Output {
180        Self {
181            program: self.program + rhs.program,
182            processor: self.processor + rhs.processor,
183            op_stack: self.op_stack + rhs.op_stack,
184            ram: self.ram + rhs.ram,
185            jump_stack: self.jump_stack + rhs.jump_stack,
186            hash: self.hash + rhs.hash,
187            cascade: self.cascade + rhs.cascade,
188            lookup: self.lookup + rhs.lookup,
189            u32: self.u32 + rhs.u32,
190        }
191    }
192}
193
194impl AddAssign<Self> for VMTableHeights {
195    fn add_assign(&mut self, rhs: Self) {
196        *self = *self + rhs;
197    }
198}
199
200impl ProfileLine {
201    fn table_height_contributions(&self) -> VMTableHeights {
202        self.table_heights_stop - self.table_heights_start
203    }
204}
205
206impl Display for ProfileLine {
207    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
208        let indentation = "  ".repeat(self.call_depth);
209        let label = &self.label;
210        let cycle_count = self.table_height_contributions().processor;
211        write!(f, "{indentation}{label}: {cycle_count}")
212    }
213}
214
215impl Display for ExecutionTraceProfile {
216    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
217        struct AggregateLine {
218            label: String,
219            call_depth: usize,
220            table_heights: VMTableHeights,
221        }
222
223        const COL_WIDTH: usize = 20;
224
225        let mut aggregated: Vec<AggregateLine> = vec![];
226        for line in &self.profile {
227            if let Some(agg) = aggregated
228                .iter_mut()
229                .find(|agg| agg.label == line.label && agg.call_depth == line.call_depth)
230            {
231                agg.table_heights += line.table_height_contributions();
232            } else {
233                aggregated.push(AggregateLine {
234                    label: line.label.clone(),
235                    call_depth: line.call_depth,
236                    table_heights: line.table_height_contributions(),
237                });
238            }
239        }
240        aggregated.push(AggregateLine {
241            label: "Total".to_string(),
242            call_depth: 0,
243            table_heights: self.total,
244        });
245
246        let label = |line: &AggregateLine| "··".repeat(line.call_depth) + &line.label;
247
248        let max_label_len = aggregated.iter().map(|line| label(line).len()).max();
249        let max_label_len = max_label_len.unwrap_or_default().max(COL_WIDTH);
250
251        write!(f, "| {: <max_label_len$} ", "Subroutine")?;
252        write!(f, "| {: >COL_WIDTH$} ", "Processor")?;
253        write!(f, "| {: >COL_WIDTH$} ", "OpStack")?;
254        write!(f, "| {: >COL_WIDTH$} ", "Ram")?;
255        write!(f, "| {: >COL_WIDTH$} ", "Hash")?;
256        write!(f, "| {: >COL_WIDTH$} ", "U32")?;
257        writeln!(f, "|")?;
258
259        write!(f, "|:{:-<max_label_len$}-", "")?;
260        write!(f, "|-{:->COL_WIDTH$}:", "")?;
261        write!(f, "|-{:->COL_WIDTH$}:", "")?;
262        write!(f, "|-{:->COL_WIDTH$}:", "")?;
263        write!(f, "|-{:->COL_WIDTH$}:", "")?;
264        write!(f, "|-{:->COL_WIDTH$}:", "")?;
265        writeln!(f, "|")?;
266
267        for line in &aggregated {
268            let rel_precision = 1;
269            let rel_width = 3 + 1 + rel_precision; // eg '100.0'
270            let abs_width = COL_WIDTH - rel_width - 4; // ' (' and '%)'
271
272            let label = label(line);
273            let proc_abs = line.table_heights.processor;
274            let proc_rel = 100.0 * f64::from(proc_abs) / f64::from(self.total.processor);
275            let proc_rel = format!("{proc_rel:.rel_precision$}");
276            let stack_abs = line.table_heights.op_stack;
277            let stack_rel = 100.0 * f64::from(stack_abs) / f64::from(self.total.op_stack);
278            let stack_rel = format!("{stack_rel:.rel_precision$}");
279            let ram_abs = line.table_heights.ram;
280            let ram_rel = 100.0 * f64::from(ram_abs) / f64::from(self.total.ram);
281            let ram_rel = format!("{ram_rel:.rel_precision$}");
282            let hash_abs = line.table_heights.hash;
283            let hash_rel = 100.0 * f64::from(hash_abs) / f64::from(self.total.hash);
284            let hash_rel = format!("{hash_rel:.rel_precision$}");
285            let u32_abs = line.table_heights.u32;
286            let u32_rel = 100.0 * f64::from(u32_abs) / f64::from(self.total.u32);
287            let u32_rel = format!("{u32_rel:.rel_precision$}");
288
289            write!(f, "| {label:<max_label_len$} ")?;
290            write!(f, "| {proc_abs:>abs_width$} ({proc_rel:>rel_width$}%) ")?;
291            write!(f, "| {stack_abs:>abs_width$} ({stack_rel:>rel_width$}%) ")?;
292            write!(f, "| {ram_abs:>abs_width$} ({ram_rel:>rel_width$}%) ")?;
293            write!(f, "| {hash_abs:>abs_width$} ({hash_rel:>rel_width$}%) ")?;
294            write!(f, "| {u32_abs:>abs_width$} ({u32_rel:>rel_width$}%) ")?;
295            writeln!(f, "|")?;
296        }
297
298        // print total height of all tables
299        let max_height = self.total.max();
300        let height_len = std::cmp::max(max_height.to_string().len(), "Height".len());
301
302        writeln!(f)?;
303        writeln!(f, "| Table     | {: >height_len$} | Dominates |", "Height")?;
304        writeln!(f, "|:----------|-{:->height_len$}:|----------:|", "")?;
305        for id in TableId::iter() {
306            let height = self.total.height_of_table(id);
307            let dominates = if height == max_height { "yes" } else { "no" };
308            writeln!(f, "| {id:<9} | {height:>height_len$} | {dominates:>9} |")?;
309        }
310        writeln!(f)?;
311        writeln!(f, "Padded height: 2^{}", self.padded_height.ilog2())?;
312
313        Ok(())
314    }
315}
316
317#[cfg(test)]
318#[cfg_attr(coverage_nightly, coverage(off))]
319mod tests {
320    use assert2::assert;
321    use assert2::let_assert;
322
323    use crate::prelude::InstructionError;
324    use crate::prelude::TableId;
325    use crate::prelude::VM;
326    use crate::prelude::VMState;
327    use crate::prelude::triton_program;
328
329    #[test]
330    fn profile_can_be_created_and_agrees_with_regular_vm_run() {
331        let program =
332            crate::example_programs::CALCULATE_NEW_MMR_PEAKS_FROM_APPEND_WITH_SAFE_LISTS.clone();
333        let (profile_output, profile) = VM::profile(program.clone(), [].into(), [].into()).unwrap();
334        let mut vm_state = VMState::new(program.clone(), [].into(), [].into());
335        let_assert!(Ok(()) = vm_state.run());
336        assert!(profile_output == vm_state.public_output);
337        assert!(profile.total.processor == vm_state.cycle_count);
338
339        let_assert!(Ok((aet, trace_output)) = VM::trace_execution(program, [].into(), [].into()));
340        assert!(profile_output == trace_output);
341
342        let height = |id| u32::try_from(aet.height_of_table(id)).unwrap();
343        assert!(height(TableId::Program) == profile.total.program);
344        assert!(height(TableId::Processor) == profile.total.processor);
345        assert!(height(TableId::OpStack) == profile.total.op_stack);
346        assert!(height(TableId::Ram) == profile.total.ram);
347        assert!(height(TableId::Hash) == profile.total.hash);
348        assert!(height(TableId::Cascade) == profile.total.cascade);
349        assert!(height(TableId::Lookup) == profile.total.lookup);
350        assert!(height(TableId::U32) == profile.total.u32);
351
352        println!("{profile}");
353    }
354
355    #[test]
356    fn program_with_too_many_returns_crashes_vm_but_not_profiler() {
357        let program = triton_program! {
358            call foo return halt
359            foo: return
360        };
361        let_assert!(Err(err) = VM::profile(program, [].into(), [].into()));
362        let_assert!(InstructionError::JumpStackIsEmpty = err.source);
363    }
364
365    #[test]
366    fn call_instruction_does_not_contribute_to_profile_span() {
367        let program = triton_program! { call foo halt foo: return };
368        let_assert!(Ok((_, profile)) = VM::profile(program, [].into(), [].into()));
369
370        let [foo_span] = &profile.profile[..] else {
371            panic!("span `foo` must be present")
372        };
373        assert!("foo" == foo_span.label);
374        assert!(1 == foo_span.table_height_contributions().processor);
375    }
376}