Skip to main content

ternlang_core/
stdlib.rs

1/// StdlibLoader — resolves `use` statements into parsed function definitions.
2///
3/// When user code contains `use std::trit;` inside a function body, this module
4/// parses the corresponding stdlib source and injects the functions into the
5/// program before semantic analysis and codegen.
6///
7/// Stdlib sources are embedded at compile time via `include_str!` so the
8/// compiler binary is fully self-contained — no filesystem lookups at runtime.
9use crate::ast::{Program, Stmt};
10use crate::parser::Parser;
11
12pub struct StdlibLoader;
13
14impl StdlibLoader {
15    fn source_for(path: &[String]) -> Option<&'static str> {
16        match path.join("::").as_str() {
17            "std::trit"     => Some(include_str!("../stdlib/std/trit.tern")),
18            "std::math"     => Some(include_str!("../stdlib/std/math.tern")),
19            "std::tensor"   => Some(include_str!("../stdlib/std/tensor.tern")),
20            "std::io"       => Some(include_str!("../stdlib/std/io.tern")),
21            "ml::quantize"  => Some(include_str!("../stdlib/ml/quantize.tern")),
22            "ml::inference" => Some(include_str!("../stdlib/ml/inference.tern")),
23            _               => None,
24        }
25    }
26
27    /// Recursively collect `use` paths from a slice of statements.
28    fn collect_use_paths(stmts: &[Stmt]) -> Vec<Vec<String>> {
29        let mut paths = Vec::new();
30        for stmt in stmts {
31            match stmt {
32                Stmt::Use { path } => paths.push(path.clone()),
33                Stmt::Block(inner) => paths.extend(Self::collect_use_paths(inner)),
34                Stmt::IfTernary { on_pos, on_zero, on_neg, .. } => {
35                    paths.extend(Self::collect_use_paths(&[*on_pos.clone()]));
36                    paths.extend(Self::collect_use_paths(&[*on_zero.clone()]));
37                    paths.extend(Self::collect_use_paths(&[*on_neg.clone()]));
38                }
39                Stmt::Match { arms, .. } => {
40                    for (_, arm_stmt) in arms {
41                        paths.extend(Self::collect_use_paths(&[arm_stmt.clone()]));
42                    }
43                }
44                _ => {}
45            }
46        }
47        paths
48    }
49
50    /// Parse stdlib modules referenced by `use` statements and prepend their
51    /// functions to `program.functions`.  Functions already present by name are
52    /// not duplicated, so calling this multiple times is safe.
53    pub fn resolve(program: &mut Program) {
54        // Build the set of already-defined function names
55        let mut known: std::collections::HashSet<String> =
56            program.functions.iter().map(|f| f.name.clone()).collect();
57
58        // Collect all use paths from every function body
59        let mut all_paths: Vec<Vec<String>> = program
60            .functions
61            .iter()
62            .flat_map(|f| Self::collect_use_paths(&f.body))
63            .collect();
64
65        // Deduplicate so we parse each module at most once
66        all_paths.sort();
67        all_paths.dedup();
68
69        let mut stdlib_fns = Vec::new();
70
71        for path in &all_paths {
72            let key = path.join("::");
73            let Some(src) = Self::source_for(path) else {
74                // Unknown module — leave as-is; semantic checker will surface errors
75                continue;
76            };
77            let mut parser = Parser::new(src);
78            match parser.parse_program() {
79                Ok(stdlib_prog) => {
80                    for f in stdlib_prog.functions {
81                        if !known.contains(&f.name) {
82                            known.insert(f.name.clone());
83                            stdlib_fns.push(f);
84                        }
85                    }
86                }
87                Err(e) => {
88                    eprintln!("[stdlib] Failed to parse {}: {:?}", key, e);
89                }
90            }
91        }
92
93        // Prepend stdlib functions so they appear before user-defined functions
94        // (call order in bytecode doesn't matter, but it keeps the symbol table tidy)
95        stdlib_fns.extend(program.functions.drain(..));
96        program.functions = stdlib_fns;
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use crate::parser::Parser;
104
105    /// Verify that each stdlib module parses without errors.
106    #[test]
107    fn all_stdlib_modules_parse() {
108        let modules = [
109            vec!["std".to_string(), "trit".to_string()],
110            vec!["std".to_string(), "math".to_string()],
111            vec!["std".to_string(), "tensor".to_string()],
112            vec!["std".to_string(), "io".to_string()],
113            vec!["ml".to_string(), "quantize".to_string()],
114            vec!["ml".to_string(), "inference".to_string()],
115        ];
116        for path in &modules {
117            let src = StdlibLoader::source_for(path)
118                .unwrap_or_else(|| panic!("Missing stdlib source for {}", path.join("::")));
119            let mut parser = Parser::new(src);
120            parser.parse_program()
121                .unwrap_or_else(|e| panic!("Parse error in {}: {:?}", path.join("::"), e));
122        }
123    }
124
125    /// A program with `use std::trit;` should gain abs/min/max/etc after resolve.
126    #[test]
127    fn resolve_injects_trit_stdlib() {
128        let src = r#"
129fn main() -> trit {
130    use std::trit;
131    let x: trit = abs(-1);
132    return x;
133}
134"#;
135        let mut parser = Parser::new(src);
136        let mut prog = parser.parse_program().expect("parse failed");
137        assert!(!prog.functions.iter().any(|f| f.name == "abs"),
138            "abs should not be present before resolve");
139        StdlibLoader::resolve(&mut prog);
140        assert!(prog.functions.iter().any(|f| f.name == "abs"),
141            "abs should be injected after resolve");
142        assert!(prog.functions.iter().any(|f| f.name == "min"));
143        assert!(prog.functions.iter().any(|f| f.name == "majority"));
144    }
145
146    /// Multiple use statements should all be resolved, with no duplicates.
147    #[test]
148    fn resolve_multiple_modules_no_duplicates() {
149        let src = r#"
150fn main() -> trit {
151    use std::trit;
152    use std::math;
153    let x: trit = neg(1);
154    return x;
155}
156"#;
157        let mut parser = Parser::new(src);
158        let mut prog = parser.parse_program().expect("parse failed");
159        StdlibLoader::resolve(&mut prog);
160
161        // Count how many times "neg" appears — should be exactly 1
162        let neg_count = prog.functions.iter().filter(|f| f.name == "neg").count();
163        assert_eq!(neg_count, 1, "neg should appear exactly once");
164
165        // Both modules should be present
166        assert!(prog.functions.iter().any(|f| f.name == "abs"));   // std::trit
167        assert!(prog.functions.iter().any(|f| f.name == "rectify")); // std::math
168    }
169
170    /// Resolve is idempotent — calling it twice should not duplicate functions.
171    #[test]
172    fn resolve_is_idempotent() {
173        let src = r#"
174fn main() -> trit {
175    use std::trit;
176    return 0;
177}
178"#;
179        let mut parser = Parser::new(src);
180        let mut prog = parser.parse_program().expect("parse failed");
181        StdlibLoader::resolve(&mut prog);
182        StdlibLoader::resolve(&mut prog);
183        let abs_count = prog.functions.iter().filter(|f| f.name == "abs").count();
184        assert_eq!(abs_count, 1, "abs should not be duplicated by double resolve");
185    }
186
187    /// Unknown module paths are silently skipped (not a hard error).
188    #[test]
189    fn unknown_module_skipped_gracefully() {
190        let src = r#"
191fn main() -> trit {
192    use std::nonexistent;
193    return 0;
194}
195"#;
196        let mut parser = Parser::new(src);
197        let mut prog = parser.parse_program().expect("parse failed");
198        // Should not panic
199        StdlibLoader::resolve(&mut prog);
200    }
201}