1use std::collections::HashSet;
2use std::fmt::Display;
3use std::fmt::Formatter;
4use std::fmt::Result as FmtResult;
5use std::ops::Add;
6use std::ops::AddAssign;
7use std::ops::Sub;
8
9use air::table::TableId;
10use arbitrary::Arbitrary;
11use strum::IntoEnumIterator;
12
13use crate::aet::AlgebraicExecutionTrace;
14use crate::table::u32::U32TableEntry;
15
16#[derive(Debug, Default, Clone, Eq, PartialEq, Arbitrary)]
17pub(crate) struct ExecutionTraceProfiler {
18 call_stack: Vec<usize>,
19 profile: Vec<ProfileLine>,
20 u32_table_entries: HashSet<U32TableEntry>,
21}
22
23#[non_exhaustive]
26#[derive(Debug, Default, Clone, Eq, PartialEq, Hash, Arbitrary)]
27pub struct ProfileLine {
28 pub label: String,
29 pub call_depth: usize,
30
31 pub table_heights_start: VMTableHeights,
35
36 pub table_heights_stop: VMTableHeights,
41}
42
43#[non_exhaustive]
48#[derive(Debug, Clone, Eq, PartialEq, Hash, Arbitrary)]
49pub struct ExecutionTraceProfile {
50 pub total: VMTableHeights,
52
53 pub profile: Vec<ProfileLine>,
55
56 pub padded_height: usize,
59}
60
61#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
64pub struct VMTableHeights {
65 pub program: u32,
66 pub processor: u32,
67 pub op_stack: u32,
68 pub ram: u32,
69 pub jump_stack: u32,
70 pub hash: u32,
71 pub cascade: u32,
72 pub lookup: u32,
73 pub u32: u32,
74}
75
76impl ExecutionTraceProfiler {
77 pub fn new() -> Self {
78 Self {
79 call_stack: vec![],
80 profile: vec![],
81 u32_table_entries: HashSet::default(),
82 }
83 }
84
85 pub fn enter_span(&mut self, label: impl Into<String>, aet: &AlgebraicExecutionTrace) {
86 let profile_line = ProfileLine {
87 label: label.into(),
88 call_depth: self.call_stack.len(),
89 table_heights_start: Self::extract_table_heights(aet),
90 table_heights_stop: VMTableHeights::default(),
91 };
92
93 let line_number = self.profile.len();
94 self.profile.push(profile_line);
95 self.call_stack.push(line_number);
96 }
97
98 pub fn exit_span(&mut self, aet: &AlgebraicExecutionTrace) {
99 if let Some(line_number) = self.call_stack.pop() {
100 self.profile[line_number].table_heights_stop = Self::extract_table_heights(aet);
101 };
102 }
103
104 pub fn finish(mut self, aet: &AlgebraicExecutionTrace) -> ExecutionTraceProfile {
105 let table_heights = Self::extract_table_heights(aet);
106 for &line_number in &self.call_stack {
107 self.profile[line_number].table_heights_stop = table_heights;
108 }
109
110 ExecutionTraceProfile {
111 total: table_heights,
112 profile: self.profile,
113 padded_height: aet.padded_height(),
114 }
115 }
116
117 fn extract_table_heights(aet: &AlgebraicExecutionTrace) -> VMTableHeights {
118 let height = |id| aet.height_of_table(id).try_into().unwrap_or(u32::MAX);
120
121 VMTableHeights {
122 program: height(TableId::Program),
123 processor: height(TableId::Processor),
124 op_stack: height(TableId::OpStack),
125 ram: height(TableId::Ram),
126 jump_stack: height(TableId::JumpStack),
127 hash: height(TableId::Hash),
128 cascade: height(TableId::Cascade),
129 lookup: height(TableId::Lookup),
130 u32: height(TableId::U32),
131 }
132 }
133}
134
135impl VMTableHeights {
136 fn height_of_table(&self, table: TableId) -> u32 {
137 match table {
138 TableId::Program => self.program,
139 TableId::Processor => self.processor,
140 TableId::OpStack => self.op_stack,
141 TableId::Ram => self.ram,
142 TableId::JumpStack => self.jump_stack,
143 TableId::Hash => self.hash,
144 TableId::Cascade => self.cascade,
145 TableId::Lookup => self.lookup,
146 TableId::U32 => self.u32,
147 }
148 }
149
150 fn max(&self) -> u32 {
151 TableId::iter()
152 .map(|id| self.height_of_table(id))
153 .max()
154 .unwrap()
155 }
156}
157
158impl Sub for VMTableHeights {
159 type Output = Self;
160
161 fn sub(self, rhs: Self) -> Self::Output {
162 Self {
163 program: self.program.saturating_sub(rhs.program),
164 processor: self.processor.saturating_sub(rhs.processor),
165 op_stack: self.op_stack.saturating_sub(rhs.op_stack),
166 ram: self.ram.saturating_sub(rhs.ram),
167 jump_stack: self.jump_stack.saturating_sub(rhs.jump_stack),
168 hash: self.hash.saturating_sub(rhs.hash),
169 cascade: self.cascade.saturating_sub(rhs.cascade),
170 lookup: self.lookup.saturating_sub(rhs.lookup),
171 u32: self.u32.saturating_sub(rhs.u32),
172 }
173 }
174}
175
176impl Add for VMTableHeights {
177 type Output = Self;
178
179 fn add(self, rhs: Self) -> Self::Output {
180 Self {
181 program: self.program + rhs.program,
182 processor: self.processor + rhs.processor,
183 op_stack: self.op_stack + rhs.op_stack,
184 ram: self.ram + rhs.ram,
185 jump_stack: self.jump_stack + rhs.jump_stack,
186 hash: self.hash + rhs.hash,
187 cascade: self.cascade + rhs.cascade,
188 lookup: self.lookup + rhs.lookup,
189 u32: self.u32 + rhs.u32,
190 }
191 }
192}
193
194impl AddAssign<Self> for VMTableHeights {
195 fn add_assign(&mut self, rhs: Self) {
196 *self = *self + rhs;
197 }
198}
199
200impl ProfileLine {
201 fn table_height_contributions(&self) -> VMTableHeights {
202 self.table_heights_stop - self.table_heights_start
203 }
204}
205
206impl Display for ProfileLine {
207 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
208 let indentation = " ".repeat(self.call_depth);
209 let label = &self.label;
210 let cycle_count = self.table_height_contributions().processor;
211 write!(f, "{indentation}{label}: {cycle_count}")
212 }
213}
214
215impl Display for ExecutionTraceProfile {
216 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
217 struct AggregateLine {
218 label: String,
219 call_depth: usize,
220 table_heights: VMTableHeights,
221 }
222
223 const COL_WIDTH: usize = 20;
224
225 let mut aggregated: Vec<AggregateLine> = vec![];
226 for line in &self.profile {
227 if let Some(agg) = aggregated
228 .iter_mut()
229 .find(|agg| agg.label == line.label && agg.call_depth == line.call_depth)
230 {
231 agg.table_heights += line.table_height_contributions();
232 } else {
233 aggregated.push(AggregateLine {
234 label: line.label.clone(),
235 call_depth: line.call_depth,
236 table_heights: line.table_height_contributions(),
237 });
238 }
239 }
240 aggregated.push(AggregateLine {
241 label: "Total".to_string(),
242 call_depth: 0,
243 table_heights: self.total,
244 });
245
246 let label = |line: &AggregateLine| "··".repeat(line.call_depth) + &line.label;
247
248 let max_label_len = aggregated.iter().map(|line| label(line).len()).max();
249 let max_label_len = max_label_len.unwrap_or_default().max(COL_WIDTH);
250
251 write!(f, "| {: <max_label_len$} ", "Subroutine")?;
252 write!(f, "| {: >COL_WIDTH$} ", "Processor")?;
253 write!(f, "| {: >COL_WIDTH$} ", "OpStack")?;
254 write!(f, "| {: >COL_WIDTH$} ", "Ram")?;
255 write!(f, "| {: >COL_WIDTH$} ", "Hash")?;
256 write!(f, "| {: >COL_WIDTH$} ", "U32")?;
257 writeln!(f, "|")?;
258
259 write!(f, "|:{:-<max_label_len$}-", "")?;
260 write!(f, "|-{:->COL_WIDTH$}:", "")?;
261 write!(f, "|-{:->COL_WIDTH$}:", "")?;
262 write!(f, "|-{:->COL_WIDTH$}:", "")?;
263 write!(f, "|-{:->COL_WIDTH$}:", "")?;
264 write!(f, "|-{:->COL_WIDTH$}:", "")?;
265 writeln!(f, "|")?;
266
267 for line in &aggregated {
268 let rel_precision = 1;
269 let rel_width = 3 + 1 + rel_precision; let abs_width = COL_WIDTH - rel_width - 4; let label = label(line);
273 let proc_abs = line.table_heights.processor;
274 let proc_rel = 100.0 * f64::from(proc_abs) / f64::from(self.total.processor);
275 let proc_rel = format!("{proc_rel:.rel_precision$}");
276 let stack_abs = line.table_heights.op_stack;
277 let stack_rel = 100.0 * f64::from(stack_abs) / f64::from(self.total.op_stack);
278 let stack_rel = format!("{stack_rel:.rel_precision$}");
279 let ram_abs = line.table_heights.ram;
280 let ram_rel = 100.0 * f64::from(ram_abs) / f64::from(self.total.ram);
281 let ram_rel = format!("{ram_rel:.rel_precision$}");
282 let hash_abs = line.table_heights.hash;
283 let hash_rel = 100.0 * f64::from(hash_abs) / f64::from(self.total.hash);
284 let hash_rel = format!("{hash_rel:.rel_precision$}");
285 let u32_abs = line.table_heights.u32;
286 let u32_rel = 100.0 * f64::from(u32_abs) / f64::from(self.total.u32);
287 let u32_rel = format!("{u32_rel:.rel_precision$}");
288
289 write!(f, "| {label:<max_label_len$} ")?;
290 write!(f, "| {proc_abs:>abs_width$} ({proc_rel:>rel_width$}%) ")?;
291 write!(f, "| {stack_abs:>abs_width$} ({stack_rel:>rel_width$}%) ")?;
292 write!(f, "| {ram_abs:>abs_width$} ({ram_rel:>rel_width$}%) ")?;
293 write!(f, "| {hash_abs:>abs_width$} ({hash_rel:>rel_width$}%) ")?;
294 write!(f, "| {u32_abs:>abs_width$} ({u32_rel:>rel_width$}%) ")?;
295 writeln!(f, "|")?;
296 }
297
298 let max_height = self.total.max();
300 let height_len = std::cmp::max(max_height.to_string().len(), "Height".len());
301
302 writeln!(f)?;
303 writeln!(f, "| Table | {: >height_len$} | Dominates |", "Height")?;
304 writeln!(f, "|:----------|-{:->height_len$}:|----------:|", "")?;
305 for id in TableId::iter() {
306 let height = self.total.height_of_table(id);
307 let dominates = if height == max_height { "yes" } else { "no" };
308 writeln!(f, "| {id:<9} | {height:>height_len$} | {dominates:>9} |")?;
309 }
310 writeln!(f)?;
311 writeln!(f, "Padded height: 2^{}", self.padded_height.ilog2())?;
312
313 Ok(())
314 }
315}
316
317#[cfg(test)]
318#[cfg_attr(coverage_nightly, coverage(off))]
319mod tests {
320 use assert2::assert;
321 use assert2::let_assert;
322
323 use crate::prelude::InstructionError;
324 use crate::prelude::TableId;
325 use crate::prelude::VM;
326 use crate::prelude::VMState;
327 use crate::prelude::triton_program;
328
329 #[test]
330 fn profile_can_be_created_and_agrees_with_regular_vm_run() {
331 let program =
332 crate::example_programs::CALCULATE_NEW_MMR_PEAKS_FROM_APPEND_WITH_SAFE_LISTS.clone();
333 let (profile_output, profile) = VM::profile(program.clone(), [].into(), [].into()).unwrap();
334 let mut vm_state = VMState::new(program.clone(), [].into(), [].into());
335 let_assert!(Ok(()) = vm_state.run());
336 assert!(profile_output == vm_state.public_output);
337 assert!(profile.total.processor == vm_state.cycle_count);
338
339 let_assert!(Ok((aet, trace_output)) = VM::trace_execution(program, [].into(), [].into()));
340 assert!(profile_output == trace_output);
341
342 let height = |id| u32::try_from(aet.height_of_table(id)).unwrap();
343 assert!(height(TableId::Program) == profile.total.program);
344 assert!(height(TableId::Processor) == profile.total.processor);
345 assert!(height(TableId::OpStack) == profile.total.op_stack);
346 assert!(height(TableId::Ram) == profile.total.ram);
347 assert!(height(TableId::Hash) == profile.total.hash);
348 assert!(height(TableId::Cascade) == profile.total.cascade);
349 assert!(height(TableId::Lookup) == profile.total.lookup);
350 assert!(height(TableId::U32) == profile.total.u32);
351
352 println!("{profile}");
353 }
354
355 #[test]
356 fn program_with_too_many_returns_crashes_vm_but_not_profiler() {
357 let program = triton_program! {
358 call foo return halt
359 foo: return
360 };
361 let_assert!(Err(err) = VM::profile(program, [].into(), [].into()));
362 let_assert!(InstructionError::JumpStackIsEmpty = err.source);
363 }
364
365 #[test]
366 fn call_instruction_does_not_contribute_to_profile_span() {
367 let program = triton_program! { call foo halt foo: return };
368 let_assert!(Ok((_, profile)) = VM::profile(program, [].into(), [].into()));
369
370 let [foo_span] = &profile.profile[..] else {
371 panic!("span `foo` must be present")
372 };
373 assert!("foo" == foo_span.label);
374 assert!(1 == foo_span.table_height_contributions().processor);
375 }
376}