Skip to main content

trident/cost/model/
mod.rs

1pub mod triton;
2
3use crate::ast::BinOp;
4
5pub(crate) use triton::TritonCostModel;
6
7// ---------------------------------------------------------------------------
8// CostModel trait — target-agnostic cost interface
9// ---------------------------------------------------------------------------
10
11/// Maximum number of cost tables any target can have.
12pub const MAX_TABLES: usize = 8;
13
14/// Trait for target-specific cost models.
15///
16/// Each target VM implements this to provide table names, per-instruction
17/// costs, and formatting for cost reports. The cost analyzer delegates all
18/// target-specific knowledge through this trait.
19pub(crate) trait CostModel {
20    /// Names of the execution tables (e.g. ["processor", "hash", "u32", ...]).
21    fn table_names(&self) -> &[&str];
22
23    /// Short display names for compact annotations (e.g. ["cc", "hash", "u32", ...]).
24    fn table_short_names(&self) -> &[&str];
25
26    /// Cost of a builtin function call by name.
27    fn builtin_cost(&self, name: &str) -> TableCost;
28
29    /// Cost of a binary operation.
30    fn binop_cost(&self, op: &BinOp) -> TableCost;
31
32    /// Overhead cost for a function call/return pair.
33    fn call_overhead(&self) -> TableCost;
34
35    /// Cost of a single stack manipulation (push/dup/swap).
36    fn stack_op(&self) -> TableCost;
37
38    /// Overhead cost for an if/else branch.
39    fn if_overhead(&self) -> TableCost;
40
41    /// Overhead cost per loop iteration.
42    fn loop_overhead(&self) -> TableCost;
43
44    /// Number of hash table rows per hash permutation.
45    fn hash_rows_per_permutation(&self) -> u64;
46
47    /// Number of trace columns (used for proving time estimation).
48    fn trace_column_count(&self) -> u64;
49}
50
51// ---------------------------------------------------------------------------
52// TableCost — target-generic cost vector
53// ---------------------------------------------------------------------------
54
55/// Cost across execution tables. Fixed-size array indexed by table position
56/// as defined by the target's CostModel. Table names are external metadata,
57/// not baked into this struct.
58#[derive(Clone, Copy, Debug, PartialEq, Eq)]
59pub struct TableCost {
60    /// Cost values indexed by table position.
61    pub values: [u64; MAX_TABLES],
62    /// Number of active tables.
63    pub count: u8,
64}
65
66impl Default for TableCost {
67    fn default() -> Self {
68        Self::ZERO
69    }
70}
71
72impl TableCost {
73    /// Zero cost with no active tables. When added to a real cost,
74    /// the result inherits the other cost's `count` (via `max`).
75    /// Use `from_slice(&[0; N])` when you need a zero cost with a specific table count.
76    pub const ZERO: TableCost = TableCost {
77        values: [0; MAX_TABLES],
78        count: 0,
79    };
80
81    /// Build from a slice of values (used by CostModel implementations).
82    pub fn from_slice(vals: &[u64]) -> TableCost {
83        let mut values = [0u64; MAX_TABLES];
84        let n = vals.len().min(MAX_TABLES);
85        values[..n].copy_from_slice(&vals[..n]);
86        TableCost {
87            values,
88            count: n as u8,
89        }
90    }
91
92    /// Get value at table index.
93    pub fn get(&self, i: usize) -> u64 {
94        self.values[i]
95    }
96
97    /// Check if any table has non-zero cost.
98    pub fn is_nonzero(&self) -> bool {
99        let n = self.count as usize;
100        self.values[..n].iter().any(|&v| v > 0)
101    }
102
103    pub fn add(&self, other: &TableCost) -> TableCost {
104        let n = self.count.max(other.count) as usize;
105        let mut values = [0u64; MAX_TABLES];
106        for i in 0..n {
107            values[i] = self.values[i] + other.values[i];
108        }
109        TableCost {
110            values,
111            count: n as u8,
112        }
113    }
114
115    pub fn scale(&self, factor: u64) -> TableCost {
116        let n = self.count as usize;
117        let mut values = [0u64; MAX_TABLES];
118        for i in 0..n {
119            values[i] = self.values[i].saturating_mul(factor);
120        }
121        TableCost {
122            values,
123            count: self.count,
124        }
125    }
126
127    pub fn max(&self, other: &TableCost) -> TableCost {
128        let n = self.count.max(other.count) as usize;
129        let mut values = [0u64; MAX_TABLES];
130        for i in 0..n {
131            values[i] = self.values[i].max(other.values[i]);
132        }
133        TableCost {
134            values,
135            count: n as u8,
136        }
137    }
138
139    /// The maximum height across all active tables.
140    pub fn max_height(&self) -> u64 {
141        let n = self.count as usize;
142        self.values[..n].iter().copied().max().unwrap_or(0)
143    }
144
145    /// Which table is the tallest, by short name.
146    pub fn dominant_table<'a>(&self, short_names: &[&'a str]) -> &'a str {
147        let n = self.count as usize;
148        if n == 0 || short_names.is_empty() {
149            return "?";
150        }
151        let max = self.max_height();
152        if max == 0 {
153            return short_names[0];
154        }
155        for i in 0..n.min(short_names.len()) {
156            if self.values[i] == max {
157                return short_names[i];
158            }
159        }
160        short_names[0]
161    }
162
163    /// Serialize to a JSON object string using the given table names as keys.
164    pub fn to_json_value(&self, names: &[&str]) -> String {
165        let n = self.count as usize;
166        let mut parts = Vec::new();
167        for i in 0..n.min(names.len()) {
168            // Escape quotes and backslashes in table names for valid JSON
169            let escaped = names[i].replace('\\', "\\\\").replace('"', "\\\"");
170            parts.push(format!("\"{}\": {}", escaped, self.values[i]));
171        }
172        format!("{{{}}}", parts.join(", "))
173    }
174
175    /// Deserialize from a JSON object string using the given table names as keys.
176    pub fn from_json_value(s: &str, names: &[&str]) -> Option<TableCost> {
177        fn extract_u64(s: &str, key: &str) -> Option<u64> {
178            // Escape the key the same way to_json_value does
179            let escaped = key.replace('\\', "\\\\").replace('"', "\\\"");
180            let needle = format!("\"{}\"", escaped);
181            // Search for all occurrences and pick the one that's a real key
182            let mut search_from = 0;
183            while let Some(pos) = s[search_from..].find(&needle) {
184                let idx = search_from + pos;
185                let rest = &s[idx + needle.len()..];
186                // Must be followed by optional whitespace then ':'
187                let trimmed = rest.trim_start();
188                if trimmed.starts_with(':') {
189                    let after_colon = trimmed[1..].trim_start();
190                    let end = after_colon
191                        .find(|c: char| !c.is_ascii_digit())
192                        .unwrap_or(after_colon.len());
193                    return after_colon[..end].parse().ok();
194                }
195                search_from = idx + needle.len();
196            }
197            None
198        }
199
200        let mut values = [0u64; MAX_TABLES];
201        for (i, name) in names.iter().enumerate() {
202            values[i] = extract_u64(s, name)?;
203        }
204        Some(TableCost {
205            values,
206            count: names.len() as u8,
207        })
208    }
209
210    /// Format a compact annotation string showing non-zero cost fields.
211    pub fn format_annotation(&self, short_names: &[&str]) -> String {
212        let n = self.count as usize;
213        let mut parts = Vec::new();
214        for i in 0..n.min(short_names.len()) {
215            if self.values[i] > 0 {
216                parts.push(format!("{}={}", short_names[i], self.values[i]));
217            }
218        }
219        parts.join(" ")
220    }
221}
222
223/// Look up builtin cost using a named target's cost model.
224pub(crate) fn cost_builtin(target: &str, name: &str) -> TableCost {
225    create_cost_model(target).builtin_cost(name)
226}
227
228/// Select the cost model for a given target name.
229///
230/// Falls back to Triton for unknown targets (currently the only model).
231pub(crate) fn create_cost_model(target_name: &str) -> &'static dyn CostModel {
232    match target_name {
233        "triton" => &TritonCostModel,
234        _ => &TritonCostModel,
235    }
236}