1use std::collections::BTreeMap;
2
3use super::model::{create_cost_model, CostModel, TableCost};
4use crate::ast::*;
5use crate::field::proof;
6
7#[derive(Clone, Debug)]
11pub struct FunctionCost {
12 pub name: String,
13 pub cost: TableCost,
14 pub per_iteration: Option<(TableCost, u64)>,
16}
17
18#[derive(Clone, Debug)]
20pub struct ProgramCost {
21 pub program_name: String,
22 pub functions: Vec<FunctionCost>,
23 pub total: TableCost,
24 pub table_names: Vec<String>,
26 pub table_short_names: Vec<String>,
28 pub attestation_hash_rows: u64,
30 pub padded_height: u64,
31 pub estimated_proving_ns: u64,
32 pub loop_bound_waste: Vec<(String, u64, u64)>, }
35
36impl ProgramCost {
37 pub fn short_names(&self) -> Vec<&str> {
39 self.table_short_names.iter().map(|s| s.as_str()).collect()
40 }
41
42 pub fn long_names(&self) -> Vec<&str> {
44 self.table_names.iter().map(|s| s.as_str()).collect()
45 }
46}
47
48pub(crate) struct CostAnalyzer<'a> {
55 pub(crate) cost_model: &'a dyn CostModel,
57 pub(crate) fn_bodies: BTreeMap<String, FnDef>,
59 fn_costs: BTreeMap<String, TableCost>,
61 in_progress: Vec<String>,
63 pub(crate) loop_bound_waste: Vec<(String, u64, u64)>,
65}
66
67impl Default for CostAnalyzer<'_> {
68 fn default() -> Self {
69 Self::for_target("triton")
70 }
71}
72
73impl<'a> CostAnalyzer<'a> {
74 pub(crate) fn for_target(target_name: &str) -> Self {
76 Self::with_cost_model(create_cost_model(target_name))
77 }
78
79 pub(crate) fn with_cost_model(cost_model: &'a dyn CostModel) -> Self {
81 Self {
82 cost_model,
83 fn_bodies: BTreeMap::new(),
84 fn_costs: BTreeMap::new(),
85 in_progress: Vec::new(),
86 loop_bound_waste: Vec::new(),
87 }
88 }
89
90 pub(crate) fn analyze_file(&mut self, file: &File) -> ProgramCost {
92 for item in &file.items {
94 if let Item::Fn(func) = &item.node {
95 self.fn_bodies.insert(func.name.node.clone(), func.clone());
96 }
97 }
98
99 let mut functions = Vec::new();
101 let fn_names: Vec<String> = self.fn_bodies.keys().cloned().collect();
102 for name in &fn_names {
103 let func = self
104 .fn_bodies
105 .get(name)
106 .expect("name comes from fn_bodies.keys()")
107 .clone();
108 let cost = self.cost_fn(&func);
109 let per_iteration = self.find_loop_iteration_cost(&func);
110 functions.push(FunctionCost {
111 name: name.clone(),
112 cost,
113 per_iteration,
114 });
115 }
116
117 let total = if let Some(main_cost) = self.fn_costs.get("main") {
119 main_cost.add(&self.cost_model.call_overhead()) } else {
121 functions
122 .iter()
123 .fold(TableCost::ZERO, |acc, f| acc.add(&f.cost))
124 };
125
126 let instruction_count = total.get(0).max(10);
129 let hash_rows = self.cost_model.hash_rows_per_permutation();
130 let attestation_hash_rows = instruction_count.div_ceil(10) * hash_rows;
131
132 let max_height = total.max_height().max(attestation_hash_rows);
134 let padded_height = proof::padded_height(max_height);
135
136 let columns = self.cost_model.trace_column_count();
137 let estimated_proving_ns = proof::estimate_proving_ns(padded_height, columns);
138
139 for item in &file.items {
141 if let Item::Fn(func) = &item.node {
142 if let Some(body) = &func.body {
143 self.scan_loop_bound_waste(&func.name.node, &body.node);
144 }
145 }
146 }
147
148 ProgramCost {
149 program_name: file.name.node.clone(),
150 functions,
151 total,
152 table_names: self
153 .cost_model
154 .table_names()
155 .iter()
156 .map(|s| s.to_string())
157 .collect(),
158 table_short_names: self
159 .cost_model
160 .table_short_names()
161 .iter()
162 .map(|s| s.to_string())
163 .collect(),
164 attestation_hash_rows,
165 padded_height,
166 estimated_proving_ns,
167 loop_bound_waste: std::mem::take(&mut self.loop_bound_waste),
168 }
169 }
170
171 pub(crate) fn cost_fn(&mut self, func: &FnDef) -> TableCost {
172 if let Some(cached) = self.fn_costs.get(&func.name.node) {
173 return *cached;
174 }
175
176 if self.in_progress.contains(&func.name.node) {
179 return TableCost::ZERO;
180 }
181
182 let depth_before = self.in_progress.len();
183 self.in_progress.push(func.name.node.clone());
184
185 let cost = if let Some(body) = &func.body {
186 self.cost_block(&body.node)
187 } else {
188 TableCost::ZERO
189 };
190
191 self.in_progress.pop();
192
193 if depth_before == 0 {
197 self.fn_costs.insert(func.name.node.clone(), cost);
198 }
199 cost
200 }
201
202 pub(crate) fn cost_block(&mut self, block: &Block) -> TableCost {
203 let mut cost = TableCost::ZERO;
204 for stmt in &block.stmts {
205 cost = cost.add(&self.cost_stmt(&stmt.node));
206 }
207 if let Some(tail) = &block.tail_expr {
208 cost = cost.add(&self.cost_expr(&tail.node));
209 }
210 cost
211 }
212
213 pub(crate) fn cost_stmt(&mut self, stmt: &Stmt) -> TableCost {
214 let stack_op = self.cost_model.stack_op();
215 match stmt {
216 Stmt::Let { init, .. } => {
217 self.cost_expr(&init.node).add(&stack_op)
219 }
220 Stmt::Assign { value, .. } => {
221 self.cost_expr(&value.node).add(&stack_op).add(&stack_op)
223 }
224 Stmt::TupleAssign { names, value } => {
225 let mut cost = self.cost_expr(&value.node);
226 for _ in names {
228 cost = cost.add(&stack_op).add(&stack_op);
229 }
230 cost
231 }
232 Stmt::If {
233 cond,
234 then_block,
235 else_block,
236 } => {
237 let cond_cost = self.cost_expr(&cond.node);
238 let then_cost = self.cost_block(&then_block.node);
239 let else_cost = if let Some(eb) = else_block {
240 self.cost_block(&eb.node)
241 } else {
242 TableCost::ZERO
243 };
244 cond_cost
246 .add(&then_cost.max(&else_cost))
247 .add(&self.cost_model.if_overhead())
248 }
249 Stmt::For {
250 end, bound, body, ..
251 } => {
252 let end_cost = self.cost_expr(&end.node);
253 let body_cost = self.cost_block(&body.node);
254 let iterations = if let Some(b) = bound {
256 *b
257 } else if let Expr::Literal(Literal::Integer(n)) = &end.node {
258 *n
259 } else {
260 self.loop_bound_waste.push((
264 self.in_progress.last().cloned().unwrap_or_default(),
265 1, 0, ));
268 1
269 };
270 let per_iter = body_cost.add(&self.cost_model.loop_overhead());
272 end_cost.add(&per_iter.scale(iterations))
273 }
274 Stmt::Expr(expr) => self.cost_expr(&expr.node),
275 Stmt::Return(val) => {
276 if let Some(v) = val {
277 self.cost_expr(&v.node)
278 } else {
279 TableCost::ZERO
280 }
281 }
282 Stmt::Reveal { fields, .. } => {
283 let io_cost = self.cost_model.builtin_cost("pub_write");
285 let mut cost = stack_op.clone(); cost = cost.add(&io_cost); for (_name, val) in fields {
288 cost = cost.add(&self.cost_expr(&val.node));
289 cost = cost.add(&io_cost); }
291 cost
292 }
293 Stmt::Asm { body, .. } => {
294 let line_count = body
296 .lines()
297 .filter(|l| {
298 let t = l.trim();
299 !t.is_empty() && !t.starts_with("//")
300 })
301 .count() as u64;
302 stack_op.scale(line_count)
303 }
304 Stmt::Match { expr, arms } => {
305 let scrutinee_cost = self.cost_expr(&expr.node);
306 let arm_overhead = stack_op.scale(3).add(&self.cost_model.if_overhead());
308 let num_checked_arms = arms
310 .iter()
311 .filter(|a| !matches!(a.pattern.node, MatchPattern::Wildcard))
312 .count() as u64;
313 let check_cost = arm_overhead.scale(num_checked_arms);
314 let max_body = arms
316 .iter()
317 .map(|a| self.cost_block(&a.body.node))
318 .fold(TableCost::ZERO, |acc, c| acc.max(&c));
319 scrutinee_cost.add(&check_cost).add(&max_body)
320 }
321 Stmt::Seal { fields, .. } => {
322 let mut cost = stack_op.clone(); for (_name, val) in fields {
326 cost = cost.add(&self.cost_expr(&val.node));
327 }
328 let padding = 9usize.saturating_sub(fields.len());
329 for _ in 0..padding {
330 cost = cost.add(&stack_op); }
332 let hash_count = (1 + fields.len()).div_ceil(10);
334 for _ in 0..hash_count {
335 cost = cost.add(&self.cost_model.builtin_cost("hash"));
336 }
337 cost = cost.add(&self.cost_model.builtin_cost("pub_write5"));
339 cost
340 }
341 }
342 }
343}