trident/cost/model/
mod.rs1pub mod triton;
2
3use crate::ast::BinOp;
4
5pub(crate) use triton::TritonCostModel;
6
7pub const MAX_TABLES: usize = 8;
13
14pub(crate) trait CostModel {
20 fn table_names(&self) -> &[&str];
22
23 fn table_short_names(&self) -> &[&str];
25
26 fn builtin_cost(&self, name: &str) -> TableCost;
28
29 fn binop_cost(&self, op: &BinOp) -> TableCost;
31
32 fn call_overhead(&self) -> TableCost;
34
35 fn stack_op(&self) -> TableCost;
37
38 fn if_overhead(&self) -> TableCost;
40
41 fn loop_overhead(&self) -> TableCost;
43
44 fn hash_rows_per_permutation(&self) -> u64;
46
47 fn trace_column_count(&self) -> u64;
49}
50
51#[derive(Clone, Copy, Debug, PartialEq, Eq)]
59pub struct TableCost {
60 pub values: [u64; MAX_TABLES],
62 pub count: u8,
64}
65
66impl Default for TableCost {
67 fn default() -> Self {
68 Self::ZERO
69 }
70}
71
72impl TableCost {
73 pub const ZERO: TableCost = TableCost {
77 values: [0; MAX_TABLES],
78 count: 0,
79 };
80
81 pub fn from_slice(vals: &[u64]) -> TableCost {
83 let mut values = [0u64; MAX_TABLES];
84 let n = vals.len().min(MAX_TABLES);
85 values[..n].copy_from_slice(&vals[..n]);
86 TableCost {
87 values,
88 count: n as u8,
89 }
90 }
91
92 pub fn get(&self, i: usize) -> u64 {
94 self.values[i]
95 }
96
97 pub fn is_nonzero(&self) -> bool {
99 let n = self.count as usize;
100 self.values[..n].iter().any(|&v| v > 0)
101 }
102
103 pub fn add(&self, other: &TableCost) -> TableCost {
104 let n = self.count.max(other.count) as usize;
105 let mut values = [0u64; MAX_TABLES];
106 for i in 0..n {
107 values[i] = self.values[i] + other.values[i];
108 }
109 TableCost {
110 values,
111 count: n as u8,
112 }
113 }
114
115 pub fn scale(&self, factor: u64) -> TableCost {
116 let n = self.count as usize;
117 let mut values = [0u64; MAX_TABLES];
118 for i in 0..n {
119 values[i] = self.values[i].saturating_mul(factor);
120 }
121 TableCost {
122 values,
123 count: self.count,
124 }
125 }
126
127 pub fn max(&self, other: &TableCost) -> TableCost {
128 let n = self.count.max(other.count) as usize;
129 let mut values = [0u64; MAX_TABLES];
130 for i in 0..n {
131 values[i] = self.values[i].max(other.values[i]);
132 }
133 TableCost {
134 values,
135 count: n as u8,
136 }
137 }
138
139 pub fn max_height(&self) -> u64 {
141 let n = self.count as usize;
142 self.values[..n].iter().copied().max().unwrap_or(0)
143 }
144
145 pub fn dominant_table<'a>(&self, short_names: &[&'a str]) -> &'a str {
147 let n = self.count as usize;
148 if n == 0 || short_names.is_empty() {
149 return "?";
150 }
151 let max = self.max_height();
152 if max == 0 {
153 return short_names[0];
154 }
155 for i in 0..n.min(short_names.len()) {
156 if self.values[i] == max {
157 return short_names[i];
158 }
159 }
160 short_names[0]
161 }
162
163 pub fn to_json_value(&self, names: &[&str]) -> String {
165 let n = self.count as usize;
166 let mut parts = Vec::new();
167 for i in 0..n.min(names.len()) {
168 let escaped = names[i].replace('\\', "\\\\").replace('"', "\\\"");
170 parts.push(format!("\"{}\": {}", escaped, self.values[i]));
171 }
172 format!("{{{}}}", parts.join(", "))
173 }
174
175 pub fn from_json_value(s: &str, names: &[&str]) -> Option<TableCost> {
177 fn extract_u64(s: &str, key: &str) -> Option<u64> {
178 let escaped = key.replace('\\', "\\\\").replace('"', "\\\"");
180 let needle = format!("\"{}\"", escaped);
181 let mut search_from = 0;
183 while let Some(pos) = s[search_from..].find(&needle) {
184 let idx = search_from + pos;
185 let rest = &s[idx + needle.len()..];
186 let trimmed = rest.trim_start();
188 if trimmed.starts_with(':') {
189 let after_colon = trimmed[1..].trim_start();
190 let end = after_colon
191 .find(|c: char| !c.is_ascii_digit())
192 .unwrap_or(after_colon.len());
193 return after_colon[..end].parse().ok();
194 }
195 search_from = idx + needle.len();
196 }
197 None
198 }
199
200 let mut values = [0u64; MAX_TABLES];
201 for (i, name) in names.iter().enumerate() {
202 values[i] = extract_u64(s, name)?;
203 }
204 Some(TableCost {
205 values,
206 count: names.len() as u8,
207 })
208 }
209
210 pub fn format_annotation(&self, short_names: &[&str]) -> String {
212 let n = self.count as usize;
213 let mut parts = Vec::new();
214 for i in 0..n.min(short_names.len()) {
215 if self.values[i] > 0 {
216 parts.push(format!("{}={}", short_names[i], self.values[i]));
217 }
218 }
219 parts.join(" ")
220 }
221}
222
223pub(crate) fn cost_builtin(target: &str, name: &str) -> TableCost {
225 create_cost_model(target).builtin_cost(name)
226}
227
228pub(crate) fn create_cost_model(target_name: &str) -> &'static dyn CostModel {
232 match target_name {
233 "triton" => &TritonCostModel,
234 _ => &TritonCostModel,
235 }
236}