Skip to main content

trident/api/
tools.rs

1use super::*;
2
3pub fn analyze_costs(source: &str, filename: &str) -> Result<cost::ProgramCost, Vec<Diagnostic>> {
4    let file = crate::parse_source(source, filename)?;
5
6    if let Err(errors) = TypeChecker::new().check_file(&file) {
7        render_diagnostics(&errors, filename, source);
8        return Err(errors);
9    }
10
11    let cost = cost::CostAnalyzer::default().analyze_file(&file);
12    Ok(cost)
13}
14
15/// Parse, type-check, and compute cost analysis for a multi-module project.
16/// Falls back to single-file analysis if module resolution fails.
17pub fn analyze_costs_project(
18    entry_path: &Path,
19    options: &CompileOptions,
20) -> Result<cost::ProgramCost, Vec<Diagnostic>> {
21    use crate::pipeline::PreparedProject;
22
23    let project = PreparedProject::build(entry_path, options)?;
24
25    // Analyze costs for the program file (last in topological order)
26    if let Some(file) = project.last_file() {
27        let cost = cost::CostAnalyzer::for_target(&options.target_config.name).analyze_file(file);
28        Ok(cost)
29    } else {
30        Err(vec![Diagnostic::error(
31            "no program file found".to_string(),
32            span::Span::dummy(),
33        )])
34    }
35}
36
37/// Parse, type-check, and verify a project using symbolic execution + solver.
38///
39/// Analyzes all functions across all modules, not just `main`.
40/// Returns a `VerificationReport` with static analysis, random testing (Schwartz-Zippel),
41/// and bounded model checking results.
42pub fn verify_project(entry_path: &Path) -> Result<solve::VerificationReport, Vec<Diagnostic>> {
43    use crate::pipeline::PreparedProject;
44
45    let project = PreparedProject::build_default(entry_path)?;
46
47    // Collect constraint systems from all functions in all modules
48    let mut combined = sym::ConstraintSystem::new();
49    for pm in &project.modules {
50        for (_, system) in sym::analyze_all(&pm.file) {
51            combined.constraints.extend(system.constraints);
52            combined.num_variables += system.num_variables;
53            for (k, v) in system.variables {
54                combined.variables.insert(k, v);
55            }
56            combined.pub_inputs.extend(system.pub_inputs);
57            combined.pub_outputs.extend(system.pub_outputs);
58            combined.divine_inputs.extend(system.divine_inputs);
59        }
60    }
61
62    Ok(solve::verify(&combined))
63}
64
65/// Verify all functions in a project, returning per-function results.
66///
67/// Each entry in the returned vec is `(module_name, fn_name, report)`.
68pub fn verify_project_per_function(
69    entry_path: &Path,
70) -> Result<Vec<(String, String, solve::VerificationReport)>, Vec<Diagnostic>> {
71    use crate::pipeline::PreparedProject;
72
73    let project = PreparedProject::build_default(entry_path)?;
74
75    let mut results = Vec::new();
76    for pm in &project.modules {
77        let module_name = pm.file.name.node.clone();
78        for (fn_name, system) in sym::analyze_all(&pm.file) {
79            let report = solve::verify(&system);
80            results.push((module_name.clone(), fn_name, report));
81        }
82    }
83
84    Ok(results)
85}
86
87/// Count the number of TASM instructions in a compiled output string.
88/// Skips comments, labels, blank lines, and the halt instruction.
89pub fn count_tasm_instructions(tasm: &str) -> usize {
90    tasm.lines()
91        .map(|line| line.trim())
92        .filter(|line| {
93            !line.is_empty() && !line.starts_with("//") && !line.ends_with(':') && *line != "halt"
94        })
95        .count()
96}
97
98/// Parse TASM into per-function instruction counts.
99/// Returns a BTreeMap from function name (without `__` prefix) to instruction count.
100/// Only counts labeled functions; unlabeled preamble code is ignored.
101pub fn parse_tasm_functions(tasm: &str) -> BTreeMap<String, usize> {
102    let mut functions = BTreeMap::new();
103    let mut current_label: Option<String> = None;
104    let mut current_count: usize = 0;
105
106    for line in tasm.lines() {
107        let trimmed = line.trim();
108        if trimmed.is_empty() || trimmed.starts_with("//") {
109            continue;
110        }
111        if trimmed.ends_with(':') {
112            if let Some(label) = current_label.take() {
113                if current_count > 0 {
114                    functions.insert(label, current_count);
115                }
116            }
117            let raw = trimmed.trim_end_matches(':');
118            // Normalize label: strip `__` prefix or module-mangled prefix.
119            // `__funcname` -> `funcname`
120            // `std_crypto_mod__funcname` -> `funcname`
121            // `__then__0` -> skip (compiler-internal deferred block)
122            let name = if let Some(pos) = raw.rfind("__") {
123                let suffix = &raw[pos + 2..];
124                if suffix.is_empty() || suffix.chars().all(|c| c.is_ascii_digit()) {
125                    // Deferred block (then/else/while + numeric id) — skip
126                    current_label = None;
127                    current_count = 0;
128                    continue;
129                }
130                suffix
131            } else {
132                raw
133            };
134            current_label = Some(name.to_string());
135            current_count = 0;
136            continue;
137        }
138        if trimmed == "halt" {
139            continue;
140        }
141        if current_label.is_some() {
142            current_count += 1;
143        }
144    }
145    if let Some(label) = current_label {
146        if current_count > 0 {
147            functions.insert(label, current_count);
148        }
149    }
150    functions
151}
152
153/// Per-function benchmark comparison.
154#[derive(Clone, Debug)]
155pub struct FunctionBenchmark {
156    pub name: String,
157    pub compiled_instructions: usize,
158    pub baseline_instructions: usize,
159}
160
161/// Module-level benchmark result with per-function comparisons.
162#[derive(Clone, Debug)]
163pub struct ModuleBenchmarkResult {
164    pub module_path: String,
165    pub functions: Vec<FunctionBenchmark>,
166    pub total_compiled: usize,
167    pub total_baseline: usize,
168}
169
170/// Format a number with comma separators (e.g. 2097152 -> "2,097,152").
171/// Returns an em-dash for zero.
172pub fn fmt_num(n: usize) -> String {
173    if n == 0 {
174        return "\u{2014}".to_string();
175    }
176    let s = n.to_string();
177    let mut result = String::with_capacity(s.len() + s.len() / 3);
178    for (i, ch) in s.chars().enumerate() {
179        if i > 0 && (s.len() - i) % 3 == 0 {
180            result.push(',');
181        }
182        result.push(ch);
183    }
184    result
185}
186
187/// Format a ratio as `N.NNx` using integer arithmetic.
188/// Returns an em-dash when `den` is zero.
189pub fn fmt_ratio(num: usize, den: usize) -> String {
190    if den == 0 {
191        "\u{2014}".to_string()
192    } else {
193        // Two decimal places via integer math: ratio_100 = num * 100 / den
194        let ratio_100 = num * 100 / den;
195        format!("{}.{:02}x", ratio_100 / 100, ratio_100 % 100)
196    }
197}
198
199/// Return a status icon: checkmark when `num <= 2*den`, warning triangle
200/// otherwise, space when `den` is zero.
201pub fn status_icon(num: usize, den: usize) -> &'static str {
202    if den == 0 {
203        " "
204    } else if num <= 2 * den {
205        "\u{2713}"
206    } else {
207        "\u{25b3}"
208    }
209}
210
211impl ModuleBenchmarkResult {
212    pub fn format_header() -> String {
213        let top = format!(
214            "\u{250c}{}\u{252c}{}\u{252c}{}\u{252c}{}\u{252c}{}\u{2510}",
215            "\u{2500}".repeat(30),
216            "\u{2500}".repeat(10),
217            "\u{2500}".repeat(10),
218            "\u{2500}".repeat(9),
219            "\u{2500}".repeat(3),
220        );
221        let header = format!(
222            "\u{2502} {:<28} \u{2502} {:>8} \u{2502} {:>8} \u{2502} {:>7} \u{2502}   \u{2502}",
223            "Function", "Tri", "Hand", "Ratio"
224        );
225        let mid = format!(
226            "\u{251c}{}\u{253c}{}\u{253c}{}\u{253c}{}\u{253c}{}\u{2524}",
227            "\u{2500}".repeat(30),
228            "\u{2500}".repeat(10),
229            "\u{2500}".repeat(10),
230            "\u{2500}".repeat(9),
231            "\u{2500}".repeat(3),
232        );
233        format!("{}\n{}\n{}", top, header, mid)
234    }
235
236    pub fn format_module_header(&self) -> String {
237        format!(
238            "\u{251c}{}\u{253c}{}\u{253c}{}\u{253c}{}\u{253c}{}\u{2524}\n\u{2502} {:<28} \u{2502} {:>8} \u{2502} {:>8} \u{2502} {:>7} \u{2502} {} \u{2502}",
239            "\u{2500}".repeat(30),
240            "\u{2500}".repeat(10),
241            "\u{2500}".repeat(10),
242            "\u{2500}".repeat(9),
243            "\u{2500}".repeat(3),
244            self.module_path,
245            fmt_num(self.total_compiled),
246            fmt_num(self.total_baseline),
247            fmt_ratio(self.total_compiled, self.total_baseline),
248            status_icon(self.total_compiled, self.total_baseline),
249        )
250    }
251
252    pub fn format_function(&self, f: &FunctionBenchmark) -> String {
253        format!(
254            "\u{2502}   {:<26} \u{2502} {:>8} \u{2502} {:>8} \u{2502} {:>7} \u{2502} {} \u{2502}",
255            f.name,
256            fmt_num(f.compiled_instructions),
257            fmt_num(f.baseline_instructions),
258            fmt_ratio(f.compiled_instructions, f.baseline_instructions),
259            status_icon(f.compiled_instructions, f.baseline_instructions),
260        )
261    }
262
263    pub fn format_separator() -> String {
264        format!(
265            "\u{2514}{}\u{2534}{}\u{2534}{}\u{2534}{}\u{2534}{}\u{2518}",
266            "\u{2500}".repeat(30),
267            "\u{2500}".repeat(10),
268            "\u{2500}".repeat(10),
269            "\u{2500}".repeat(9),
270            "\u{2500}".repeat(3),
271        )
272    }
273
274    /// Format a summary line. `avg_num`/`avg_den` and `max_num`/`max_den` are
275    /// numerator/denominator pairs for the average and max ratios.
276    pub fn format_summary(
277        avg_num: usize,
278        avg_den: usize,
279        max_num: usize,
280        max_den: usize,
281        count: usize,
282    ) -> String {
283        format!(
284            "  Avg: {}  Max: {}  ({} modules)",
285            fmt_ratio(avg_num, avg_den),
286            fmt_ratio(max_num, max_den),
287            count
288        )
289    }
290}
291
292/// Generate markdown documentation for a Trident project.
293pub fn generate_docs(
294    entry_path: &Path,
295    options: &CompileOptions,
296) -> Result<String, Vec<Diagnostic>> {
297    doc::generate_docs(entry_path, options)
298}
299
300/// Parse, type-check, and produce per-line cost-annotated source output.
301pub fn annotate_source(source: &str, filename: &str) -> Result<String, Vec<Diagnostic>> {
302    annotate_source_with_target(source, filename, "triton")
303}
304
305/// Like `annotate_source`, but uses the specified target's cost model.
306pub fn annotate_source_with_target(
307    source: &str,
308    filename: &str,
309    target: &str,
310) -> Result<String, Vec<Diagnostic>> {
311    let file = crate::parse_source(source, filename)?;
312
313    if let Err(errors) = TypeChecker::new().check_file(&file) {
314        render_diagnostics(&errors, filename, source);
315        return Err(errors);
316    }
317
318    let mut analyzer = cost::CostAnalyzer::for_target(target);
319    let pc = analyzer.analyze_file(&file);
320    let short_names = pc.short_names();
321    let stmt_costs = analyzer.stmt_costs(&file, source);
322
323    // Build a map from line number to aggregated cost
324    let mut line_costs: BTreeMap<u32, cost::TableCost> = BTreeMap::new();
325    for (line, cost) in &stmt_costs {
326        line_costs
327            .entry(*line)
328            .and_modify(|existing| *existing = existing.add(cost))
329            .or_insert_with(|| cost.clone());
330    }
331
332    let lines: Vec<&str> = source.lines().collect();
333    let line_count = lines.len();
334    let line_num_width = format!("{}", line_count).len().max(2);
335
336    // Find max line length for alignment
337    let max_line_len = lines.iter().map(|l| l.len()).max().unwrap_or(0).min(60);
338
339    let mut out = String::new();
340    for (i, line) in lines.iter().enumerate() {
341        let line_num = (i + 1) as u32;
342        let padded_line = format!("{:<width$}", line, width = max_line_len);
343        if let Some(cost) = line_costs.get(&line_num) {
344            let annotation = cost.format_annotation(&short_names);
345            if !annotation.is_empty() {
346                out.push_str(&format!(
347                    "{:>width$} | {}  [{}]\n",
348                    line_num,
349                    padded_line,
350                    annotation,
351                    width = line_num_width,
352                ));
353                continue;
354            }
355        }
356        out.push_str(&format!(
357            "{:>width$} | {}\n",
358            line_num,
359            line,
360            width = line_num_width,
361        ));
362    }
363
364    Ok(out)
365}
366
367/// Format Trident source code, preserving comments.
368pub fn format_source(source: &str, _filename: &str) -> Result<String, Vec<Diagnostic>> {
369    let (tokens, comments, lex_errors) = lexer::Lexer::new(source, 0).tokenize();
370    if !lex_errors.is_empty() {
371        return Err(lex_errors);
372    }
373    let file = parser::Parser::new(tokens).parse_file()?;
374    Ok(format::format_file(&file, &comments))
375}
376
377/// Type-check only, without rendering diagnostics to stderr.
378/// Used by the LSP server to get structured errors.
379pub fn check_silent(source: &str, filename: &str) -> Result<(), Vec<Diagnostic>> {
380    let file = crate::parse_source_silent(source, filename)?;
381    TypeChecker::new().check_file(&file)?;
382    Ok(())
383}
384
385/// Project-aware type-check for the LSP.
386/// Finds trident.toml, resolves dependencies, and type-checks
387/// the given file with full module context.
388/// Falls back to single-file check if no project is found.
389pub fn check_file_in_project(source: &str, file_path: &Path) -> Result<(), Vec<Diagnostic>> {
390    let dir = file_path.parent().unwrap_or(Path::new("."));
391    let entry = match project::Project::find(dir) {
392        Some(toml_path) => match project::Project::load(&toml_path) {
393            Ok(p) => p.entry,
394            Err(_) => file_path.to_path_buf(),
395        },
396        None => file_path.to_path_buf(),
397    };
398
399    // Resolve all modules from the entry point (handles std.* even without project)
400    let modules = match resolve_modules(&entry) {
401        Ok(m) => m,
402        Err(_) => return check_silent(source, &file_path.to_string_lossy()),
403    };
404
405    // Parse and type-check all modules in dependency order
406    let mut all_exports: Vec<ModuleExports> = Vec::new();
407    let file_path_canon = file_path
408        .canonicalize()
409        .unwrap_or_else(|_| file_path.to_path_buf());
410
411    for module in &modules {
412        let mod_path_canon = module
413            .file_path
414            .canonicalize()
415            .unwrap_or_else(|_| module.file_path.clone());
416        let is_target = mod_path_canon == file_path_canon;
417
418        // Use live buffer for the file being edited
419        let src = if is_target { source } else { &module.source };
420        let parsed = crate::parse_source_silent(src, &module.file_path.to_string_lossy())?;
421
422        let mut tc = TypeChecker::new();
423        for exports in &all_exports {
424            tc.import_module(exports);
425        }
426
427        match tc.check_file(&parsed) {
428            Ok(exports) => {
429                all_exports.push(exports);
430            }
431            Err(errors) => {
432                if is_target {
433                    return Err(errors);
434                }
435                // Dep has errors — stop, but don't report
436                // dep errors as if they're in this file
437                return Ok(());
438            }
439        }
440    }
441
442    Ok(())
443}