Skip to main content

trident/api/
mod.rs

1pub(crate) use std::collections::{BTreeMap, BTreeSet};
2pub(crate) use std::path::Path;
3
4pub(crate) use crate::ast::{self, FileKind};
5pub(crate) use crate::cost;
6pub(crate) use crate::diagnostic::{render_diagnostics, Diagnostic};
7pub(crate) use crate::resolve::resolve_modules;
8pub(crate) use crate::span;
9pub(crate) use crate::target::TerrainConfig;
10pub(crate) use crate::tir::builder::TIRBuilder;
11pub(crate) use crate::tir::linker::{link, ModuleTasm};
12pub(crate) use crate::tir::lower::create_stack_lowering;
13pub(crate) use crate::tir::optimize::optimize as optimize_tir;
14pub(crate) use crate::typecheck::{ModuleExports, TypeChecker};
15pub(crate) use crate::{format, lexer, parser, project, solve, sym};
16
17#[cfg(test)]
18mod tests;
19
20/// Options controlling compilation: VM target + conditional compilation flags.
21#[derive(Clone, Debug)]
22pub struct CompileOptions {
23    /// Profile name for cfg flags (e.g. "debug", "release").
24    pub profile: String,
25    /// Active cfg flags for conditional compilation.
26    pub cfg_flags: BTreeSet<String>,
27    /// Target VM configuration.
28    pub target_config: TerrainConfig,
29    /// Additional module search directories (from locked dependencies).
30    pub dep_dirs: Vec<std::path::PathBuf>,
31}
32
33impl Default for CompileOptions {
34    fn default() -> Self {
35        Self {
36            profile: "debug".to_string(),
37            cfg_flags: BTreeSet::from(["debug".to_string()]),
38            target_config: TerrainConfig::triton(),
39            dep_dirs: Vec::new(),
40        }
41    }
42}
43
44impl CompileOptions {
45    /// Create options for a named profile (debug/release/custom).
46    pub fn for_profile(profile: &str) -> Self {
47        Self {
48            profile: profile.to_string(),
49            cfg_flags: BTreeSet::from([profile.to_string()]),
50            target_config: TerrainConfig::triton(),
51            dep_dirs: Vec::new(),
52        }
53    }
54
55    /// Create options for a named built-in target (backward compat alias).
56    pub fn for_target(target: &str) -> Self {
57        Self::for_profile(target)
58    }
59}
60
61/// Compile a single Trident source string to TASM.
62pub fn compile(source: &str, filename: &str) -> Result<String, Vec<Diagnostic>> {
63    compile_with_options(source, filename, &CompileOptions::default())
64}
65
66/// Compile a single Trident source string to TASM with options.
67pub fn compile_with_options(
68    source: &str,
69    filename: &str,
70    options: &CompileOptions,
71) -> Result<String, Vec<Diagnostic>> {
72    let file = crate::parse_source(source, filename)?;
73
74    // Type check
75    let exports = match TypeChecker::with_target(options.target_config.clone())
76        .with_cfg_flags(options.cfg_flags.clone())
77        .check_file(&file)
78    {
79        Ok(exports) => exports,
80        Err(errors) => {
81            render_diagnostics(&errors, filename, source);
82            return Err(errors);
83        }
84    };
85
86    // Build IR, optimize, and lower to target assembly
87    let ir = TIRBuilder::new(options.target_config.clone())
88        .with_cfg_flags(options.cfg_flags.clone())
89        .with_mono_instances(exports.mono_instances)
90        .with_call_resolutions(exports.call_resolutions)
91        .build_file(&file);
92    let ir = optimize_tir(ir);
93    let lowering = create_stack_lowering(&options.target_config.name);
94    let tasm = lowering.lower(&ir).join("\n");
95    Ok(tasm)
96}
97
98/// Compile a multi-module project from an entry point path.
99pub fn compile_project(entry_path: &Path) -> Result<String, Vec<Diagnostic>> {
100    compile_project_with_options(entry_path, &CompileOptions::default())
101}
102
103/// Compile a multi-module project with options.
104pub fn compile_project_with_options(
105    entry_path: &Path,
106    options: &CompileOptions,
107) -> Result<String, Vec<Diagnostic>> {
108    use crate::pipeline::PreparedProject;
109
110    let project = PreparedProject::build(entry_path, options)?;
111
112    let intrinsic_map = project.intrinsic_map();
113    let module_aliases = project.module_aliases();
114    let external_constants = project.external_constants();
115
116    // Emit TASM for each module
117    let mut tasm_modules = Vec::new();
118    for (i, pm) in project.modules.iter().enumerate() {
119        let is_program = pm.file.kind == FileKind::Program;
120        let mono = project
121            .exports
122            .get(i)
123            .map(|e| e.mono_instances.clone())
124            .unwrap_or_default();
125        let call_res = project
126            .exports
127            .get(i)
128            .map(|e| e.call_resolutions.clone())
129            .unwrap_or_default();
130        let ir = TIRBuilder::new(options.target_config.clone())
131            .with_cfg_flags(options.cfg_flags.clone())
132            .with_intrinsics(intrinsic_map.clone())
133            .with_module_aliases(module_aliases.clone())
134            .with_constants(external_constants.clone())
135            .with_mono_instances(mono)
136            .with_call_resolutions(call_res)
137            .build_file(&pm.file);
138        let ir = optimize_tir(ir);
139        let lowering = create_stack_lowering(&options.target_config.name);
140        let tasm = lowering.lower(&ir).join("\n");
141        tasm_modules.push(ModuleTasm {
142            module_name: pm.file.name.node.clone(),
143            is_program,
144            tasm,
145        });
146    }
147
148    // Link
149    let linked = link(tasm_modules);
150    Ok(linked)
151}
152
153/// Type-check only (no TASM emission).
154pub fn check(source: &str, filename: &str) -> Result<(), Vec<Diagnostic>> {
155    let file = crate::parse_source(source, filename)?;
156
157    if let Err(errors) = TypeChecker::new().check_file(&file) {
158        render_diagnostics(&errors, filename, source);
159        return Err(errors);
160    }
161
162    Ok(())
163}
164
165/// Project-aware type-check from an entry point path.
166/// Resolves all modules (including std.*) and type-checks in dependency order.
167pub fn check_project(entry_path: &Path) -> Result<(), Vec<Diagnostic>> {
168    use crate::pipeline::PreparedProject;
169
170    PreparedProject::build_default(entry_path)?;
171    Ok(())
172}
173
174/// Discover `#[test]` functions in a parsed file.
175pub fn discover_tests(file: &ast::File) -> Vec<String> {
176    let mut tests = Vec::new();
177    for item in &file.items {
178        if let ast::Item::Fn(func) = &item.node {
179            if func.is_test {
180                tests.push(func.name.node.clone());
181            }
182        }
183    }
184    tests
185}
186
187/// A single test result.
188#[derive(Clone, Debug)]
189pub struct TestResult {
190    pub name: String,
191    pub passed: bool,
192    pub cost: Option<cost::TableCost>,
193    pub error: Option<String>,
194}
195
196/// Run all `#[test]` functions in a project.
197///
198/// For each test function, we:
199/// 1. Parse and type-check the project
200/// 2. Compile a mini-program that just calls the test function
201/// 3. Report pass/fail with cost summary
202pub fn run_tests(
203    entry_path: &std::path::Path,
204    options: &CompileOptions,
205) -> Result<String, Vec<Diagnostic>> {
206    use crate::pipeline::PreparedProject;
207
208    let project = PreparedProject::build(entry_path, options)?;
209
210    // Discover all #[test] functions across all modules
211    let mut test_fns: Vec<(String, String)> = Vec::new(); // (module_name, fn_name)
212    for pm in &project.modules {
213        for test_name in discover_tests(&pm.file) {
214            test_fns.push((pm.file.name.node.clone(), test_name));
215        }
216    }
217
218    if test_fns.is_empty() {
219        return Ok("No #[test] functions found.\n".to_string());
220    }
221
222    // For each test function, compile a mini-program and report
223    let mut results: Vec<TestResult> = Vec::new();
224    let mut short_names: Vec<String> = Vec::new();
225    for (module_name, test_name) in &test_fns {
226        // Find the source file for this module
227        let source_entry = project
228            .modules
229            .iter()
230            .find(|m| m.file.name.node == *module_name);
231
232        if let Some(pm) = source_entry {
233            // Build a mini-program source that just calls the test function
234            let mini_source = if module_name.starts_with("module") || module_name.contains('.') {
235                // For module test functions, we'd need cross-module calls
236                // For simplicity, compile in-context
237                pm.source.clone()
238            } else {
239                pm.source.clone()
240            };
241
242            // Try to compile (type-check + emit) the source.
243            // The test function itself is validated by the type checker.
244            // For now, "passing" means it compiles without errors.
245            match compile_with_options(&mini_source, &pm.file_path.to_string_lossy(), options) {
246                Ok(tasm) => {
247                    // Compute cost for the test function
248                    let test_cost =
249                        analyze_costs(&mini_source, &pm.file_path.to_string_lossy()).ok();
250                    if short_names.is_empty() {
251                        if let Some(ref pc) = test_cost {
252                            short_names = pc.table_short_names.clone();
253                        }
254                    }
255                    let fn_cost = test_cost.as_ref().and_then(|pc| {
256                        pc.functions
257                            .iter()
258                            .find(|f| f.name == *test_name)
259                            .map(|f| f.cost.clone())
260                    });
261                    // Check if the generated TASM contains an assert failure marker
262                    let has_error = tasm.contains("// ERROR");
263                    results.push(TestResult {
264                        name: test_name.clone(),
265                        passed: !has_error,
266                        cost: fn_cost,
267                        error: if has_error {
268                            Some("compilation produced errors".to_string())
269                        } else {
270                            None
271                        },
272                    });
273                }
274                Err(errors) => {
275                    let msg = errors
276                        .iter()
277                        .map(|d| d.message.clone())
278                        .collect::<Vec<_>>()
279                        .join("; ");
280                    results.push(TestResult {
281                        name: test_name.clone(),
282                        passed: false,
283                        cost: None,
284                        error: Some(msg),
285                    });
286                }
287            }
288        }
289    }
290
291    // Format the report
292    let mut report = String::new();
293    let total = results.len();
294    let passed = results.iter().filter(|r| r.passed).count();
295    let failed = total - passed;
296
297    report.push_str(&format!(
298        "running {} test{}\n",
299        total,
300        if total == 1 { "" } else { "s" }
301    ));
302
303    for result in &results {
304        let status = if result.passed { "ok" } else { "FAILED" };
305        let cost_str = if let Some(ref c) = result.cost {
306            let sn: Vec<&str> = short_names.iter().map(|s| s.as_str()).collect();
307            let ann = c.format_annotation(&sn);
308            if ann.is_empty() {
309                String::new()
310            } else {
311                format!(" ({})", ann)
312            }
313        } else {
314            String::new()
315        };
316        report.push_str(&format!(
317            "  test {} ... {}{}\n",
318            result.name, status, cost_str
319        ));
320        if let Some(ref err) = result.error {
321            report.push_str(&format!("    error: {}\n", err));
322        }
323    }
324
325    report.push('\n');
326    if failed == 0 {
327        report.push_str(&format!("test result: ok. {} passed; 0 failed\n", passed));
328    } else {
329        report.push_str(&format!(
330            "test result: FAILED. {} passed; {} failed\n",
331            passed, failed
332        ));
333    }
334
335    Ok(report)
336}
337
338/// Compile a module and emit TASM for all its functions (no linking, no DCE).
339/// Dependencies are resolved and type-checked, but only the target module's
340/// TASM is returned. Labels use the raw `__funcname:` format.
341pub fn compile_module(
342    module_path: &Path,
343    options: &CompileOptions,
344) -> Result<String, Vec<Diagnostic>> {
345    use crate::pipeline::PreparedProject;
346
347    let project = PreparedProject::build(module_path, options)?;
348
349    let intrinsic_map = project.intrinsic_map();
350    let module_aliases = project.module_aliases();
351    let external_constants = project.external_constants();
352
353    // Emit TASM for only the target module (last in topological order)
354    if let Some((i, pm)) = project.modules.iter().enumerate().last() {
355        let mono = project
356            .exports
357            .get(i)
358            .map(|e| e.mono_instances.clone())
359            .unwrap_or_default();
360        let call_res = project
361            .exports
362            .get(i)
363            .map(|e| e.call_resolutions.clone())
364            .unwrap_or_default();
365        let ir = TIRBuilder::new(options.target_config.clone())
366            .with_cfg_flags(options.cfg_flags.clone())
367            .with_intrinsics(intrinsic_map)
368            .with_module_aliases(module_aliases)
369            .with_constants(external_constants)
370            .with_mono_instances(mono)
371            .with_call_resolutions(call_res)
372            .build_file(&pm.file);
373        let ir = optimize_tir(ir);
374        let lowering = create_stack_lowering(&options.target_config.name);
375        let tasm = lowering.lower(&ir).join("\n");
376        Ok(tasm)
377    } else {
378        Err(vec![Diagnostic::error(
379            "no module found".to_string(),
380            span::Span::dummy(),
381        )])
382    }
383}
384
385/// Build TIR (optimized intermediate representation) from a single source file.
386///
387/// Returns the IR ops before lowering to target assembly. Used by the
388/// neural optimizer to analyze and improve the compilation.
389pub fn build_tir(
390    source: &str,
391    filename: &str,
392    options: &CompileOptions,
393) -> Result<Vec<crate::tir::TIROp>, Vec<Diagnostic>> {
394    let file = crate::parse_source(source, filename)?;
395
396    let exports = match TypeChecker::with_target(options.target_config.clone())
397        .with_cfg_flags(options.cfg_flags.clone())
398        .check_file(&file)
399    {
400        Ok(exports) => exports,
401        Err(errors) => {
402            render_diagnostics(&errors, filename, source);
403            return Err(errors);
404        }
405    };
406
407    let ir = TIRBuilder::new(options.target_config.clone())
408        .with_cfg_flags(options.cfg_flags.clone())
409        .with_mono_instances(exports.mono_instances)
410        .with_call_resolutions(exports.call_resolutions)
411        .build_file(&file);
412    Ok(optimize_tir(ir))
413}
414
415/// Build TIR from a project entry point with full module resolution.
416///
417/// Uses the same multi-module pipeline as `compile_project_with_options`
418/// but returns combined TIR ops instead of TASM. Required for neural
419/// training on files that import other modules (e.g. merkle.tri imports
420/// vm.crypto.merkle).
421pub fn build_tir_project(
422    entry_path: &Path,
423    options: &CompileOptions,
424) -> Result<Vec<crate::tir::TIROp>, Vec<Diagnostic>> {
425    use crate::pipeline::PreparedProject;
426
427    let project = PreparedProject::build(entry_path, options)?;
428
429    let intrinsic_map = project.intrinsic_map();
430    let module_aliases = project.module_aliases();
431    let external_constants = project.external_constants();
432
433    let mut all_ir = Vec::new();
434    for (i, pm) in project.modules.iter().enumerate() {
435        let mono = project
436            .exports
437            .get(i)
438            .map(|e| e.mono_instances.clone())
439            .unwrap_or_default();
440        let call_res = project
441            .exports
442            .get(i)
443            .map(|e| e.call_resolutions.clone())
444            .unwrap_or_default();
445        let ir = TIRBuilder::new(options.target_config.clone())
446            .with_cfg_flags(options.cfg_flags.clone())
447            .with_intrinsics(intrinsic_map.clone())
448            .with_module_aliases(module_aliases.clone())
449            .with_constants(external_constants.clone())
450            .with_mono_instances(mono)
451            .with_call_resolutions(call_res)
452            .build_file(&pm.file);
453        all_ir.extend(optimize_tir(ir));
454    }
455    Ok(all_ir)
456}
457
458pub(crate) mod doc;
459pub(crate) mod pipeline;
460mod tools;
461pub use tools::*;
462
463/// Compile a multi-module project to a `ProgramBundle` artifact.
464///
465/// This is the primary entry point for warriors: it produces a
466/// self-contained bundle with compiled assembly, cost analysis,
467/// function signatures, and metadata.
468pub fn compile_to_bundle(
469    entry_path: &Path,
470    options: &CompileOptions,
471) -> Result<crate::runtime::ProgramBundle, Vec<Diagnostic>> {
472    use crate::runtime::artifact::{BundleCost, BundleFunction, ProgramBundle};
473    use pipeline::PreparedProject;
474
475    let tasm = compile_project_with_options(entry_path, options)?;
476
477    // Cost analysis (best-effort — use zeros on failure)
478    let program_cost =
479        analyze_costs_project(entry_path, options).unwrap_or_else(|_| cost::ProgramCost {
480            program_name: String::new(),
481            functions: Vec::new(),
482            total: cost::TableCost::ZERO,
483            table_names: Vec::new(),
484            table_short_names: Vec::new(),
485            attestation_hash_rows: 0,
486            padded_height: 0,
487            estimated_proving_ns: 0,
488            loop_bound_waste: Vec::new(),
489        });
490
491    // Parse entry file for function signatures + content hashes
492    let project = PreparedProject::build(entry_path, options)?;
493    let entry_file = project
494        .modules
495        .iter()
496        .find(|m| m.file.kind == FileKind::Program)
497        .or_else(|| project.modules.last());
498
499    let (functions, entry_point, source_hash) = if let Some(pm) = entry_file {
500        let fn_hashes = crate::hash::hash_file(&pm.file);
501        let fns: Vec<BundleFunction> = pm
502            .file
503            .items
504            .iter()
505            .filter_map(|item| {
506                if let ast::Item::Fn(func) = &item.node {
507                    if !func.is_test {
508                        let hash = fn_hashes
509                            .get(&func.name.node)
510                            .map(|h| h.to_hex())
511                            .unwrap_or_default();
512                        return Some(BundleFunction {
513                            name: func.name.node.clone(),
514                            hash,
515                            signature: crate::deploy::format_fn_signature(func),
516                        });
517                    }
518                }
519                None
520            })
521            .collect();
522        let ep = if fns.iter().any(|f| f.name == "main") {
523            "main".to_string()
524        } else {
525            fns.first()
526                .map(|f| f.name.clone())
527                .unwrap_or_else(|| "main".to_string())
528        };
529        let sh = crate::hash::hash_file_content(&pm.file).to_hex();
530        (fns, ep, sh)
531    } else {
532        (Vec::new(), "main".to_string(), String::new())
533    };
534
535    let name = entry_path
536        .file_stem()
537        .and_then(|s| s.to_str())
538        .unwrap_or("program")
539        .to_string();
540
541    Ok(ProgramBundle {
542        name,
543        version: "0.1.0".to_string(),
544        target_vm: options.target_config.name.clone(),
545        target_os: None,
546        assembly: tasm,
547        entry_point,
548        functions,
549        cost: BundleCost {
550            table_values: (0..program_cost.total.count as usize)
551                .map(|i| program_cost.total.get(i))
552                .collect(),
553            table_names: program_cost.table_names,
554            padded_height: program_cost.padded_height,
555            estimated_proving_ns: program_cost.estimated_proving_ns,
556        },
557        source_hash,
558    })
559}