Skip to main content

ternlang_core/
semantic.rs

1use crate::ast::*;
2
3// ─── Errors ───────────────────────────────────────────────────────────────────
4
5#[derive(Debug)]
6pub enum SemanticError {
7    TypeMismatch { expected: Type, found: Type },
8    UndefinedVariable(String),
9    UndefinedStruct(String),
10    UndefinedField { struct_name: String, field: String },
11    UndefinedFunction(String),
12    ReturnTypeMismatch { function: String, expected: Type, found: Type },
13    ArgCountMismatch { function: String, expected: usize, found: usize },
14    ArgTypeMismatch { function: String, param_index: usize, expected: Type, found: Type },
15    /// `?` used on an expression that doesn't return trit
16    PropagateOnNonTrit { found: Type },
17}
18
19impl std::fmt::Display for SemanticError {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        match self {
22            Self::TypeMismatch { expected, found } =>
23                write!(f, "[TYPE-001] Type mismatch: expected {expected:?}, found {found:?}. A trit is not an int. An int is not a trit. They don't coerce.\n           → details: stdlib/errors/TYPE-001.tern  |  ternlang errors TYPE-001"),
24            Self::UndefinedVariable(n) =>
25                write!(f, "[SCOPE-001] '{n}' is undefined — hold state. Declare before use, or check for a typo.\n            → details: stdlib/errors/SCOPE-001.tern  |  ternlang errors SCOPE-001"),
26            Self::UndefinedStruct(n) =>
27                write!(f, "[STRUCT-001] Struct '{n}' doesn't exist. A ghost type — the type system can't find it anywhere.\n             → details: stdlib/errors/STRUCT-001.tern  |  ternlang errors STRUCT-001"),
28            Self::UndefinedField { struct_name, field } =>
29                write!(f, "[STRUCT-002] Struct '{struct_name}' has no field '{field}'. Check the definition — maybe it was renamed.\n             → details: stdlib/errors/STRUCT-002.tern  |  ternlang errors STRUCT-002"),
30            Self::UndefinedFunction(n) =>
31                write!(f, "[FN-001] '{n}' was called but never defined. Declare it above the call site, or check for a typo.\n          → details: stdlib/errors/FN-001.tern  |  ternlang errors FN-001"),
32            Self::ReturnTypeMismatch { function, expected, found } =>
33                write!(f, "[FN-002] '{function}' promised to return {expected:?} but returned {found:?}. Ternary contracts are strict — all paths must match.\n          → details: stdlib/errors/FN-002.tern  |  ternlang errors FN-002"),
34            Self::ArgCountMismatch { function, expected, found } =>
35                write!(f, "[FN-003] '{function}' expects {expected} arg(s), got {found}. Arity is not optional — not even in hold state.\n          → details: stdlib/errors/FN-003.tern  |  ternlang errors FN-003"),
36            Self::ArgTypeMismatch { function, param_index, expected, found } =>
37                write!(f, "[FN-004] '{function}' arg {param_index}: expected {expected:?}, found {found:?}. Types travel with their values — they don't change at the border.\n          → details: stdlib/errors/FN-004.tern  |  ternlang errors FN-004"),
38            Self::PropagateOnNonTrit { found } =>
39                write!(f, "[PROP-001] '?' used on a {found:?} expression. Only trit-returning functions carry the three-valued signal. The third state requires a trit.\n            → details: stdlib/errors/PROP-001.tern  |  ternlang errors PROP-001"),
40        }
41    }
42}
43
44// ─── Full function signature ──────────────────────────────────────────────────
45
46#[derive(Debug, Clone)]
47pub struct FunctionSig {
48    /// Parameter types in declaration order. None = variadic / unknown (built-ins with flexible arity).
49    pub params: Option<Vec<Type>>,
50    pub return_type: Type,
51}
52
53impl FunctionSig {
54    fn exact(params: Vec<Type>, return_type: Type) -> Self {
55        Self { params: Some(params), return_type }
56    }
57    fn variadic(return_type: Type) -> Self {
58        Self { params: None, return_type }
59    }
60}
61
62// ─── Analyzer ────────────────────────────────────────────────────────────────
63
64pub struct SemanticAnalyzer {
65    scopes:           Vec<std::collections::HashMap<String, Type>>,
66    struct_defs:      std::collections::HashMap<String, Vec<(String, Type)>>,
67    func_signatures:  std::collections::HashMap<String, FunctionSig>,
68    /// Set while type-checking a function body so Return stmts can be validated.
69    current_fn_name:       Option<String>,
70    current_fn_return:     Option<Type>,
71}
72
73impl SemanticAnalyzer {
74    pub fn new() -> Self {
75        let mut sigs: std::collections::HashMap<String, FunctionSig> = std::collections::HashMap::new();
76
77        // ── std::trit built-ins ────────────────────────────────────────────
78        sigs.insert("consensus".into(), FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
79        sigs.insert("invert".into(),    FunctionSig::exact(vec![Type::Trit],             Type::Trit));
80        sigs.insert("length".into(),    FunctionSig::variadic(Type::Int));
81        sigs.insert("truth".into(),     FunctionSig::exact(vec![],                       Type::Trit));
82        sigs.insert("hold".into(),      FunctionSig::exact(vec![],                       Type::Trit));
83        sigs.insert("conflict".into(),  FunctionSig::exact(vec![],                       Type::Trit));
84        sigs.insert("mul".into(),       FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
85
86        // ── std::tensor ────────────────────────────────────────────────────
87        sigs.insert("matmul".into(),   FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
88        sigs.insert("sparsity".into(), FunctionSig::variadic(Type::Int));
89        sigs.insert("shape".into(),    FunctionSig::variadic(Type::Int));
90        sigs.insert("zeros".into(),    FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
91
92        // ── std::io ────────────────────────────────────────────────────────
93        sigs.insert("print".into(),    FunctionSig::variadic(Type::Trit));
94        sigs.insert("println".into(),  FunctionSig::variadic(Type::Trit));
95
96        // ── std::math ──────────────────────────────────────────────────────
97        sigs.insert("abs".into(),      FunctionSig::exact(vec![Type::Int],  Type::Int));
98        sigs.insert("min".into(),      FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
99        sigs.insert("max".into(),      FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
100
101        // ── ml::quantize ───────────────────────────────────────────────────
102        sigs.insert("quantize".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
103        sigs.insert("threshold".into(),FunctionSig::variadic(Type::Float));
104
105        // ── ml::inference ──────────────────────────────────────────────────
106        sigs.insert("forward".into(),  FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
107        sigs.insert("argmax".into(),   FunctionSig::variadic(Type::Int));
108
109        // ── type coercion ──────────────────────────────────────────────────
110        sigs.insert("cast".into(),     FunctionSig::variadic(Type::Trit));
111
112        Self {
113            scopes: vec![std::collections::HashMap::new()],
114            struct_defs: std::collections::HashMap::new(),
115            func_signatures: sigs,
116            current_fn_name: None,
117            current_fn_return: None,
118        }
119    }
120
121    // ── Registration ─────────────────────────────────────────────────────────
122
123    pub fn register_structs(&mut self, structs: &[StructDef]) {
124        for s in structs {
125            self.struct_defs.insert(s.name.clone(), s.fields.clone());
126        }
127    }
128
129    pub fn register_functions(&mut self, functions: &[Function]) {
130        for f in functions {
131            let params = f.params.iter().map(|(_, ty)| ty.clone()).collect();
132            self.func_signatures.insert(
133                f.name.clone(),
134                FunctionSig::exact(params, f.return_type.clone()),
135            );
136        }
137    }
138
139    pub fn register_agents(&mut self, agents: &[AgentDef]) {
140        for agent in agents {
141            for method in &agent.methods {
142                let params = method.params.iter().map(|(_, ty)| ty.clone()).collect();
143                let sig = FunctionSig::exact(params, method.return_type.clone());
144                self.func_signatures.insert(method.name.clone(), sig.clone());
145                self.func_signatures.insert(
146                    format!("{}::{}", agent.name, method.name),
147                    sig,
148                );
149            }
150        }
151    }
152
153    // ── Entry points ─────────────────────────────────────────────────────────
154
155    pub fn check_program(&mut self, program: &Program) -> Result<(), SemanticError> {
156        self.register_structs(&program.structs);
157        self.register_functions(&program.functions);
158        self.register_agents(&program.agents);
159        for agent in &program.agents {
160            for method in &agent.methods {
161                self.check_function(method)?;
162            }
163        }
164        for func in &program.functions {
165            self.check_function(func)?;
166        }
167        Ok(())
168    }
169
170    fn check_function(&mut self, func: &Function) -> Result<(), SemanticError> {
171        // Track return type context for this function body.
172        let prev_name   = self.current_fn_name.take();
173        let prev_return = self.current_fn_return.take();
174        self.current_fn_name   = Some(func.name.clone());
175        self.current_fn_return = Some(func.return_type.clone());
176
177        self.scopes.push(std::collections::HashMap::new());
178        for (name, ty) in &func.params {
179            self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
180        }
181        for stmt in &func.body {
182            self.check_stmt(stmt)?;
183        }
184        self.scopes.pop();
185
186        // Restore outer context (handles nested definitions if ever needed).
187        self.current_fn_name   = prev_name;
188        self.current_fn_return = prev_return;
189        Ok(())
190    }
191
192    // ── Statement checking ───────────────────────────────────────────────────
193
194    pub fn check_stmt(&mut self, stmt: &Stmt) -> Result<(), SemanticError> {
195        match stmt {
196            Stmt::Let { name, ty, value } => {
197                let val_ty = self.infer_expr_type(value)?;
198                let type_ok = val_ty == *ty
199                    || matches!(value, Expr::Cast { .. })
200                    || matches!(value, Expr::StructLiteral { .. }) // Struct literals checked in infer_expr_type
201                    || (*ty == Type::Int && val_ty == Type::Trit)
202                    || (*ty == Type::Trit && val_ty == Type::Int)
203                    || (matches!(ty, Type::Named(_)) && val_ty == Type::Trit)
204                    || (matches!(ty, Type::TritTensor { .. }) && matches!(val_ty, Type::TritTensor { .. }))
205                    || (*ty == Type::AgentRef && val_ty == Type::AgentRef);
206                if !type_ok {
207                    return Err(SemanticError::TypeMismatch { expected: ty.clone(), found: val_ty });
208                }
209                self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
210                Ok(())
211            }
212
213            Stmt::Return(expr) => {
214                let found = self.infer_expr_type(expr)?;
215                if let (Some(fn_name), Some(expected)) = (&self.current_fn_name, &self.current_fn_return) {
216                    // Allow TritTensor shape flexibility and AgentRef, cast
217                    let ok = found == *expected
218                        || matches!(expr, Expr::Cast { .. })
219                        || matches!(expr, Expr::StructLiteral { .. })
220                        || (*expected == Type::Int && found == Type::Trit)
221                        || (*expected == Type::Trit && found == Type::Int)
222                        || (matches!(expected, Type::TritTensor { .. }) && matches!(found, Type::TritTensor { .. }))
223                        || (matches!(expected, Type::Named(_)) && found == Type::Trit);
224                    if !ok {
225                        return Err(SemanticError::ReturnTypeMismatch {
226                            function: fn_name.clone(),
227                            expected: expected.clone(),
228                            found,
229                        });
230                    }
231                }
232                Ok(())
233            }
234
235            Stmt::IfTernary { condition, on_pos, on_zero, on_neg } => {
236                let cond_ty = self.infer_expr_type(condition)?;
237                if cond_ty != Type::Trit {
238                    return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
239                }
240                self.check_stmt(on_pos)?;
241                self.check_stmt(on_zero)?;
242                self.check_stmt(on_neg)?;
243                Ok(())
244            }
245
246            Stmt::Match { condition, arms } => {
247                let cond_ty = self.infer_expr_type(condition)?;
248                if cond_ty != Type::Trit && cond_ty != Type::Int {
249                    return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
250                }
251                
252                if cond_ty == Type::Trit {
253                    // Enforce exhaustiveness and value range for Trit match
254                    let has_pos = arms.iter().any(|(v, _)| *v == 1);
255                    let has_zero = arms.iter().any(|(v, _)| *v == 0);
256                    let has_neg = arms.iter().any(|(v, _)| *v == -1);
257                    if !has_pos || !has_zero || !has_neg {
258                        // We reuse TypeMismatch as a generic error here for simplicity
259                        // in the current schema if a specific exhaustiveness error isn't available.
260                        return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
261                    }
262                    for (val, _) in arms {
263                        if *val < -1 || *val > 1 {
264                            return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: Type::Int });
265                        }
266                    }
267                }
268
269                for (_val, arm_stmt) in arms {
270                    self.check_stmt(arm_stmt)?;
271                }
272                Ok(())
273            }
274
275            Stmt::Block(stmts) => {
276                self.scopes.push(std::collections::HashMap::new());
277                for s in stmts { self.check_stmt(s)?; }
278                self.scopes.pop();
279                Ok(())
280            }
281
282            Stmt::Decorated { stmt, .. } => self.check_stmt(stmt),
283
284            Stmt::Expr(expr) => { self.infer_expr_type(expr)?; Ok(()) }
285
286            Stmt::ForIn { var, iter, body } => {
287                self.infer_expr_type(iter)?;
288                self.scopes.push(std::collections::HashMap::new());
289                self.scopes.last_mut().unwrap().insert(var.clone(), Type::Trit);
290                self.check_stmt(body)?;
291                self.scopes.pop();
292                Ok(())
293            }
294
295            Stmt::WhileTernary { condition, on_pos, on_zero, on_neg } => {
296                let cond_ty = self.infer_expr_type(condition)?;
297                if cond_ty != Type::Trit {
298                    return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
299                }
300                self.check_stmt(on_pos)?;
301                self.check_stmt(on_zero)?;
302                self.check_stmt(on_neg)?;
303                Ok(())
304            }
305
306            Stmt::Loop { body }   => self.check_stmt(body),
307            Stmt::Break           => Ok(()),
308            Stmt::Continue        => Ok(()),
309            Stmt::Use { .. }      => Ok(()),
310            Stmt::FromImport { .. } => Ok(()),
311
312            Stmt::Send { target, message } => {
313                self.infer_expr_type(target)?;
314                self.infer_expr_type(message)?;
315                Ok(())
316            }
317
318            Stmt::FieldSet { object, field, value } => {
319                let obj_ty = self.lookup_var(object)?;
320                if let Type::Named(struct_name) = obj_ty {
321                    let field_ty = self.lookup_field(&struct_name, field)?;
322                    let val_ty   = self.infer_expr_type(value)?;
323                    if val_ty != field_ty {
324                        return Err(SemanticError::TypeMismatch { expected: field_ty, found: val_ty });
325                    }
326                } else {
327                    self.infer_expr_type(value)?;
328                }
329                Ok(())
330            }
331
332            Stmt::IndexSet { object, row, col, value } => {
333                self.lookup_var(object)?;
334                self.infer_expr_type(row)?;
335                self.infer_expr_type(col)?;
336                self.infer_expr_type(value)?;
337                Ok(())
338            }
339
340            Stmt::Set { name, value } => {
341                let var_ty = self.lookup_var(name)?;
342                let val_ty = self.infer_expr_type(value)?;
343                let ok = var_ty == val_ty
344                    || matches!(value, Expr::Cast { .. })
345                    || (var_ty == Type::Int && val_ty == Type::Trit)
346                    || (var_ty == Type::Trit && val_ty == Type::Int);
347                if !ok {
348                    return Err(SemanticError::TypeMismatch { expected: var_ty, found: val_ty });
349                }
350                Ok(())
351            }
352        }
353    }
354
355    // ── Expression type inference ─────────────────────────────────────────────
356
357    fn infer_expr_type(&self, expr: &Expr) -> Result<Type, SemanticError> {
358        match expr {
359            Expr::TritLiteral(_)   => Ok(Type::Trit),
360            Expr::IntLiteral(_)    => Ok(Type::Int),
361            Expr::FloatLiteral(_)  => Ok(Type::Float),
362            Expr::StringLiteral(_) => Ok(Type::String),
363            Expr::Ident(name)      => self.lookup_var(name),
364
365            Expr::BinaryOp { op, lhs, rhs } => {
366                let l = self.infer_expr_type(lhs)?;
367                let r = self.infer_expr_type(rhs)?;
368                match op {
369                    BinOp::Less | BinOp::Greater | BinOp::LessEqual | BinOp::GreaterEqual | BinOp::Equal | BinOp::NotEqual | BinOp::And | BinOp::Or => {
370                        Ok(Type::Trit)
371                    }
372
373                    _ => {
374                        // Allow cross-type Numeric operations (Int vs Trit)
375                        let is_numeric = |t: &Type| matches!(t, Type::Int | Type::Trit | Type::Float);
376                        if is_numeric(&l) && is_numeric(&r) {
377                            if l == Type::Float || r == Type::Float { return Ok(Type::Float); }
378                            if l == Type::Int || r == Type::Int { return Ok(Type::Int); }
379                            return Ok(Type::Trit);
380                        }
381
382                        if l != r {
383                            return Err(SemanticError::TypeMismatch { expected: l, found: r });
384                        }
385                        Ok(l)
386                    }
387                }
388            }
389
390            Expr::UnaryOp { expr, .. } => self.infer_expr_type(expr),
391
392            Expr::Call { callee, args } => {
393                let sig = self.func_signatures.get(callee.as_str())
394                    .ok_or_else(|| SemanticError::UndefinedFunction(callee.clone()))?
395                    .clone();
396
397                // Argument arity + type checking (only for exact signatures).
398                if let Some(param_types) = &sig.params {
399                    if args.len() != param_types.len() {
400                        return Err(SemanticError::ArgCountMismatch {
401                            function: callee.clone(),
402                            expected: param_types.len(),
403                            found:    args.len(),
404                        });
405                    }
406                    for (i, (arg, expected_ty)) in args.iter().zip(param_types.iter()).enumerate() {
407                        let found_ty = self.infer_expr_type(arg)?;
408                        // Allow TritTensor shape flexibility and cast coercion.
409                        let ok = found_ty == *expected_ty
410                            || matches!(arg, Expr::Cast { .. })
411                            || (expected_ty == &Type::Int && found_ty == Type::Trit)
412                            || (expected_ty == &Type::Trit && found_ty == Type::Int)
413                            || (matches!(expected_ty, Type::TritTensor { .. })
414                                && matches!(found_ty, Type::TritTensor { .. }))
415                            || (matches!(expected_ty, Type::Named(_)) && found_ty == Type::Trit);
416                        if !ok {
417                            return Err(SemanticError::ArgTypeMismatch {
418                                function:    callee.clone(),
419                                param_index: i,
420                                expected:    expected_ty.clone(),
421                                found:       found_ty,
422                            });
423                        }
424                    }
425                } else {
426                    // Variadic — still infer arg types to catch undefined variables.
427                    for arg in args { self.infer_expr_type(arg)?; }
428                }
429
430                Ok(sig.return_type)
431            }
432
433            Expr::Cast { ty, .. }     => Ok(ty.clone()),
434            Expr::Spawn { .. }        => Ok(Type::AgentRef),
435            Expr::Await { .. }        => Ok(Type::Trit),
436            Expr::NodeId              => Ok(Type::String),
437
438            Expr::Propagate { expr } => {
439                let inner = self.infer_expr_type(expr)?;
440                if inner != Type::Trit {
441                    return Err(SemanticError::PropagateOnNonTrit { found: inner });
442                }
443                Ok(Type::Trit)
444            }
445
446            Expr::TritTensorLiteral(vals) => {
447                Ok(Type::TritTensor { dims: vec![vals.len()] })
448            }
449
450            Expr::StructLiteral { name, fields } => {
451                // Verify struct exists and fields match
452                let def = self.struct_defs.get(name)
453                    .ok_or_else(|| SemanticError::UndefinedStruct(name.clone()))?;
454                
455                if fields.len() != def.len() {
456                    return Err(SemanticError::ArgCountMismatch { 
457                        function: name.clone(), 
458                        expected: def.len(), 
459                        found: fields.len() 
460                    });
461                }
462
463                for (f_name, f_val) in fields {
464                    let expected_f_ty = def.iter()
465                        .find(|(n, _)| n == f_name)
466                        .ok_or_else(|| SemanticError::UndefinedField { 
467                            struct_name: name.clone(), 
468                            field: f_name.clone() 
469                        })?
470                        .1.clone();
471                    let found_f_ty = self.infer_expr_type(f_val)?;
472                    if found_f_ty != expected_f_ty {
473                        return Err(SemanticError::TypeMismatch { 
474                            expected: expected_f_ty, 
475                            found: found_f_ty 
476                        });
477                    }
478                }
479                Ok(Type::Named(name.clone()))
480            }
481
482            Expr::FieldAccess { object, field } => {
483                let obj_ty = self.infer_expr_type(object)?;
484                if let Type::Named(struct_name) = obj_ty {
485                    self.lookup_field(&struct_name, field)
486                } else {
487                    Ok(Type::Trit)
488                }
489            }
490
491            Expr::Index { object, row, col } => {
492                self.infer_expr_type(object)?;
493                self.infer_expr_type(row)?;
494                self.infer_expr_type(col)?;
495                Ok(Type::Trit)
496            }
497        }
498    }
499
500    // ── Scope helpers ─────────────────────────────────────────────────────────
501
502    fn lookup_var(&self, name: &str) -> Result<Type, SemanticError> {
503        for scope in self.scopes.iter().rev() {
504            if let Some(ty) = scope.get(name) { return Ok(ty.clone()); }
505        }
506        Err(SemanticError::UndefinedVariable(name.to_string()))
507    }
508
509    fn lookup_field(&self, struct_name: &str, field: &str) -> Result<Type, SemanticError> {
510        let fields = self.struct_defs.get(struct_name)
511            .ok_or_else(|| SemanticError::UndefinedStruct(struct_name.to_string()))?;
512        fields.iter()
513            .find(|(f, _)| f == field)
514            .map(|(_, ty)| ty.clone())
515            .ok_or_else(|| SemanticError::UndefinedField {
516                struct_name: struct_name.to_string(),
517                field: field.to_string(),
518            })
519    }
520}
521
522// ─── Tests ───────────────────────────────────────────────────────────────────
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use crate::parser::Parser;
528
529    fn check(src: &str) -> Result<(), SemanticError> {
530        let mut parser = Parser::new(src);
531        let prog = parser.parse_program().expect("parse failed");
532        let mut analyzer = SemanticAnalyzer::new();
533        analyzer.check_program(&prog)
534    }
535
536    fn check_ok(src: &str) {
537        assert!(check(src).is_ok(), "expected ok, got: {:?}", check(src));
538    }
539
540    fn check_err(src: &str) {
541        assert!(check(src).is_err(), "expected error but check passed");
542    }
543
544    // ── Return type validation ────────────────────────────────────────────────
545
546    #[test]
547    fn test_return_correct_type() {
548        check_ok("fn f() -> trit { return 1; }");
549    }
550
551    #[test]
552    fn test_return_int_in_trit_fn() {
553        // Now allowed via implicit coercion
554        check_ok("fn f() -> trit { let x: int = 42; return x; }");
555    }
556
557    #[test]
558    fn test_return_trit_in_trit_fn() {
559        check_ok("fn decide(a: trit, b: trit) -> trit { return consensus(a, b); }");
560    }
561
562    // ── Argument count checking ───────────────────────────────────────────────
563
564    #[test]
565    fn test_call_correct_arity() {
566        check_ok("fn f() -> trit { return consensus(1, -1); }");
567    }
568
569    #[test]
570    fn test_call_too_few_args_caught() {
571        check_err("fn f() -> trit { return consensus(1); }");
572    }
573
574    #[test]
575    fn test_call_too_many_args_caught() {
576        check_err("fn f() -> trit { return invert(1, 1); }");
577    }
578
579    // ── Argument type checking ────────────────────────────────────────────────
580
581    #[test]
582    fn test_call_int_arg_in_trit_fn() {
583        // Now allowed via implicit coercion
584        check_ok("fn f(a: trit) -> trit { return invert(a); } fn main() -> trit { let x: int = 42; return f(x); }");
585    }
586
587    #[test]
588    fn test_call_correct_arg_type() {
589        check_ok("fn f(a: trit) -> trit { return invert(a); }");
590    }
591
592    // ── Undefined function ────────────────────────────────────────────────────
593
594    #[test]
595    fn test_undefined_function_caught() {
596        check_err("fn f() -> trit { return doesnt_exist(1); }");
597    }
598
599    // ── User-defined function forward references ──────────────────────────────
600
601    #[test]
602    fn test_user_fn_return_type_registered() {
603        check_ok("fn helper(a: trit) -> trit { return invert(a); } fn main() -> trit { return helper(1); }");
604    }
605
606    #[test]
607    fn test_user_fn_int_return_ok() {
608        // Now allowed
609        check_ok("fn helper(a: trit) -> trit { let x: int = 1; return x; }");
610    }
611
612    // ── Undefined variable ────────────────────────────────────────────────────
613
614    #[test]
615    fn test_undefined_variable_caught() {
616        check_err("fn f() -> trit { return ghost_var; }");
617    }
618
619    #[test]
620    fn test_defined_variable_ok() {
621        check_ok("fn f() -> trit { let x: trit = 1; return x; }");
622    }
623
624    // ── Struct field types ────────────────────────────────────────────────────
625
626    #[test]
627    fn test_struct_field_access_ok() {
628        check_ok("struct S { val: trit } fn f(s: S) -> trit { return s.val; }");
629    }
630}