sp1_core_machine/utils/
span.rs

1use std::{collections::HashMap, fmt::Display, hash::Hash, iter::once};
2
3use sp1_core_executor::events::{format_table_line, sorted_table_lines};
4use thiserror::Error;
5
6/// A builder to create a [`Span`].
7/// `S` is the type of span names and `T` is the type of item names.
8#[derive(Debug, Clone)]
9pub struct SpanBuilder<S, T = S> {
10    pub parents: Vec<Span<S, T>>,
11    pub current_span: Span<S, T>,
12}
13
14impl<S, T> SpanBuilder<S, T>
15where
16    S: Display,
17    T: Ord + Display + Hash,
18{
19    /// Create an empty builder with the given name for the root span.
20    pub fn new(name: S) -> Self {
21        Self { parents: Default::default(), current_span: Span::new(name) }
22    }
23
24    /// Add an item to this span.
25    pub fn item(&mut self, item_name: impl Into<T>) -> &mut Self
26    where
27        T: Hash + Eq,
28    {
29        self.current_span.cts.entry(item_name.into()).and_modify(|x| *x += 1).or_insert(1);
30        self
31    }
32
33    /// Enter a new child span with the given name.
34    pub fn enter(&mut self, span_name: S) -> &mut Self {
35        let span = Span::new(span_name);
36        self.parents.push(core::mem::replace(&mut self.current_span, span));
37        self
38    }
39
40    /// Exit the current span, moving back to its parent.
41    ///
42    /// Yields an error if the current span is the root span, which may not be exited.
43    pub fn exit(&mut self) -> Result<&mut Self, SpanBuilderExitError>
44    where
45        T: Clone + Hash + Eq,
46    {
47        let mut parent_span = self.parents.pop().ok_or(SpanBuilderExitError::RootSpanExit)?;
48        // Add spanned instructions to parent.
49        for (instr_name, &ct) in self.current_span.cts.iter() {
50            // Always clones. Could be avoided with `raw_entry`, but it's not a big deal.
51            parent_span.cts.entry(instr_name.clone()).and_modify(|x| *x += ct).or_insert(ct);
52        }
53        // Move to the parent span.
54        let child_span = core::mem::replace(&mut self.current_span, parent_span);
55        self.current_span.children.push(child_span);
56        Ok(self)
57    }
58
59    /// Get the root span, consuming the builder.
60    ///
61    /// Yields an error if the current span is not the root span.
62    pub fn finish(self) -> Result<Span<S, T>, SpanBuilderFinishError> {
63        if self.parents.is_empty() {
64            Ok(self.current_span)
65        } else {
66            Err(SpanBuilderFinishError::OpenSpan(self.current_span.name.to_string()))
67        }
68    }
69}
70
71#[derive(Error, Debug, Clone)]
72pub enum SpanBuilderError {
73    #[error(transparent)]
74    Exit(#[from] SpanBuilderExitError),
75    #[error(transparent)]
76    Finish(#[from] SpanBuilderFinishError),
77}
78
79#[derive(Error, Debug, Clone)]
80pub enum SpanBuilderExitError {
81    #[error("cannot exit root span")]
82    RootSpanExit,
83}
84
85#[derive(Error, Debug, Clone)]
86pub enum SpanBuilderFinishError {
87    #[error("open span: {0}")]
88    OpenSpan(String),
89}
90
91/// A span for counting items in a recursive structure. Create and populate using [`SpanBuilder`].
92/// `S` is the type of span names and `T` is the type of item names.
93#[derive(Debug, Clone, Default)]
94pub struct Span<S, T = S> {
95    pub name: S,
96    pub cts: HashMap<T, usize>,
97    pub children: Vec<Span<S, T>>,
98}
99
100impl<S, T> Span<S, T>
101where
102    S: Display,
103    T: Ord + Display + Hash,
104{
105    /// Create a new span with the given name.
106    pub fn new(name: S) -> Self {
107        Self { name, cts: Default::default(), children: Default::default() }
108    }
109
110    /// Calculate the total number of items counted by this span and its children.
111    pub fn total(&self) -> usize {
112        // Counts are already added from children.
113        self.cts.values().cloned().sum()
114    }
115
116    /// Format and yield lines describing this span. Appropriate for logging.
117    pub fn lines(&self) -> Vec<String> {
118        let Self { name, cts: instr_cts, children } = self;
119        let (width, lines) = sorted_table_lines(instr_cts);
120        let lines = lines.map(|(label, count)| format_table_line(&width, &label, count));
121
122        once(format!("{name}"))
123            .chain(
124                children
125                    .iter()
126                    .flat_map(|c| c.lines())
127                    .chain(lines)
128                    .map(|line| format!("│  {line}")),
129            )
130            .chain(once(format!("└╴ {} total", self.total())))
131            .collect()
132    }
133}