1use crate::ast::{Program, Stmt};
10use crate::parser::Parser;
11
12pub struct StdlibLoader;
13
14impl StdlibLoader {
15 fn source_for(path: &[String]) -> Option<&'static str> {
16 match path.join("::").as_str() {
17 "std::trit" => Some(include_str!("../stdlib/std/trit.tern")),
18 "std::math" => Some(include_str!("../stdlib/std/math.tern")),
19 "std::tensor" => Some(include_str!("../stdlib/std/tensor.tern")),
20 "std::io" => Some(include_str!("../stdlib/std/io.tern")),
21 "ml::quantize" => Some(include_str!("../stdlib/ml/quantize.tern")),
22 "ml::inference" => Some(include_str!("../stdlib/ml/inference.tern")),
23 _ => None,
24 }
25 }
26
27 fn collect_use_paths(stmts: &[Stmt]) -> Vec<Vec<String>> {
29 let mut paths = Vec::new();
30 for stmt in stmts {
31 match stmt {
32 Stmt::Use { path } => paths.push(path.clone()),
33 Stmt::Block(inner) => paths.extend(Self::collect_use_paths(inner)),
34 Stmt::IfTernary { on_pos, on_zero, on_neg, .. } => {
35 paths.extend(Self::collect_use_paths(&[*on_pos.clone()]));
36 paths.extend(Self::collect_use_paths(&[*on_zero.clone()]));
37 paths.extend(Self::collect_use_paths(&[*on_neg.clone()]));
38 }
39 Stmt::Match { arms, .. } => {
40 for (_, arm_stmt) in arms {
41 paths.extend(Self::collect_use_paths(&[arm_stmt.clone()]));
42 }
43 }
44 _ => {}
45 }
46 }
47 paths
48 }
49
50 pub fn resolve(program: &mut Program) {
54 let mut known: std::collections::HashSet<String> =
56 program.functions.iter().map(|f| f.name.clone()).collect();
57
58 let mut all_paths: Vec<Vec<String>> = program
60 .functions
61 .iter()
62 .flat_map(|f| Self::collect_use_paths(&f.body))
63 .collect();
64
65 all_paths.sort();
67 all_paths.dedup();
68
69 let mut stdlib_fns = Vec::new();
70
71 for path in &all_paths {
72 let key = path.join("::");
73 let Some(src) = Self::source_for(path) else {
74 continue;
76 };
77 let mut parser = Parser::new(src);
78 match parser.parse_program() {
79 Ok(stdlib_prog) => {
80 for f in stdlib_prog.functions {
81 if !known.contains(&f.name) {
82 known.insert(f.name.clone());
83 stdlib_fns.push(f);
84 }
85 }
86 }
87 Err(e) => {
88 eprintln!("[stdlib] Failed to parse {}: {:?}", key, e);
89 }
90 }
91 }
92
93 stdlib_fns.extend(program.functions.drain(..));
96 program.functions = stdlib_fns;
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use crate::parser::Parser;
104
105 #[test]
107 fn all_stdlib_modules_parse() {
108 let modules = [
109 vec!["std".to_string(), "trit".to_string()],
110 vec!["std".to_string(), "math".to_string()],
111 vec!["std".to_string(), "tensor".to_string()],
112 vec!["std".to_string(), "io".to_string()],
113 vec!["ml".to_string(), "quantize".to_string()],
114 vec!["ml".to_string(), "inference".to_string()],
115 ];
116 for path in &modules {
117 let src = StdlibLoader::source_for(path)
118 .unwrap_or_else(|| panic!("Missing stdlib source for {}", path.join("::")));
119 let mut parser = Parser::new(src);
120 parser.parse_program()
121 .unwrap_or_else(|e| panic!("Parse error in {}: {:?}", path.join("::"), e));
122 }
123 }
124
125 #[test]
127 fn resolve_injects_trit_stdlib() {
128 let src = r#"
129fn main() -> trit {
130 use std::trit;
131 let x: trit = abs(-1);
132 return x;
133}
134"#;
135 let mut parser = Parser::new(src);
136 let mut prog = parser.parse_program().expect("parse failed");
137 assert!(!prog.functions.iter().any(|f| f.name == "abs"),
138 "abs should not be present before resolve");
139 StdlibLoader::resolve(&mut prog);
140 assert!(prog.functions.iter().any(|f| f.name == "abs"),
141 "abs should be injected after resolve");
142 assert!(prog.functions.iter().any(|f| f.name == "min"));
143 assert!(prog.functions.iter().any(|f| f.name == "majority"));
144 }
145
146 #[test]
148 fn resolve_multiple_modules_no_duplicates() {
149 let src = r#"
150fn main() -> trit {
151 use std::trit;
152 use std::math;
153 let x: trit = neg(1);
154 return x;
155}
156"#;
157 let mut parser = Parser::new(src);
158 let mut prog = parser.parse_program().expect("parse failed");
159 StdlibLoader::resolve(&mut prog);
160
161 let neg_count = prog.functions.iter().filter(|f| f.name == "neg").count();
163 assert_eq!(neg_count, 1, "neg should appear exactly once");
164
165 assert!(prog.functions.iter().any(|f| f.name == "abs")); assert!(prog.functions.iter().any(|f| f.name == "rectify")); }
169
170 #[test]
172 fn resolve_is_idempotent() {
173 let src = r#"
174fn main() -> trit {
175 use std::trit;
176 return 0;
177}
178"#;
179 let mut parser = Parser::new(src);
180 let mut prog = parser.parse_program().expect("parse failed");
181 StdlibLoader::resolve(&mut prog);
182 StdlibLoader::resolve(&mut prog);
183 let abs_count = prog.functions.iter().filter(|f| f.name == "abs").count();
184 assert_eq!(abs_count, 1, "abs should not be duplicated by double resolve");
185 }
186
187 #[test]
189 fn unknown_module_skipped_gracefully() {
190 let src = r#"
191fn main() -> trit {
192 use std::nonexistent;
193 return 0;
194}
195"#;
196 let mut parser = Parser::new(src);
197 let mut prog = parser.parse_program().expect("parse failed");
198 StdlibLoader::resolve(&mut prog);
200 }
201}