Skip to main content

ternlang_core/
semantic.rs

1use crate::ast::*;
2
3// ─── Errors ───────────────────────────────────────────────────────────────────
4
5#[derive(Debug)]
6pub enum SemanticError {
7    TypeMismatch { expected: Type, found: Type },
8    UndefinedVariable(String),
9    UndefinedStruct(String),
10    UndefinedField { struct_name: String, field: String },
11    UndefinedFunction(String),
12    ReturnTypeMismatch { function: String, expected: Type, found: Type },
13    ArgCountMismatch { function: String, expected: usize, found: usize },
14    ArgTypeMismatch { function: String, param_index: usize, expected: Type, found: Type },
15}
16
17impl std::fmt::Display for SemanticError {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        match self {
20            Self::TypeMismatch { expected, found } =>
21                write!(f, "type mismatch: expected {:?}, found {:?}", expected, found),
22            Self::UndefinedVariable(n) =>
23                write!(f, "undefined variable: '{}'", n),
24            Self::UndefinedStruct(n) =>
25                write!(f, "undefined struct: '{}'", n),
26            Self::UndefinedField { struct_name, field } =>
27                write!(f, "struct '{}' has no field '{}'", struct_name, field),
28            Self::UndefinedFunction(n) =>
29                write!(f, "undefined function: '{}'", n),
30            Self::ReturnTypeMismatch { function, expected, found } =>
31                write!(f, "function '{}' declared return type {:?} but returns {:?}", function, expected, found),
32            Self::ArgCountMismatch { function, expected, found } =>
33                write!(f, "function '{}' expects {} argument(s), got {}", function, expected, found),
34            Self::ArgTypeMismatch { function, param_index, expected, found } =>
35                write!(f, "function '{}' argument {}: expected {:?}, found {:?}", function, param_index, expected, found),
36        }
37    }
38}
39
40// ─── Full function signature ──────────────────────────────────────────────────
41
42#[derive(Debug, Clone)]
43pub struct FunctionSig {
44    /// Parameter types in declaration order. None = variadic / unknown (built-ins with flexible arity).
45    pub params: Option<Vec<Type>>,
46    pub return_type: Type,
47}
48
49impl FunctionSig {
50    fn exact(params: Vec<Type>, return_type: Type) -> Self {
51        Self { params: Some(params), return_type }
52    }
53    fn variadic(return_type: Type) -> Self {
54        Self { params: None, return_type }
55    }
56}
57
58// ─── Analyzer ────────────────────────────────────────────────────────────────
59
60pub struct SemanticAnalyzer {
61    scopes:           Vec<std::collections::HashMap<String, Type>>,
62    struct_defs:      std::collections::HashMap<String, Vec<(String, Type)>>,
63    func_signatures:  std::collections::HashMap<String, FunctionSig>,
64    /// Set while type-checking a function body so Return stmts can be validated.
65    current_fn_name:       Option<String>,
66    current_fn_return:     Option<Type>,
67}
68
69impl SemanticAnalyzer {
70    pub fn new() -> Self {
71        let mut sigs: std::collections::HashMap<String, FunctionSig> = std::collections::HashMap::new();
72
73        // ── std::trit built-ins ────────────────────────────────────────────
74        sigs.insert("consensus".into(), FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
75        sigs.insert("invert".into(),    FunctionSig::exact(vec![Type::Trit],             Type::Trit));
76        sigs.insert("truth".into(),     FunctionSig::exact(vec![],                       Type::Trit));
77        sigs.insert("hold".into(),      FunctionSig::exact(vec![],                       Type::Trit));
78        sigs.insert("conflict".into(),  FunctionSig::exact(vec![],                       Type::Trit));
79        sigs.insert("mul".into(),       FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
80
81        // ── std::tensor ────────────────────────────────────────────────────
82        sigs.insert("matmul".into(),   FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
83        sigs.insert("sparsity".into(), FunctionSig::variadic(Type::Int));
84        sigs.insert("shape".into(),    FunctionSig::variadic(Type::Int));
85        sigs.insert("zeros".into(),    FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
86
87        // ── std::io ────────────────────────────────────────────────────────
88        sigs.insert("print".into(),    FunctionSig::variadic(Type::Trit));
89        sigs.insert("println".into(),  FunctionSig::variadic(Type::Trit));
90
91        // ── std::math ──────────────────────────────────────────────────────
92        sigs.insert("abs".into(),      FunctionSig::exact(vec![Type::Int],  Type::Int));
93        sigs.insert("min".into(),      FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
94        sigs.insert("max".into(),      FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
95
96        // ── ml::quantize ───────────────────────────────────────────────────
97        sigs.insert("quantize".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
98        sigs.insert("threshold".into(),FunctionSig::variadic(Type::Float));
99
100        // ── ml::inference ──────────────────────────────────────────────────
101        sigs.insert("forward".into(),  FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
102        sigs.insert("argmax".into(),   FunctionSig::variadic(Type::Int));
103
104        // ── type coercion ──────────────────────────────────────────────────
105        sigs.insert("cast".into(),     FunctionSig::variadic(Type::Trit));
106
107        Self {
108            scopes: vec![std::collections::HashMap::new()],
109            struct_defs: std::collections::HashMap::new(),
110            func_signatures: sigs,
111            current_fn_name: None,
112            current_fn_return: None,
113        }
114    }
115
116    // ── Registration ─────────────────────────────────────────────────────────
117
118    pub fn register_structs(&mut self, structs: &[StructDef]) {
119        for s in structs {
120            self.struct_defs.insert(s.name.clone(), s.fields.clone());
121        }
122    }
123
124    pub fn register_functions(&mut self, functions: &[Function]) {
125        for f in functions {
126            let params = f.params.iter().map(|(_, ty)| ty.clone()).collect();
127            self.func_signatures.insert(
128                f.name.clone(),
129                FunctionSig::exact(params, f.return_type.clone()),
130            );
131        }
132    }
133
134    pub fn register_agents(&mut self, agents: &[AgentDef]) {
135        for agent in agents {
136            for method in &agent.methods {
137                let params = method.params.iter().map(|(_, ty)| ty.clone()).collect();
138                let sig = FunctionSig::exact(params, method.return_type.clone());
139                self.func_signatures.insert(method.name.clone(), sig.clone());
140                self.func_signatures.insert(
141                    format!("{}::{}", agent.name, method.name),
142                    sig,
143                );
144            }
145        }
146    }
147
148    // ── Entry points ─────────────────────────────────────────────────────────
149
150    pub fn check_program(&mut self, program: &Program) -> Result<(), SemanticError> {
151        self.register_structs(&program.structs);
152        self.register_functions(&program.functions);
153        self.register_agents(&program.agents);
154        for agent in &program.agents {
155            for method in &agent.methods {
156                self.check_function(method)?;
157            }
158        }
159        for func in &program.functions {
160            self.check_function(func)?;
161        }
162        Ok(())
163    }
164
165    fn check_function(&mut self, func: &Function) -> Result<(), SemanticError> {
166        // Track return type context for this function body.
167        let prev_name   = self.current_fn_name.take();
168        let prev_return = self.current_fn_return.take();
169        self.current_fn_name   = Some(func.name.clone());
170        self.current_fn_return = Some(func.return_type.clone());
171
172        self.scopes.push(std::collections::HashMap::new());
173        for (name, ty) in &func.params {
174            self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
175        }
176        for stmt in &func.body {
177            self.check_stmt(stmt)?;
178        }
179        self.scopes.pop();
180
181        // Restore outer context (handles nested definitions if ever needed).
182        self.current_fn_name   = prev_name;
183        self.current_fn_return = prev_return;
184        Ok(())
185    }
186
187    // ── Statement checking ───────────────────────────────────────────────────
188
189    pub fn check_stmt(&mut self, stmt: &Stmt) -> Result<(), SemanticError> {
190        match stmt {
191            Stmt::Let { name, ty, value } => {
192                let val_ty = self.infer_expr_type(value)?;
193                let type_ok = val_ty == *ty
194                    || matches!(value, Expr::Cast { .. })
195                    || (matches!(ty, Type::Named(_)) && val_ty == Type::Trit)
196                    || (matches!(ty, Type::TritTensor { .. }) && matches!(val_ty, Type::TritTensor { .. }))
197                    || (*ty == Type::AgentRef && val_ty == Type::AgentRef);
198                if !type_ok {
199                    return Err(SemanticError::TypeMismatch { expected: ty.clone(), found: val_ty });
200                }
201                self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
202                Ok(())
203            }
204
205            Stmt::Return(expr) => {
206                let found = self.infer_expr_type(expr)?;
207                if let (Some(fn_name), Some(expected)) = (&self.current_fn_name, &self.current_fn_return) {
208                    // Allow TritTensor shape flexibility and AgentRef, cast
209                    let ok = found == *expected
210                        || matches!(expr, Expr::Cast { .. })
211                        || (matches!(expected, Type::TritTensor { .. }) && matches!(found, Type::TritTensor { .. }))
212                        || (matches!(expected, Type::Named(_)) && found == Type::Trit);
213                    if !ok {
214                        return Err(SemanticError::ReturnTypeMismatch {
215                            function: fn_name.clone(),
216                            expected: expected.clone(),
217                            found,
218                        });
219                    }
220                }
221                Ok(())
222            }
223
224            Stmt::IfTernary { condition, on_pos, on_zero, on_neg } => {
225                let cond_ty = self.infer_expr_type(condition)?;
226                if cond_ty != Type::Trit {
227                    return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
228                }
229                self.check_stmt(on_pos)?;
230                self.check_stmt(on_zero)?;
231                self.check_stmt(on_neg)?;
232                Ok(())
233            }
234
235            Stmt::Match { condition, arms } => {
236                let cond_ty = self.infer_expr_type(condition)?;
237                if cond_ty != Type::Trit {
238                    return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
239                }
240                for (_val, arm_stmt) in arms {
241                    self.check_stmt(arm_stmt)?;
242                }
243                Ok(())
244            }
245
246            Stmt::Block(stmts) => {
247                self.scopes.push(std::collections::HashMap::new());
248                for s in stmts { self.check_stmt(s)?; }
249                self.scopes.pop();
250                Ok(())
251            }
252
253            Stmt::Decorated { stmt, .. } => self.check_stmt(stmt),
254
255            Stmt::Expr(expr) => { self.infer_expr_type(expr)?; Ok(()) }
256
257            Stmt::ForIn { var, iter, body } => {
258                self.infer_expr_type(iter)?;
259                self.scopes.push(std::collections::HashMap::new());
260                self.scopes.last_mut().unwrap().insert(var.clone(), Type::Trit);
261                self.check_stmt(body)?;
262                self.scopes.pop();
263                Ok(())
264            }
265
266            Stmt::WhileTernary { condition, on_pos, on_zero, on_neg } => {
267                let cond_ty = self.infer_expr_type(condition)?;
268                if cond_ty != Type::Trit {
269                    return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
270                }
271                self.check_stmt(on_pos)?;
272                self.check_stmt(on_zero)?;
273                self.check_stmt(on_neg)?;
274                Ok(())
275            }
276
277            Stmt::Loop { body }   => self.check_stmt(body),
278            Stmt::Break           => Ok(()),
279            Stmt::Continue        => Ok(()),
280            Stmt::Use { .. }      => Ok(()),
281
282            Stmt::Send { target, message } => {
283                self.infer_expr_type(target)?;
284                self.infer_expr_type(message)?;
285                Ok(())
286            }
287
288            Stmt::FieldSet { object, field, value } => {
289                let obj_ty = self.lookup_var(object)?;
290                if let Type::Named(struct_name) = obj_ty {
291                    let field_ty = self.lookup_field(&struct_name, field)?;
292                    let val_ty   = self.infer_expr_type(value)?;
293                    if val_ty != field_ty {
294                        return Err(SemanticError::TypeMismatch { expected: field_ty, found: val_ty });
295                    }
296                } else {
297                    self.infer_expr_type(value)?;
298                }
299                Ok(())
300            }
301        }
302    }
303
304    // ── Expression type inference ─────────────────────────────────────────────
305
306    fn infer_expr_type(&self, expr: &Expr) -> Result<Type, SemanticError> {
307        match expr {
308            Expr::TritLiteral(_)   => Ok(Type::Trit),
309            Expr::IntLiteral(_)    => Ok(Type::Int),
310            Expr::StringLiteral(_) => Ok(Type::String),
311            Expr::Ident(name)      => self.lookup_var(name),
312
313            Expr::BinaryOp { lhs, rhs, .. } => {
314                let l = self.infer_expr_type(lhs)?;
315                let r = self.infer_expr_type(rhs)?;
316                if l != r {
317                    return Err(SemanticError::TypeMismatch { expected: l, found: r });
318                }
319                Ok(l)
320            }
321
322            Expr::UnaryOp { expr, .. } => self.infer_expr_type(expr),
323
324            Expr::Call { callee, args } => {
325                let sig = self.func_signatures.get(callee.as_str())
326                    .ok_or_else(|| SemanticError::UndefinedFunction(callee.clone()))?
327                    .clone();
328
329                // Argument arity + type checking (only for exact signatures).
330                if let Some(param_types) = &sig.params {
331                    if args.len() != param_types.len() {
332                        return Err(SemanticError::ArgCountMismatch {
333                            function: callee.clone(),
334                            expected: param_types.len(),
335                            found:    args.len(),
336                        });
337                    }
338                    for (i, (arg, expected_ty)) in args.iter().zip(param_types.iter()).enumerate() {
339                        let found_ty = self.infer_expr_type(arg)?;
340                        // Allow TritTensor shape flexibility and cast coercion.
341                        let ok = found_ty == *expected_ty
342                            || matches!(arg, Expr::Cast { .. })
343                            || (matches!(expected_ty, Type::TritTensor { .. })
344                                && matches!(found_ty, Type::TritTensor { .. }))
345                            || (matches!(expected_ty, Type::Named(_)) && found_ty == Type::Trit);
346                        if !ok {
347                            return Err(SemanticError::ArgTypeMismatch {
348                                function:    callee.clone(),
349                                param_index: i,
350                                expected:    expected_ty.clone(),
351                                found:       found_ty,
352                            });
353                        }
354                    }
355                } else {
356                    // Variadic — still infer arg types to catch undefined variables.
357                    for arg in args { self.infer_expr_type(arg)?; }
358                }
359
360                Ok(sig.return_type)
361            }
362
363            Expr::Cast { ty, .. }     => Ok(ty.clone()),
364            Expr::Spawn { .. }        => Ok(Type::AgentRef),
365            Expr::Await { .. }        => Ok(Type::Trit),
366            Expr::NodeId              => Ok(Type::String),
367
368            Expr::FieldAccess { object, field } => {
369                let obj_ty = self.infer_expr_type(object)?;
370                if let Type::Named(struct_name) = obj_ty {
371                    self.lookup_field(&struct_name, field)
372                } else {
373                    Ok(Type::Trit)
374                }
375            }
376        }
377    }
378
379    // ── Scope helpers ─────────────────────────────────────────────────────────
380
381    fn lookup_var(&self, name: &str) -> Result<Type, SemanticError> {
382        for scope in self.scopes.iter().rev() {
383            if let Some(ty) = scope.get(name) { return Ok(ty.clone()); }
384        }
385        Err(SemanticError::UndefinedVariable(name.to_string()))
386    }
387
388    fn lookup_field(&self, struct_name: &str, field: &str) -> Result<Type, SemanticError> {
389        let fields = self.struct_defs.get(struct_name)
390            .ok_or_else(|| SemanticError::UndefinedStruct(struct_name.to_string()))?;
391        fields.iter()
392            .find(|(f, _)| f == field)
393            .map(|(_, ty)| ty.clone())
394            .ok_or_else(|| SemanticError::UndefinedField {
395                struct_name: struct_name.to_string(),
396                field: field.to_string(),
397            })
398    }
399}
400
401// ─── Tests ───────────────────────────────────────────────────────────────────
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use crate::parser::Parser;
407
408    fn check(src: &str) -> Result<(), SemanticError> {
409        let mut parser = Parser::new(src);
410        let prog = parser.parse_program().expect("parse failed");
411        let mut analyzer = SemanticAnalyzer::new();
412        analyzer.check_program(&prog)
413    }
414
415    fn check_ok(src: &str) {
416        assert!(check(src).is_ok(), "expected ok, got: {:?}", check(src));
417    }
418
419    fn check_err(src: &str) {
420        assert!(check(src).is_err(), "expected error but check passed");
421    }
422
423    // ── Return type validation ────────────────────────────────────────────────
424
425    #[test]
426    fn test_return_correct_type() {
427        check_ok("fn f() -> trit { return 1; }");
428    }
429
430    #[test]
431    fn test_return_wrong_type_caught() {
432        // Returns Int but declared -> trit
433        check_err("fn f() -> trit { let x: int = 42; return x; }");
434    }
435
436    #[test]
437    fn test_return_trit_in_trit_fn() {
438        check_ok("fn decide(a: trit, b: trit) -> trit { return consensus(a, b); }");
439    }
440
441    // ── Argument count checking ───────────────────────────────────────────────
442
443    #[test]
444    fn test_call_correct_arity() {
445        check_ok("fn f() -> trit { return consensus(1, -1); }");
446    }
447
448    #[test]
449    fn test_call_too_few_args_caught() {
450        check_err("fn f() -> trit { return consensus(1); }");
451    }
452
453    #[test]
454    fn test_call_too_many_args_caught() {
455        check_err("fn f() -> trit { return invert(1, 1); }");
456    }
457
458    // ── Argument type checking ────────────────────────────────────────────────
459
460    #[test]
461    fn test_call_wrong_arg_type_caught() {
462        // invert expects trit, passing int literal 42 directly — int is not trit
463        check_err("fn f() -> trit { let x: int = 42; return invert(x); }");
464    }
465
466    #[test]
467    fn test_call_correct_arg_type() {
468        check_ok("fn f(a: trit) -> trit { return invert(a); }");
469    }
470
471    // ── Undefined function ────────────────────────────────────────────────────
472
473    #[test]
474    fn test_undefined_function_caught() {
475        check_err("fn f() -> trit { return doesnt_exist(1); }");
476    }
477
478    // ── User-defined function forward references ──────────────────────────────
479
480    #[test]
481    fn test_user_fn_return_type_registered() {
482        check_ok("fn helper(a: trit) -> trit { return invert(a); } fn main() -> trit { return helper(1); }");
483    }
484
485    #[test]
486    fn test_user_fn_wrong_return_caught() {
487        check_err("fn helper(a: trit) -> trit { let x: int = 1; return x; }");
488    }
489
490    // ── Undefined variable ────────────────────────────────────────────────────
491
492    #[test]
493    fn test_undefined_variable_caught() {
494        check_err("fn f() -> trit { return ghost_var; }");
495    }
496
497    #[test]
498    fn test_defined_variable_ok() {
499        check_ok("fn f() -> trit { let x: trit = 1; return x; }");
500    }
501
502    // ── Struct field types ────────────────────────────────────────────────────
503
504    #[test]
505    fn test_struct_field_access_ok() {
506        check_ok("struct S { val: trit } fn f(s: S) -> trit { return s.val; }");
507    }
508}