Skip to main content

trident/cost/
analyzer.rs

1use std::collections::BTreeMap;
2
3use super::model::{create_cost_model, CostModel, TableCost};
4use crate::ast::*;
5use crate::field::proof;
6
7// --- Per-function cost result ---
8
9/// Cost analysis result for a single function.
10#[derive(Clone, Debug)]
11pub struct FunctionCost {
12    pub name: String,
13    pub cost: TableCost,
14    /// If this function contains a loop, per-iteration cost.
15    pub per_iteration: Option<(TableCost, u64)>,
16}
17
18/// Cost analysis result for the full program.
19#[derive(Clone, Debug)]
20pub struct ProgramCost {
21    pub program_name: String,
22    pub functions: Vec<FunctionCost>,
23    pub total: TableCost,
24    /// Table names from the CostModel (e.g. ["processor", "hash", ...]).
25    pub table_names: Vec<String>,
26    /// Short display names (e.g. ["cc", "hash", ...]).
27    pub table_short_names: Vec<String>,
28    /// Program attestation adds ceil(instruction_count / 10) * 6 hash rows.
29    pub attestation_hash_rows: u64,
30    pub padded_height: u64,
31    pub estimated_proving_ns: u64,
32    /// H0004: loops where declared bound >> actual constant end.
33    pub loop_bound_waste: Vec<(String, u64, u64)>, // (fn_name, end_value, bound)
34}
35
36impl ProgramCost {
37    /// Short names as str slice refs (for passing to TableCost methods).
38    pub fn short_names(&self) -> Vec<&str> {
39        self.table_short_names.iter().map(|s| s.as_str()).collect()
40    }
41
42    /// Long names as str slice refs.
43    pub fn long_names(&self) -> Vec<&str> {
44        self.table_names.iter().map(|s| s.as_str()).collect()
45    }
46}
47
48// --- Cost analyzer ---
49
50/// Computes static cost by walking the AST.
51///
52/// The analyzer is parameterized by a `CostModel` that provides all
53/// target-specific cost constants.
54pub(crate) struct CostAnalyzer<'a> {
55    /// Target-specific cost model.
56    pub(crate) cost_model: &'a dyn CostModel,
57    /// Function bodies indexed by name (for resolving calls).
58    pub(crate) fn_bodies: BTreeMap<String, FnDef>,
59    /// Cached function costs to avoid recomputation.
60    fn_costs: BTreeMap<String, TableCost>,
61    /// Recursion guard to prevent infinite loops in cost computation.
62    in_progress: Vec<String>,
63    /// H0004: collected loop bound waste entries (fn_name, end_value, bound).
64    pub(crate) loop_bound_waste: Vec<(String, u64, u64)>,
65}
66
67impl Default for CostAnalyzer<'_> {
68    fn default() -> Self {
69        Self::for_target("triton")
70    }
71}
72
73impl<'a> CostAnalyzer<'a> {
74    /// Create an analyzer for the named target.
75    pub(crate) fn for_target(target_name: &str) -> Self {
76        Self::with_cost_model(create_cost_model(target_name))
77    }
78
79    /// Create an analyzer with a specific cost model.
80    pub(crate) fn with_cost_model(cost_model: &'a dyn CostModel) -> Self {
81        Self {
82            cost_model,
83            fn_bodies: BTreeMap::new(),
84            fn_costs: BTreeMap::new(),
85            in_progress: Vec::new(),
86            loop_bound_waste: Vec::new(),
87        }
88    }
89
90    /// Analyze a complete file and return the program cost.
91    pub(crate) fn analyze_file(&mut self, file: &File) -> ProgramCost {
92        // Collect all function definitions.
93        for item in &file.items {
94            if let Item::Fn(func) = &item.node {
95                self.fn_bodies.insert(func.name.node.clone(), func.clone());
96            }
97        }
98
99        // Compute cost for each function.
100        let mut functions = Vec::new();
101        let fn_names: Vec<String> = self.fn_bodies.keys().cloned().collect();
102        for name in &fn_names {
103            let func = self
104                .fn_bodies
105                .get(name)
106                .expect("name comes from fn_bodies.keys()")
107                .clone();
108            let cost = self.cost_fn(&func);
109            let per_iteration = self.find_loop_iteration_cost(&func);
110            functions.push(FunctionCost {
111                name: name.clone(),
112                cost,
113                per_iteration,
114            });
115        }
116
117        // Total cost: start from main if it exists, otherwise sum all.
118        let total = if let Some(main_cost) = self.fn_costs.get("main") {
119            main_cost.add(&self.cost_model.call_overhead()) // call main + halt
120        } else {
121            functions
122                .iter()
123                .fold(TableCost::ZERO, |acc, f| acc.add(&f.cost))
124        };
125
126        // Estimate program instruction count for attestation.
127        // Rough heuristic: total first-table value (processor cycles) ≈ instruction count.
128        let instruction_count = total.get(0).max(10);
129        let hash_rows = self.cost_model.hash_rows_per_permutation();
130        let attestation_hash_rows = instruction_count.div_ceil(10) * hash_rows;
131
132        // Padded height includes attestation.
133        let max_height = total.max_height().max(attestation_hash_rows);
134        let padded_height = proof::padded_height(max_height);
135
136        let columns = self.cost_model.trace_column_count();
137        let estimated_proving_ns = proof::estimate_proving_ns(padded_height, columns);
138
139        // H0004: scan for loop bound waste (bound >> constant end)
140        for item in &file.items {
141            if let Item::Fn(func) = &item.node {
142                if let Some(body) = &func.body {
143                    self.scan_loop_bound_waste(&func.name.node, &body.node);
144                }
145            }
146        }
147
148        ProgramCost {
149            program_name: file.name.node.clone(),
150            functions,
151            total,
152            table_names: self
153                .cost_model
154                .table_names()
155                .iter()
156                .map(|s| s.to_string())
157                .collect(),
158            table_short_names: self
159                .cost_model
160                .table_short_names()
161                .iter()
162                .map(|s| s.to_string())
163                .collect(),
164            attestation_hash_rows,
165            padded_height,
166            estimated_proving_ns,
167            loop_bound_waste: std::mem::take(&mut self.loop_bound_waste),
168        }
169    }
170
171    pub(crate) fn cost_fn(&mut self, func: &FnDef) -> TableCost {
172        if let Some(cached) = self.fn_costs.get(&func.name.node) {
173            return *cached;
174        }
175
176        // Recursion guard: if this function is already being analyzed,
177        // return ZERO to break the cycle.
178        if self.in_progress.contains(&func.name.node) {
179            return TableCost::ZERO;
180        }
181
182        let depth_before = self.in_progress.len();
183        self.in_progress.push(func.name.node.clone());
184
185        let cost = if let Some(body) = &func.body {
186            self.cost_block(&body.node)
187        } else {
188            TableCost::ZERO
189        };
190
191        self.in_progress.pop();
192
193        // Only cache if we're at the top-level call (no recursion in flight).
194        // Costs computed during active recursion are underestimates because
195        // recursive calls are costed as ZERO.
196        if depth_before == 0 {
197            self.fn_costs.insert(func.name.node.clone(), cost);
198        }
199        cost
200    }
201
202    pub(crate) fn cost_block(&mut self, block: &Block) -> TableCost {
203        let mut cost = TableCost::ZERO;
204        for stmt in &block.stmts {
205            cost = cost.add(&self.cost_stmt(&stmt.node));
206        }
207        if let Some(tail) = &block.tail_expr {
208            cost = cost.add(&self.cost_expr(&tail.node));
209        }
210        cost
211    }
212
213    pub(crate) fn cost_stmt(&mut self, stmt: &Stmt) -> TableCost {
214        let stack_op = self.cost_model.stack_op();
215        match stmt {
216            Stmt::Let { init, .. } => {
217                // Cost of evaluating the init expression + stack placement.
218                self.cost_expr(&init.node).add(&stack_op)
219            }
220            Stmt::Assign { value, .. } => {
221                // Cost of evaluating value + swap to replace old value.
222                self.cost_expr(&value.node).add(&stack_op).add(&stack_op)
223            }
224            Stmt::TupleAssign { names, value } => {
225                let mut cost = self.cost_expr(&value.node);
226                // One swap+pop per element.
227                for _ in names {
228                    cost = cost.add(&stack_op).add(&stack_op);
229                }
230                cost
231            }
232            Stmt::If {
233                cond,
234                then_block,
235                else_block,
236            } => {
237                let cond_cost = self.cost_expr(&cond.node);
238                let then_cost = self.cost_block(&then_block.node);
239                let else_cost = if let Some(eb) = else_block {
240                    self.cost_block(&eb.node)
241                } else {
242                    TableCost::ZERO
243                };
244                // Worst case: max of then/else branches.
245                cond_cost
246                    .add(&then_cost.max(&else_cost))
247                    .add(&self.cost_model.if_overhead())
248            }
249            Stmt::For {
250                end, bound, body, ..
251            } => {
252                let end_cost = self.cost_expr(&end.node);
253                let body_cost = self.cost_block(&body.node);
254                // Use declared bound if available, otherwise use end expr as literal.
255                let iterations = if let Some(b) = bound {
256                    *b
257                } else if let Expr::Literal(Literal::Integer(n)) = &end.node {
258                    *n
259                } else {
260                    // Non-constant loop bound with no `bounded` annotation.
261                    // Default to 1 iteration but record as a warning so
262                    // report.rs can flag it via H0004.
263                    self.loop_bound_waste.push((
264                        self.in_progress.last().cloned().unwrap_or_default(),
265                        1, // assumed iterations
266                        0, // no declared bound (0 signals "unknown")
267                    ));
268                    1
269                };
270                // Per-iteration: body + loop overhead (dup, check, decrement, recurse).
271                let per_iter = body_cost.add(&self.cost_model.loop_overhead());
272                end_cost.add(&per_iter.scale(iterations))
273            }
274            Stmt::Expr(expr) => self.cost_expr(&expr.node),
275            Stmt::Return(val) => {
276                if let Some(v) = val {
277                    self.cost_expr(&v.node)
278                } else {
279                    TableCost::ZERO
280                }
281            }
282            Stmt::Reveal { fields, .. } => {
283                // push tag + write_io 1 + (field expr + write_io 1) per field
284                let io_cost = self.cost_model.builtin_cost("pub_write");
285                let mut cost = stack_op.clone(); // push tag
286                cost = cost.add(&io_cost); // write_io 1 for tag
287                for (_name, val) in fields {
288                    cost = cost.add(&self.cost_expr(&val.node));
289                    cost = cost.add(&io_cost); // write_io 1
290                }
291                cost
292            }
293            Stmt::Asm { body, .. } => {
294                // Conservative estimate: count non-empty, non-comment lines as stack ops
295                let line_count = body
296                    .lines()
297                    .filter(|l| {
298                        let t = l.trim();
299                        !t.is_empty() && !t.starts_with("//")
300                    })
301                    .count() as u64;
302                stack_op.scale(line_count)
303            }
304            Stmt::Match { expr, arms } => {
305                let scrutinee_cost = self.cost_expr(&expr.node);
306                // Per arm: dup + push + eq + skiz/call overhead = ~5 rows
307                let arm_overhead = stack_op.scale(3).add(&self.cost_model.if_overhead());
308                // All non-wildcard arms need comparison overhead
309                let num_checked_arms = arms
310                    .iter()
311                    .filter(|a| !matches!(a.pattern.node, MatchPattern::Wildcard))
312                    .count() as u64;
313                let check_cost = arm_overhead.scale(num_checked_arms);
314                // Worst-case body: max across all arms
315                let max_body = arms
316                    .iter()
317                    .map(|a| self.cost_block(&a.body.node))
318                    .fold(TableCost::ZERO, |acc, c| acc.max(&c));
319                scrutinee_cost.add(&check_cost).add(&max_body)
320            }
321            Stmt::Seal { fields, .. } => {
322                // push tag + field exprs + padding pushes + hash + write_io 5
323                // Hash rate is 10 (tag + up to 9 fields); excess fields need extra hashes.
324                let mut cost = stack_op.clone(); // push tag
325                for (_name, val) in fields {
326                    cost = cost.add(&self.cost_expr(&val.node));
327                }
328                let padding = 9usize.saturating_sub(fields.len());
329                for _ in 0..padding {
330                    cost = cost.add(&stack_op); // push 0 padding
331                }
332                // hash (one per 10 elements; extra hashes if >9 fields)
333                let hash_count = (1 + fields.len()).div_ceil(10);
334                for _ in 0..hash_count {
335                    cost = cost.add(&self.cost_model.builtin_cost("hash"));
336                }
337                // write_io 5
338                cost = cost.add(&self.cost_model.builtin_cost("pub_write5"));
339                cost
340            }
341        }
342    }
343}