Skip to main content

ternlang_core/
semantic.rs

1use crate::ast::*;
2
3// ─── Errors ───────────────────────────────────────────────────────────────────
4
5#[derive(Debug)]
6pub enum SemanticError {
7    TypeMismatch { expected: Type, found: Type },
8    UndefinedVariable(String),
9    UndefinedStruct(String),
10    UndefinedField { struct_name: String, field: String },
11    UndefinedFunction(String),
12    ReturnTypeMismatch { function: String, expected: Type, found: Type },
13    ArgCountMismatch { function: String, expected: usize, found: usize },
14    ArgTypeMismatch { function: String, param_index: usize, expected: Type, found: Type },
15    /// `?` used on an expression that doesn't return trit
16    PropagateOnNonTrit { found: Type },
17}
18
19impl std::fmt::Display for SemanticError {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        match self {
22            Self::TypeMismatch { expected, found } =>
23                write!(f, "[TYPE-001] Type mismatch: expected {expected:?}, found {found:?}. Binary types don't map cleanly to ternary space."),
24            Self::UndefinedVariable(n) =>
25                write!(f, "[SCOPE-001] '{n}' is undefined. Hold state — declare before use."),
26            Self::UndefinedStruct(n) =>
27                write!(f, "[STRUCT-001] Struct '{n}' doesn't exist. The type system can't find it."),
28            Self::UndefinedField { struct_name, field } =>
29                write!(f, "[STRUCT-002] Struct '{struct_name}' has no field '{field}'. Check your definition."),
30            Self::UndefinedFunction(n) =>
31                write!(f, "[FN-001] '{n}' is not defined. Did you forget to declare it or import its module?"),
32            Self::ReturnTypeMismatch { function, expected, found } =>
33                write!(f, "[FN-002] Function '{function}' declared return type {expected:?} but returned {found:?}. Ternary contracts are strict."),
34            Self::ArgCountMismatch { function, expected, found } =>
35                write!(f, "[FN-003] '{function}' expects {expected} arg(s), got {found}. Arity is not optional."),
36            Self::ArgTypeMismatch { function, param_index, expected, found } =>
37                write!(f, "[FN-004] '{function}' arg {param_index}: expected {expected:?}, found {found:?}. Types travel with their values."),
38            Self::PropagateOnNonTrit { found } =>
39                write!(f, "[PROP-001] '?' used on a {found:?} expression. Only trit-returning functions can signal conflict. The third state requires a trit."),
40        }
41    }
42}
43
44// ─── Full function signature ──────────────────────────────────────────────────
45
46#[derive(Debug, Clone)]
47pub struct FunctionSig {
48    /// Parameter types in declaration order. None = variadic / unknown (built-ins with flexible arity).
49    pub params: Option<Vec<Type>>,
50    pub return_type: Type,
51}
52
53impl FunctionSig {
54    fn exact(params: Vec<Type>, return_type: Type) -> Self {
55        Self { params: Some(params), return_type }
56    }
57    fn variadic(return_type: Type) -> Self {
58        Self { params: None, return_type }
59    }
60}
61
62// ─── Analyzer ────────────────────────────────────────────────────────────────
63
64pub struct SemanticAnalyzer {
65    scopes:           Vec<std::collections::HashMap<String, Type>>,
66    struct_defs:      std::collections::HashMap<String, Vec<(String, Type)>>,
67    func_signatures:  std::collections::HashMap<String, FunctionSig>,
68    /// Set while type-checking a function body so Return stmts can be validated.
69    current_fn_name:       Option<String>,
70    current_fn_return:     Option<Type>,
71}
72
73impl SemanticAnalyzer {
74    pub fn new() -> Self {
75        let mut sigs: std::collections::HashMap<String, FunctionSig> = std::collections::HashMap::new();
76
77        // ── std::trit built-ins ────────────────────────────────────────────
78        sigs.insert("consensus".into(), FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
79        sigs.insert("invert".into(),    FunctionSig::exact(vec![Type::Trit],             Type::Trit));
80        sigs.insert("truth".into(),     FunctionSig::exact(vec![],                       Type::Trit));
81        sigs.insert("hold".into(),      FunctionSig::exact(vec![],                       Type::Trit));
82        sigs.insert("conflict".into(),  FunctionSig::exact(vec![],                       Type::Trit));
83        sigs.insert("mul".into(),       FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
84
85        // ── std::tensor ────────────────────────────────────────────────────
86        sigs.insert("matmul".into(),   FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
87        sigs.insert("sparsity".into(), FunctionSig::variadic(Type::Int));
88        sigs.insert("shape".into(),    FunctionSig::variadic(Type::Int));
89        sigs.insert("zeros".into(),    FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
90
91        // ── std::io ────────────────────────────────────────────────────────
92        sigs.insert("print".into(),    FunctionSig::variadic(Type::Trit));
93        sigs.insert("println".into(),  FunctionSig::variadic(Type::Trit));
94
95        // ── std::math ──────────────────────────────────────────────────────
96        sigs.insert("abs".into(),      FunctionSig::exact(vec![Type::Int],  Type::Int));
97        sigs.insert("min".into(),      FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
98        sigs.insert("max".into(),      FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
99
100        // ── ml::quantize ───────────────────────────────────────────────────
101        sigs.insert("quantize".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
102        sigs.insert("threshold".into(),FunctionSig::variadic(Type::Float));
103
104        // ── ml::inference ──────────────────────────────────────────────────
105        sigs.insert("forward".into(),  FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
106        sigs.insert("argmax".into(),   FunctionSig::variadic(Type::Int));
107
108        // ── type coercion ──────────────────────────────────────────────────
109        sigs.insert("cast".into(),     FunctionSig::variadic(Type::Trit));
110
111        Self {
112            scopes: vec![std::collections::HashMap::new()],
113            struct_defs: std::collections::HashMap::new(),
114            func_signatures: sigs,
115            current_fn_name: None,
116            current_fn_return: None,
117        }
118    }
119
120    // ── Registration ─────────────────────────────────────────────────────────
121
122    pub fn register_structs(&mut self, structs: &[StructDef]) {
123        for s in structs {
124            self.struct_defs.insert(s.name.clone(), s.fields.clone());
125        }
126    }
127
128    pub fn register_functions(&mut self, functions: &[Function]) {
129        for f in functions {
130            let params = f.params.iter().map(|(_, ty)| ty.clone()).collect();
131            self.func_signatures.insert(
132                f.name.clone(),
133                FunctionSig::exact(params, f.return_type.clone()),
134            );
135        }
136    }
137
138    pub fn register_agents(&mut self, agents: &[AgentDef]) {
139        for agent in agents {
140            for method in &agent.methods {
141                let params = method.params.iter().map(|(_, ty)| ty.clone()).collect();
142                let sig = FunctionSig::exact(params, method.return_type.clone());
143                self.func_signatures.insert(method.name.clone(), sig.clone());
144                self.func_signatures.insert(
145                    format!("{}::{}", agent.name, method.name),
146                    sig,
147                );
148            }
149        }
150    }
151
152    // ── Entry points ─────────────────────────────────────────────────────────
153
154    pub fn check_program(&mut self, program: &Program) -> Result<(), SemanticError> {
155        self.register_structs(&program.structs);
156        self.register_functions(&program.functions);
157        self.register_agents(&program.agents);
158        for agent in &program.agents {
159            for method in &agent.methods {
160                self.check_function(method)?;
161            }
162        }
163        for func in &program.functions {
164            self.check_function(func)?;
165        }
166        Ok(())
167    }
168
169    fn check_function(&mut self, func: &Function) -> Result<(), SemanticError> {
170        // Track return type context for this function body.
171        let prev_name   = self.current_fn_name.take();
172        let prev_return = self.current_fn_return.take();
173        self.current_fn_name   = Some(func.name.clone());
174        self.current_fn_return = Some(func.return_type.clone());
175
176        self.scopes.push(std::collections::HashMap::new());
177        for (name, ty) in &func.params {
178            self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
179        }
180        for stmt in &func.body {
181            self.check_stmt(stmt)?;
182        }
183        self.scopes.pop();
184
185        // Restore outer context (handles nested definitions if ever needed).
186        self.current_fn_name   = prev_name;
187        self.current_fn_return = prev_return;
188        Ok(())
189    }
190
191    // ── Statement checking ───────────────────────────────────────────────────
192
193    pub fn check_stmt(&mut self, stmt: &Stmt) -> Result<(), SemanticError> {
194        match stmt {
195            Stmt::Let { name, ty, value } => {
196                let val_ty = self.infer_expr_type(value)?;
197                let type_ok = val_ty == *ty
198                    || matches!(value, Expr::Cast { .. })
199                    || (matches!(ty, Type::Named(_)) && val_ty == Type::Trit)
200                    || (matches!(ty, Type::TritTensor { .. }) && matches!(val_ty, Type::TritTensor { .. }))
201                    || (*ty == Type::AgentRef && val_ty == Type::AgentRef);
202                if !type_ok {
203                    return Err(SemanticError::TypeMismatch { expected: ty.clone(), found: val_ty });
204                }
205                self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
206                Ok(())
207            }
208
209            Stmt::Return(expr) => {
210                let found = self.infer_expr_type(expr)?;
211                if let (Some(fn_name), Some(expected)) = (&self.current_fn_name, &self.current_fn_return) {
212                    // Allow TritTensor shape flexibility and AgentRef, cast
213                    let ok = found == *expected
214                        || matches!(expr, Expr::Cast { .. })
215                        || (matches!(expected, Type::TritTensor { .. }) && matches!(found, Type::TritTensor { .. }))
216                        || (matches!(expected, Type::Named(_)) && found == Type::Trit);
217                    if !ok {
218                        return Err(SemanticError::ReturnTypeMismatch {
219                            function: fn_name.clone(),
220                            expected: expected.clone(),
221                            found,
222                        });
223                    }
224                }
225                Ok(())
226            }
227
228            Stmt::IfTernary { condition, on_pos, on_zero, on_neg } => {
229                let cond_ty = self.infer_expr_type(condition)?;
230                if cond_ty != Type::Trit {
231                    return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
232                }
233                self.check_stmt(on_pos)?;
234                self.check_stmt(on_zero)?;
235                self.check_stmt(on_neg)?;
236                Ok(())
237            }
238
239            Stmt::Match { condition, arms } => {
240                let cond_ty = self.infer_expr_type(condition)?;
241                if cond_ty != Type::Trit {
242                    return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
243                }
244                for (_val, arm_stmt) in arms {
245                    self.check_stmt(arm_stmt)?;
246                }
247                Ok(())
248            }
249
250            Stmt::Block(stmts) => {
251                self.scopes.push(std::collections::HashMap::new());
252                for s in stmts { self.check_stmt(s)?; }
253                self.scopes.pop();
254                Ok(())
255            }
256
257            Stmt::Decorated { stmt, .. } => self.check_stmt(stmt),
258
259            Stmt::Expr(expr) => { self.infer_expr_type(expr)?; Ok(()) }
260
261            Stmt::ForIn { var, iter, body } => {
262                self.infer_expr_type(iter)?;
263                self.scopes.push(std::collections::HashMap::new());
264                self.scopes.last_mut().unwrap().insert(var.clone(), Type::Trit);
265                self.check_stmt(body)?;
266                self.scopes.pop();
267                Ok(())
268            }
269
270            Stmt::WhileTernary { condition, on_pos, on_zero, on_neg } => {
271                let cond_ty = self.infer_expr_type(condition)?;
272                if cond_ty != Type::Trit {
273                    return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
274                }
275                self.check_stmt(on_pos)?;
276                self.check_stmt(on_zero)?;
277                self.check_stmt(on_neg)?;
278                Ok(())
279            }
280
281            Stmt::Loop { body }   => self.check_stmt(body),
282            Stmt::Break           => Ok(()),
283            Stmt::Continue        => Ok(()),
284            Stmt::Use { .. }      => Ok(()),
285
286            Stmt::Send { target, message } => {
287                self.infer_expr_type(target)?;
288                self.infer_expr_type(message)?;
289                Ok(())
290            }
291
292            Stmt::FieldSet { object, field, value } => {
293                let obj_ty = self.lookup_var(object)?;
294                if let Type::Named(struct_name) = obj_ty {
295                    let field_ty = self.lookup_field(&struct_name, field)?;
296                    let val_ty   = self.infer_expr_type(value)?;
297                    if val_ty != field_ty {
298                        return Err(SemanticError::TypeMismatch { expected: field_ty, found: val_ty });
299                    }
300                } else {
301                    self.infer_expr_type(value)?;
302                }
303                Ok(())
304            }
305
306            Stmt::IndexSet { object, row, col, value } => {
307                self.lookup_var(object)?;
308                self.infer_expr_type(row)?;
309                self.infer_expr_type(col)?;
310                self.infer_expr_type(value)?;
311                Ok(())
312            }
313        }
314    }
315
316    // ── Expression type inference ─────────────────────────────────────────────
317
318    fn infer_expr_type(&self, expr: &Expr) -> Result<Type, SemanticError> {
319        match expr {
320            Expr::TritLiteral(_)   => Ok(Type::Trit),
321            Expr::IntLiteral(_)    => Ok(Type::Int),
322            Expr::StringLiteral(_) => Ok(Type::String),
323            Expr::Ident(name)      => self.lookup_var(name),
324
325            Expr::BinaryOp { op, lhs, rhs } => {
326                let l = self.infer_expr_type(lhs)?;
327                let r = self.infer_expr_type(rhs)?;
328                match op {
329                    BinOp::Less | BinOp::Greater | BinOp::Equal | BinOp::NotEqual | BinOp::And | BinOp::Or => {
330                        Ok(Type::Trit)
331                    }
332                    _ => {
333                        if l != r {
334                            return Err(SemanticError::TypeMismatch { expected: l, found: r });
335                        }
336                        Ok(l)
337                    }
338                }
339            }
340
341            Expr::UnaryOp { expr, .. } => self.infer_expr_type(expr),
342
343            Expr::Call { callee, args } => {
344                let sig = self.func_signatures.get(callee.as_str())
345                    .ok_or_else(|| SemanticError::UndefinedFunction(callee.clone()))?
346                    .clone();
347
348                // Argument arity + type checking (only for exact signatures).
349                if let Some(param_types) = &sig.params {
350                    if args.len() != param_types.len() {
351                        return Err(SemanticError::ArgCountMismatch {
352                            function: callee.clone(),
353                            expected: param_types.len(),
354                            found:    args.len(),
355                        });
356                    }
357                    for (i, (arg, expected_ty)) in args.iter().zip(param_types.iter()).enumerate() {
358                        let found_ty = self.infer_expr_type(arg)?;
359                        // Allow TritTensor shape flexibility and cast coercion.
360                        let ok = found_ty == *expected_ty
361                            || matches!(arg, Expr::Cast { .. })
362                            || (matches!(expected_ty, Type::TritTensor { .. })
363                                && matches!(found_ty, Type::TritTensor { .. }))
364                            || (matches!(expected_ty, Type::Named(_)) && found_ty == Type::Trit);
365                        if !ok {
366                            return Err(SemanticError::ArgTypeMismatch {
367                                function:    callee.clone(),
368                                param_index: i,
369                                expected:    expected_ty.clone(),
370                                found:       found_ty,
371                            });
372                        }
373                    }
374                } else {
375                    // Variadic — still infer arg types to catch undefined variables.
376                    for arg in args { self.infer_expr_type(arg)?; }
377                }
378
379                Ok(sig.return_type)
380            }
381
382            Expr::Cast { ty, .. }     => Ok(ty.clone()),
383            Expr::Spawn { .. }        => Ok(Type::AgentRef),
384            Expr::Await { .. }        => Ok(Type::Trit),
385            Expr::NodeId              => Ok(Type::String),
386
387            Expr::Propagate { expr } => {
388                let inner = self.infer_expr_type(expr)?;
389                if inner != Type::Trit {
390                    return Err(SemanticError::PropagateOnNonTrit { found: inner });
391                }
392                Ok(Type::Trit)
393            }
394
395            Expr::FieldAccess { object, field } => {
396                let obj_ty = self.infer_expr_type(object)?;
397                if let Type::Named(struct_name) = obj_ty {
398                    self.lookup_field(&struct_name, field)
399                } else {
400                    Ok(Type::Trit)
401                }
402            }
403
404            Expr::Index { object, row, col } => {
405                self.infer_expr_type(object)?;
406                self.infer_expr_type(row)?;
407                self.infer_expr_type(col)?;
408                Ok(Type::Trit)
409            }
410        }
411    }
412
413    // ── Scope helpers ─────────────────────────────────────────────────────────
414
415    fn lookup_var(&self, name: &str) -> Result<Type, SemanticError> {
416        for scope in self.scopes.iter().rev() {
417            if let Some(ty) = scope.get(name) { return Ok(ty.clone()); }
418        }
419        Err(SemanticError::UndefinedVariable(name.to_string()))
420    }
421
422    fn lookup_field(&self, struct_name: &str, field: &str) -> Result<Type, SemanticError> {
423        let fields = self.struct_defs.get(struct_name)
424            .ok_or_else(|| SemanticError::UndefinedStruct(struct_name.to_string()))?;
425        fields.iter()
426            .find(|(f, _)| f == field)
427            .map(|(_, ty)| ty.clone())
428            .ok_or_else(|| SemanticError::UndefinedField {
429                struct_name: struct_name.to_string(),
430                field: field.to_string(),
431            })
432    }
433}
434
435// ─── Tests ───────────────────────────────────────────────────────────────────
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440    use crate::parser::Parser;
441
442    fn check(src: &str) -> Result<(), SemanticError> {
443        let mut parser = Parser::new(src);
444        let prog = parser.parse_program().expect("parse failed");
445        let mut analyzer = SemanticAnalyzer::new();
446        analyzer.check_program(&prog)
447    }
448
449    fn check_ok(src: &str) {
450        assert!(check(src).is_ok(), "expected ok, got: {:?}", check(src));
451    }
452
453    fn check_err(src: &str) {
454        assert!(check(src).is_err(), "expected error but check passed");
455    }
456
457    // ── Return type validation ────────────────────────────────────────────────
458
459    #[test]
460    fn test_return_correct_type() {
461        check_ok("fn f() -> trit { return 1; }");
462    }
463
464    #[test]
465    fn test_return_wrong_type_caught() {
466        // Returns Int but declared -> trit
467        check_err("fn f() -> trit { let x: int = 42; return x; }");
468    }
469
470    #[test]
471    fn test_return_trit_in_trit_fn() {
472        check_ok("fn decide(a: trit, b: trit) -> trit { return consensus(a, b); }");
473    }
474
475    // ── Argument count checking ───────────────────────────────────────────────
476
477    #[test]
478    fn test_call_correct_arity() {
479        check_ok("fn f() -> trit { return consensus(1, -1); }");
480    }
481
482    #[test]
483    fn test_call_too_few_args_caught() {
484        check_err("fn f() -> trit { return consensus(1); }");
485    }
486
487    #[test]
488    fn test_call_too_many_args_caught() {
489        check_err("fn f() -> trit { return invert(1, 1); }");
490    }
491
492    // ── Argument type checking ────────────────────────────────────────────────
493
494    #[test]
495    fn test_call_wrong_arg_type_caught() {
496        // invert expects trit, passing int literal 42 directly — int is not trit
497        check_err("fn f() -> trit { let x: int = 42; return invert(x); }");
498    }
499
500    #[test]
501    fn test_call_correct_arg_type() {
502        check_ok("fn f(a: trit) -> trit { return invert(a); }");
503    }
504
505    // ── Undefined function ────────────────────────────────────────────────────
506
507    #[test]
508    fn test_undefined_function_caught() {
509        check_err("fn f() -> trit { return doesnt_exist(1); }");
510    }
511
512    // ── User-defined function forward references ──────────────────────────────
513
514    #[test]
515    fn test_user_fn_return_type_registered() {
516        check_ok("fn helper(a: trit) -> trit { return invert(a); } fn main() -> trit { return helper(1); }");
517    }
518
519    #[test]
520    fn test_user_fn_wrong_return_caught() {
521        check_err("fn helper(a: trit) -> trit { let x: int = 1; return x; }");
522    }
523
524    // ── Undefined variable ────────────────────────────────────────────────────
525
526    #[test]
527    fn test_undefined_variable_caught() {
528        check_err("fn f() -> trit { return ghost_var; }");
529    }
530
531    #[test]
532    fn test_defined_variable_ok() {
533        check_ok("fn f() -> trit { let x: trit = 1; return x; }");
534    }
535
536    // ── Struct field types ────────────────────────────────────────────────────
537
538    #[test]
539    fn test_struct_field_access_ok() {
540        check_ok("struct S { val: trit } fn f(s: S) -> trit { return s.val; }");
541    }
542}