Skip to main content

trident/cost/
scorer.rs

1//! Dynamic TASM table profiler.
2//!
3//! Counts actual table row increments per TASM instruction to compute
4//! the cliff-aware proving cost. Lightweight — no memory model, no
5//! hash computation. Just table height counters.
6
7/// Table heights for Triton VM's 6 tracked Algebraic Execution Tables.
8#[derive(Clone, Debug, Default, PartialEq, Eq)]
9pub struct TableProfile {
10    /// [processor, hash, u32, op_stack, ram, jump_stack]
11    pub heights: [u64; 6],
12}
13
14const PROC: usize = 0;
15const HASH: usize = 1;
16const U32: usize = 2;
17const OPST: usize = 3;
18const RAM: usize = 4;
19const JUMP: usize = 5;
20
21impl TableProfile {
22    /// Maximum height across all tables.
23    pub fn max_height(&self) -> u64 {
24        self.heights.iter().copied().max().unwrap_or(0)
25    }
26
27    /// Padded height: next power of 2 above max table height.
28    pub fn padded_height(&self) -> u64 {
29        let max = self.max_height();
30        if max == 0 {
31            return 1;
32        }
33        max.next_power_of_two()
34    }
35
36    /// Cliff-aware proving cost = padded_height.
37    pub fn cost(&self) -> u64 {
38        self.padded_height()
39    }
40
41    /// Index of the tallest table.
42    pub fn dominant_table(&self) -> usize {
43        self.heights
44            .iter()
45            .enumerate()
46            .max_by_key(|(_, h)| *h)
47            .map(|(i, _)| i)
48            .unwrap_or(0)
49    }
50
51    /// Table name for display.
52    pub fn table_name(idx: usize) -> &'static str {
53        match idx {
54            PROC => "processor",
55            HASH => "hash",
56            U32 => "u32",
57            OPST => "op_stack",
58            RAM => "ram",
59            JUMP => "jump_stack",
60            _ => "unknown",
61        }
62    }
63
64    /// Whether the neural candidate crossed a power-of-2 boundary
65    /// that the baseline did not.
66    pub fn is_cliff_jump(&self, baseline: &TableProfile) -> bool {
67        self.padded_height() < baseline.padded_height()
68    }
69
70    /// Whether the dominant (tallest) table changed between baseline and this profile.
71    pub fn is_table_rebalance(&self, baseline: &TableProfile) -> bool {
72        self.dominant_table() != baseline.dominant_table()
73            && self.max_height() < baseline.max_height()
74    }
75}
76
77/// Profile a sequence of TASM instruction lines, counting table row increments.
78///
79/// Instructions are whitespace-trimmed. Labels, comments, and blank lines
80/// are ignored. Returns cumulative table heights.
81pub fn profile_tasm(lines: &[&str]) -> TableProfile {
82    let mut p = TableProfile::default();
83    for line in lines {
84        let t = line.trim();
85        if t.is_empty() || t.starts_with("//") || t.ends_with(':') {
86            continue;
87        }
88        profile_instruction(t, &mut p);
89    }
90    p
91}
92
93/// Profile from a newline-separated TASM string.
94pub fn profile_tasm_str(tasm: &str) -> TableProfile {
95    let lines: Vec<&str> = tasm.lines().collect();
96    profile_tasm(&lines)
97}
98
99/// Add table row increments for a single TASM instruction.
100fn profile_instruction(instr: &str, p: &mut TableProfile) {
101    let parts: Vec<&str> = instr.split_whitespace().collect();
102    if parts.is_empty() {
103        return;
104    }
105    let op = parts[0];
106    match op {
107        // Stack operations: 1 proc + 1 opstack
108        "push" | "pop" | "dup" | "swap" | "pick" | "place" => {
109            p.heights[PROC] += 1;
110            p.heights[OPST] += 1;
111        }
112
113        // Arithmetic: 1 proc + 1 opstack
114        "add" | "mul" | "eq" | "split" | "invert" => {
115            p.heights[PROC] += 1;
116            p.heights[OPST] += 1;
117        }
118
119        // U32 operations: 1 proc + 33 u32 + 1 opstack
120        "lt" | "and" | "xor" | "pow" | "div_mod" => {
121            p.heights[PROC] += 1;
122            p.heights[U32] += 33;
123            p.heights[OPST] += 1;
124        }
125
126        // U32 no-stack: 1 proc + 33 u32
127        "log_2_floor" | "pop_count" => {
128            p.heights[PROC] += 1;
129            p.heights[U32] += 33;
130        }
131
132        // Hash operations: 1 proc + 6 hash + 1 opstack
133        "hash" => {
134            p.heights[PROC] += 1;
135            p.heights[HASH] += 6;
136            p.heights[OPST] += 1;
137        }
138        "sponge_init" => {
139            p.heights[PROC] += 1;
140            p.heights[HASH] += 6;
141        }
142        "sponge_absorb" | "sponge_squeeze" => {
143            p.heights[PROC] += 1;
144            p.heights[HASH] += 6;
145            p.heights[OPST] += 1;
146        }
147        "sponge_absorb_mem" => {
148            p.heights[PROC] += 1;
149            p.heights[HASH] += 6;
150            p.heights[OPST] += 1;
151            p.heights[RAM] += 10;
152        }
153
154        // Merkle operations
155        "merkle_step" => {
156            p.heights[PROC] += 1;
157            p.heights[HASH] += 6;
158            p.heights[U32] += 33;
159        }
160        "merkle_step_mem" => {
161            p.heights[PROC] += 1;
162            p.heights[HASH] += 6;
163            p.heights[U32] += 33;
164            p.heights[RAM] += 5;
165        }
166
167        // I/O: 1 proc + 1 opstack
168        "read_io" | "write_io" => {
169            p.heights[PROC] += 1;
170            p.heights[OPST] += 1;
171        }
172
173        // Witness: 1 proc + 1 opstack
174        "divine" => {
175            p.heights[PROC] += 1;
176            p.heights[OPST] += 1;
177        }
178
179        // Memory: 2 proc + 2 opstack + 1 ram (per word)
180        "read_mem" | "write_mem" => {
181            let width = parts
182                .get(1)
183                .and_then(|s| s.parse::<u64>().ok())
184                .unwrap_or(1);
185            p.heights[PROC] += 2;
186            p.heights[OPST] += 2;
187            p.heights[RAM] += width;
188        }
189
190        // Control flow: call/return affect jump stack
191        "call" => {
192            p.heights[PROC] += 1;
193            p.heights[JUMP] += 1;
194        }
195        "return" => {
196            p.heights[PROC] += 1;
197            p.heights[JUMP] += 1;
198        }
199        "recurse" | "recurse_or_return" => {
200            p.heights[PROC] += 1;
201            p.heights[JUMP] += 1;
202        }
203
204        // Assertions: 1 proc + 1 opstack
205        "assert" | "assert_vector" => {
206            p.heights[PROC] += 1;
207            p.heights[OPST] += 1;
208        }
209
210        // Skiz: 1 proc + 1 opstack
211        "skiz" => {
212            p.heights[PROC] += 1;
213            p.heights[OPST] += 1;
214        }
215
216        // Halt: 1 proc
217        "halt" => {
218            p.heights[PROC] += 1;
219        }
220
221        // Nop: 1 proc
222        "nop" => {
223            p.heights[PROC] += 1;
224        }
225
226        // Extension field: 1 proc
227        "xb_mul" | "x_invert" | "xx_dot_step" | "xb_dot_step" => {
228            p.heights[PROC] += 1;
229            p.heights[OPST] += 1;
230        }
231
232        // Unknown instruction: count as 1 proc row (conservative)
233        _ => {
234            p.heights[PROC] += 1;
235        }
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn empty_program() {
245        let p = profile_tasm(&[]);
246        assert_eq!(p.max_height(), 0);
247        assert_eq!(p.padded_height(), 1);
248        assert_eq!(p.cost(), 1);
249    }
250
251    #[test]
252    fn simple_add() {
253        let p = profile_tasm(&["push 1", "push 2", "add"]);
254        assert_eq!(p.heights[PROC], 3);
255        assert_eq!(p.heights[OPST], 3);
256        assert_eq!(p.heights[HASH], 0);
257    }
258
259    #[test]
260    fn hash_dominance() {
261        let p = profile_tasm(&["hash", "hash", "hash"]);
262        assert_eq!(p.heights[PROC], 3);
263        assert_eq!(p.heights[HASH], 18);
264        assert_eq!(p.dominant_table(), HASH);
265    }
266
267    #[test]
268    fn cliff_boundary() {
269        // 1024 proc rows pads to 1024
270        let mut p = TableProfile::default();
271        p.heights[PROC] = 1024;
272        assert_eq!(p.padded_height(), 1024);
273
274        // 1025 proc rows pads to 2048
275        p.heights[PROC] = 1025;
276        assert_eq!(p.padded_height(), 2048);
277    }
278
279    #[test]
280    fn cliff_jump_detection() {
281        let mut baseline = TableProfile::default();
282        baseline.heights[PROC] = 1025;
283
284        let mut candidate = TableProfile::default();
285        candidate.heights[PROC] = 1024;
286
287        assert!(candidate.is_cliff_jump(&baseline));
288        assert!(!baseline.is_cliff_jump(&candidate));
289    }
290
291    #[test]
292    fn table_rebalance_detection() {
293        let mut baseline = TableProfile::default();
294        baseline.heights[PROC] = 1000;
295        baseline.heights[HASH] = 500;
296
297        let mut candidate = TableProfile::default();
298        candidate.heights[PROC] = 700;
299        candidate.heights[HASH] = 700;
300
301        assert!(candidate.is_table_rebalance(&baseline));
302    }
303
304    #[test]
305    fn labels_and_comments_ignored() {
306        let p = profile_tasm(&["__main:", "  // comment", "  push 1", ""]);
307        assert_eq!(p.heights[PROC], 1);
308    }
309
310    #[test]
311    fn memory_width() {
312        let p = profile_tasm(&["read_mem 5"]);
313        assert_eq!(p.heights[RAM], 5);
314        assert_eq!(p.heights[PROC], 2);
315    }
316
317    #[test]
318    fn u32_operations() {
319        let p = profile_tasm(&["lt", "and"]);
320        assert_eq!(p.heights[U32], 66); // 33 + 33
321        assert_eq!(p.heights[PROC], 2);
322    }
323
324    #[test]
325    fn call_return_jump_stack() {
326        let p = profile_tasm(&["call __foo", "return"]);
327        assert_eq!(p.heights[JUMP], 2);
328    }
329
330    #[test]
331    fn profile_from_string() {
332        let tasm = "push 1\npush 2\nadd\nwrite_io 1\nhalt\n";
333        let p = profile_tasm_str(tasm);
334        assert_eq!(p.heights[PROC], 5);
335    }
336}