Skip to main content

ternlang_core/
semantic.rs

1use crate::ast::*;
2
3// ─── Errors ───────────────────────────────────────────────────────────────────
4
5#[derive(Debug)]
6pub enum SemanticError {
7    TypeMismatch { expected: Type, found: Type },
8    UndefinedVariable(String),
9    UndefinedStruct(String),
10    UndefinedField { struct_name: String, field: String },
11    UndefinedFunction(String),
12    ReturnTypeMismatch { function: String, expected: Type, found: Type },
13    ArgCountMismatch { function: String, expected: usize, found: usize },
14    ArgTypeMismatch { function: String, param_index: usize, expected: Type, found: Type },
15    /// `?` used on an expression that doesn't return trit
16    PropagateOnNonTrit { found: Type },
17    NonExhaustiveMatch(String),
18}
19
20impl std::fmt::Display for SemanticError {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            Self::TypeMismatch { expected, found } =>
24                write!(f, "[TYPE-001] Type mismatch: expected {expected:?}, found {found:?}. A trit is not an int. An int is not a trit. They don't coerce.\n           → details: stdlib/errors/TYPE-001.tern  |  ternlang errors TYPE-001"),
25            Self::UndefinedVariable(n) =>
26                write!(f, "[SCOPE-001] '{n}' is undefined — hold state. Declare before use, or check for a typo.\n            → details: stdlib/errors/SCOPE-001.tern  |  ternlang errors SCOPE-001"),
27            Self::UndefinedStruct(n) =>
28                write!(f, "[STRUCT-001] Struct '{n}' doesn't exist. A ghost type — the type system can't find it anywhere.\n             → details: stdlib/errors/STRUCT-001.tern  |  ternlang errors STRUCT-001"),
29            Self::UndefinedField { struct_name, field } =>
30                write!(f, "[STRUCT-002] Struct '{struct_name}' has no field '{field}'. Check the definition — maybe it was renamed.\n             → details: stdlib/errors/STRUCT-002.tern  |  ternlang errors STRUCT-002"),
31            Self::UndefinedFunction(n) =>
32                write!(f, "[FN-001] '{n}' was called but never defined. Declare it above the call site, or check for a typo.\n          → details: stdlib/errors/FN-001.tern  |  ternlang errors FN-001"),
33            Self::ReturnTypeMismatch { function, expected, found } =>
34                write!(f, "[FN-002] '{function}' promised to return {expected:?} but returned {found:?}. Ternary contracts are strict — all paths must match.\n          → details: stdlib/errors/FN-002.tern  |  ternlang errors FN-002"),
35            Self::ArgCountMismatch { function, expected, found } =>
36                write!(f, "[FN-003] '{function}' expects {expected} arg(s), got {found}. Arity is not optional — not even in hold state.\n          → details: stdlib/errors/FN-003.tern  |  ternlang errors FN-003"),
37            Self::ArgTypeMismatch { function, param_index, expected, found } =>
38                write!(f, "[FN-004] '{function}' arg {param_index}: expected {expected:?}, found {found:?}. Types travel with their values — they don't change at the border.\n          → details: stdlib/errors/FN-004.tern  |  ternlang errors FN-004"),
39            Self::PropagateOnNonTrit { found } =>
40                write!(f, "[PROP-001] '?' used on a {found:?} expression. Only trit-returning functions carry the three-valued signal. The third state requires a trit.\n            → details: stdlib/errors/PROP-001.tern  |  ternlang errors PROP-001"),
41            Self::NonExhaustiveMatch(msg) =>
42                write!(f, "[MATCH-001] Non-exhaustive match: {msg}. Ternary has three states — all must be covered.\n            → details: stdlib/errors/MATCH-001.tern  |  ternlang errors MATCH-001"),
43        }
44    }
45}
46
47// ─── Full function signature ──────────────────────────────────────────────────
48
49#[derive(Debug, Clone)]
50pub struct FunctionSig {
51    /// Parameter types in declaration order. None = variadic / unknown (built-ins with flexible arity).
52    pub params: Option<Vec<Type>>,
53    pub return_type: Type,
54}
55
56impl FunctionSig {
57    fn exact(params: Vec<Type>, return_type: Type) -> Self {
58        Self { params: Some(params), return_type }
59    }
60    fn variadic(return_type: Type) -> Self {
61        Self { params: None, return_type }
62    }
63}
64
65// ─── Analyzer ────────────────────────────────────────────────────────────────
66
67pub struct SemanticAnalyzer {
68    scopes:           Vec<std::collections::HashMap<String, Type>>,
69    struct_defs:      std::collections::HashMap<String, Vec<(String, Type)>>,
70    func_signatures:  std::collections::HashMap<String, FunctionSig>,
71    /// Set while type-checking a function body so Return stmts can be validated.
72    current_fn_name:       Option<String>,
73    current_fn_return:     Option<Type>,
74}
75
76impl SemanticAnalyzer {
77    pub fn new() -> Self {
78        let mut sigs: std::collections::HashMap<String, FunctionSig> = std::collections::HashMap::new();
79
80        // ── std::trit built-ins ────────────────────────────────────────────
81        sigs.insert("consensus".into(), FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
82        sigs.insert("invert".into(),    FunctionSig::exact(vec![Type::Trit],             Type::Trit));
83        sigs.insert("length".into(),    FunctionSig::variadic(Type::Int));
84        sigs.insert("truth".into(),     FunctionSig::exact(vec![],                       Type::Trit));
85        sigs.insert("hold".into(),      FunctionSig::exact(vec![],                       Type::Trit));
86        sigs.insert("conflict".into(),  FunctionSig::exact(vec![],                       Type::Trit));
87        sigs.insert("mul".into(),       FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
88
89        // ── std::tensor ────────────────────────────────────────────────────
90        sigs.insert("matmul".into(),   FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
91        sigs.insert("sparsity".into(), FunctionSig::variadic(Type::Int));
92        sigs.insert("shape".into(),    FunctionSig::variadic(Type::Int));
93        sigs.insert("zeros".into(),    FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
94
95        // ── std::io ────────────────────────────────────────────────────────
96        sigs.insert("print".into(),    FunctionSig::variadic(Type::Trit));
97        sigs.insert("println".into(),  FunctionSig::variadic(Type::Trit));
98
99        // ── std::math ──────────────────────────────────────────────────────
100        sigs.insert("abs".into(),      FunctionSig::exact(vec![Type::Int],  Type::Int));
101        sigs.insert("min".into(),      FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
102        sigs.insert("max".into(),      FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
103
104        // ── ml::quantize ───────────────────────────────────────────────────
105        sigs.insert("quantize".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
106        sigs.insert("threshold".into(),FunctionSig::variadic(Type::Float));
107
108        // ── ml::inference ──────────────────────────────────────────────────
109        sigs.insert("forward".into(),  FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
110        sigs.insert("argmax".into(),   FunctionSig::variadic(Type::Int));
111
112        // ── type coercion ──────────────────────────────────────────────────
113        sigs.insert("cast".into(),     FunctionSig::variadic(Type::Trit));
114
115        // ── trace / XAI ────────────────────────────────────────────────────
116        // explain(label: str, value: any) -> trit
117        // Logs a structured decision trace entry; returns hold().
118        sigs.insert("explain".into(),  FunctionSig::variadic(Type::Trit));
119
120        Self {
121            scopes: vec![std::collections::HashMap::new()],
122            struct_defs: std::collections::HashMap::new(),
123            func_signatures: sigs,
124            current_fn_name: None,
125            current_fn_return: None,
126        }
127    }
128
129    // ── Registration ─────────────────────────────────────────────────────────
130
131    pub fn register_structs(&mut self, structs: &[StructDef]) {
132        for s in structs {
133            self.struct_defs.insert(s.name.clone(), s.fields.clone());
134        }
135    }
136
137    pub fn register_functions(&mut self, functions: &[Function]) {
138        for f in functions {
139            let params = f.params.iter().map(|(_, ty)| ty.clone()).collect();
140            self.func_signatures.insert(
141                f.name.clone(),
142                FunctionSig::exact(params, f.return_type.clone()),
143            );
144        }
145    }
146
147    pub fn register_agents(&mut self, agents: &[AgentDef]) {
148        for agent in agents {
149            for method in &agent.methods {
150                let params = method.params.iter().map(|(_, ty)| ty.clone()).collect();
151                let sig = FunctionSig::exact(params, method.return_type.clone());
152                self.func_signatures.insert(method.name.clone(), sig.clone());
153                self.func_signatures.insert(
154                    format!("{}::{}", agent.name, method.name),
155                    sig,
156                );
157            }
158        }
159    }
160
161    // ── Entry points ─────────────────────────────────────────────────────────
162
163    pub fn check_program(&mut self, program: &Program) -> Result<(), SemanticError> {
164        self.register_structs(&program.structs);
165        self.register_functions(&program.functions);
166        self.register_agents(&program.agents);
167        for agent in &program.agents {
168            for method in &agent.methods {
169                self.check_function(method)?;
170            }
171        }
172        for func in &program.functions {
173            self.check_function(func)?;
174        }
175        Ok(())
176    }
177
178    fn check_function(&mut self, func: &Function) -> Result<(), SemanticError> {
179        // Track return type context for this function body.
180        let prev_name   = self.current_fn_name.take();
181        let prev_return = self.current_fn_return.take();
182        self.current_fn_name   = Some(func.name.clone());
183        self.current_fn_return = Some(func.return_type.clone());
184
185        self.scopes.push(std::collections::HashMap::new());
186        for (name, ty) in &func.params {
187            self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
188        }
189        for stmt in &func.body {
190            self.check_stmt(stmt)?;
191        }
192        self.scopes.pop();
193
194        // Restore outer context (handles nested definitions if ever needed).
195        self.current_fn_name   = prev_name;
196        self.current_fn_return = prev_return;
197        Ok(())
198    }
199
200    // ── Statement checking ───────────────────────────────────────────────────
201
202    pub fn check_stmt(&mut self, stmt: &Stmt) -> Result<(), SemanticError> {
203        match stmt {
204            Stmt::Let { name, ty, value } => {
205                let val_ty = self.infer_expr_type(value)?;
206                let type_ok = val_ty == *ty
207                    || matches!(value, Expr::Cast { .. })
208                    || matches!(value, Expr::StructLiteral { .. }) // Struct literals checked in infer_expr_type
209                    || (*ty == Type::Int && val_ty == Type::Trit)
210                    || (*ty == Type::Trit && val_ty == Type::Int)
211                    || (matches!(ty, Type::Named(_)) && val_ty == Type::Trit)
212                    || (matches!(ty, Type::TritTensor { .. }) && matches!(val_ty, Type::TritTensor { .. }))
213                    || (*ty == Type::AgentRef && val_ty == Type::AgentRef);
214                if !type_ok {
215                    return Err(SemanticError::TypeMismatch { expected: ty.clone(), found: val_ty });
216                }
217                self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
218                Ok(())
219            }
220
221            Stmt::Return(expr) => {
222                let found = self.infer_expr_type(expr)?;
223                if let (Some(fn_name), Some(expected)) = (&self.current_fn_name, &self.current_fn_return) {
224                    // Allow TritTensor shape flexibility and AgentRef, cast
225                    let ok = found == *expected
226                        || matches!(expr, Expr::Cast { .. })
227                        || matches!(expr, Expr::StructLiteral { .. })
228                        || (*expected == Type::Int && found == Type::Trit)
229                        || (*expected == Type::Trit && found == Type::Int)
230                        || (matches!(expected, Type::TritTensor { .. }) && matches!(found, Type::TritTensor { .. }))
231                        || (matches!(expected, Type::Named(_)) && found == Type::Trit);
232                    if !ok {
233                        return Err(SemanticError::ReturnTypeMismatch {
234                            function: fn_name.clone(),
235                            expected: expected.clone(),
236                            found,
237                        });
238                    }
239                }
240                Ok(())
241            }
242
243            Stmt::IfTernary { condition, on_pos, on_zero, on_neg } => {
244                let cond_ty = self.infer_expr_type(condition)?;
245                if cond_ty != Type::Trit {
246                    return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
247                }
248                self.check_stmt(on_pos)?;
249                self.check_stmt(on_zero)?;
250                self.check_stmt(on_neg)?;
251                Ok(())
252            }
253
254            Stmt::Match { condition, arms } => {
255                let cond_ty = self.infer_expr_type(condition)?;
256                if cond_ty != Type::Trit && cond_ty != Type::Int && cond_ty != Type::Float {
257                    return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
258                }
259                
260                if cond_ty == Type::Trit {
261                    // Enforce exhaustiveness and value range for Trit match
262                    let has_pos = arms.iter().any(|(p, _)| matches!(p, Pattern::Trit(1) | Pattern::Int(1)));
263                    let has_wildcard = arms.iter().any(|(p, _)| matches!(p, Pattern::Wildcard));
264                    if !has_wildcard {
265                        let has_zero = arms.iter().any(|(p, _)| matches!(p, Pattern::Trit(0) | Pattern::Int(0)));
266                        let has_neg = arms.iter().any(|(p, _)| matches!(p, Pattern::Trit(-1) | Pattern::Int(-1)));
267                        if !has_pos || !has_zero || !has_neg {
268                            return Err(SemanticError::NonExhaustiveMatch("Trit match must cover -1, 0, and 1 (or use _ wildcard)".into()));
269                        }
270                    }
271                    for (pattern, _) in arms {
272                        match pattern {
273                            Pattern::Trit(v) => if *v < -1 || *v > 1 { return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: Type::Int }); }
274                            Pattern::Int(v)  => if *v < -1 || *v > 1 { return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: Type::Int }); }
275                            Pattern::Float(_) => return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: Type::Float }),
276                            Pattern::Wildcard => {} // valid in any match
277                        }
278                    }
279                }
280
281                for (_pattern, arm_stmt) in arms {
282                    self.check_stmt(arm_stmt)?;
283                }
284                Ok(())
285            }
286
287            Stmt::Block(stmts) => {
288                self.scopes.push(std::collections::HashMap::new());
289                for s in stmts { self.check_stmt(s)?; }
290                self.scopes.pop();
291                Ok(())
292            }
293
294            Stmt::Decorated { stmt, .. } => self.check_stmt(stmt),
295
296            Stmt::Expr(expr) => { self.infer_expr_type(expr)?; Ok(()) }
297
298            Stmt::ForIn { var, iter, body } => {
299                self.infer_expr_type(iter)?;
300                self.scopes.push(std::collections::HashMap::new());
301                self.scopes.last_mut().unwrap().insert(var.clone(), Type::Trit);
302                self.check_stmt(body)?;
303                self.scopes.pop();
304                Ok(())
305            }
306
307            Stmt::WhileTernary { condition, on_pos, on_zero, on_neg } => {
308                let cond_ty = self.infer_expr_type(condition)?;
309                if cond_ty != Type::Trit {
310                    return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
311                }
312                self.check_stmt(on_pos)?;
313                self.check_stmt(on_zero)?;
314                self.check_stmt(on_neg)?;
315                Ok(())
316            }
317
318            Stmt::Loop { body }   => self.check_stmt(body),
319            Stmt::Break           => Ok(()),
320            Stmt::Continue        => Ok(()),
321            Stmt::Use { .. }      => Ok(()),
322            Stmt::FromImport { .. } => Ok(()),
323
324            Stmt::Send { target, message } => {
325                self.infer_expr_type(target)?;
326                self.infer_expr_type(message)?;
327                Ok(())
328            }
329
330            Stmt::FieldSet { object, field, value } => {
331                let obj_ty = self.lookup_var(object)?;
332                if let Type::Named(struct_name) = obj_ty {
333                    let field_ty = self.lookup_field(&struct_name, field)?;
334                    let val_ty   = self.infer_expr_type(value)?;
335                    if val_ty != field_ty {
336                        return Err(SemanticError::TypeMismatch { expected: field_ty, found: val_ty });
337                    }
338                } else {
339                    self.infer_expr_type(value)?;
340                }
341                Ok(())
342            }
343
344            Stmt::IndexSet { object, row, col, value } => {
345                self.lookup_var(object)?;
346                self.infer_expr_type(row)?;
347                self.infer_expr_type(col)?;
348                self.infer_expr_type(value)?;
349                Ok(())
350            }
351
352            Stmt::Set { name, value } => {
353                let var_ty = self.lookup_var(name)?;
354                let val_ty = self.infer_expr_type(value)?;
355                let ok = var_ty == val_ty
356                    || matches!(value, Expr::Cast { .. })
357                    || (var_ty == Type::Int && val_ty == Type::Trit)
358                    || (var_ty == Type::Trit && val_ty == Type::Int);
359                if !ok {
360                    return Err(SemanticError::TypeMismatch { expected: var_ty, found: val_ty });
361                }
362                Ok(())
363            }
364        }
365    }
366
367    // ── Expression type inference ─────────────────────────────────────────────
368
369    fn infer_expr_type(&self, expr: &Expr) -> Result<Type, SemanticError> {
370        match expr {
371            Expr::TritLiteral(_)   => Ok(Type::Trit),
372            Expr::IntLiteral(_)    => Ok(Type::Int),
373            Expr::FloatLiteral(_)  => Ok(Type::Float),
374            Expr::StringLiteral(_) => Ok(Type::String),
375            Expr::Ident(name)      => self.lookup_var(name),
376
377            Expr::BinaryOp { op, lhs, rhs } => {
378                let l = self.infer_expr_type(lhs)?;
379                let r = self.infer_expr_type(rhs)?;
380                match op {
381                    BinOp::Less | BinOp::Greater | BinOp::LessEqual | BinOp::GreaterEqual | BinOp::Equal | BinOp::NotEqual | BinOp::And | BinOp::Or => {
382                        Ok(Type::Trit)
383                    }
384
385                    _ => {
386                        // Allow cross-type Numeric operations (Int vs Trit)
387                        let is_numeric = |t: &Type| matches!(t, Type::Int | Type::Trit | Type::Float);
388                        if is_numeric(&l) && is_numeric(&r) {
389                            if l == Type::Float || r == Type::Float { return Ok(Type::Float); }
390                            if l == Type::Int || r == Type::Int { return Ok(Type::Int); }
391                            return Ok(Type::Trit);
392                        }
393
394                        if l != r {
395                            return Err(SemanticError::TypeMismatch { expected: l, found: r });
396                        }
397                        Ok(l)
398                    }
399                }
400            }
401
402            Expr::UnaryOp { expr, .. } => self.infer_expr_type(expr),
403
404            Expr::Call { callee, args } => {
405                let sig = self.func_signatures.get(callee.as_str())
406                    .ok_or_else(|| SemanticError::UndefinedFunction(callee.clone()))?
407                    .clone();
408
409                // Argument arity + type checking (only for exact signatures).
410                if let Some(param_types) = &sig.params {
411                    if args.len() != param_types.len() {
412                        return Err(SemanticError::ArgCountMismatch {
413                            function: callee.clone(),
414                            expected: param_types.len(),
415                            found:    args.len(),
416                        });
417                    }
418                    for (i, (arg, expected_ty)) in args.iter().zip(param_types.iter()).enumerate() {
419                        let found_ty = self.infer_expr_type(arg)?;
420                        // Allow TritTensor shape flexibility and cast coercion.
421                        let ok = found_ty == *expected_ty
422                            || matches!(arg, Expr::Cast { .. })
423                            || (expected_ty == &Type::Int && found_ty == Type::Trit)
424                            || (expected_ty == &Type::Trit && found_ty == Type::Int)
425                            || (matches!(expected_ty, Type::TritTensor { .. })
426                                && matches!(found_ty, Type::TritTensor { .. }))
427                            || (matches!(expected_ty, Type::Named(_)) && found_ty == Type::Trit);
428                        if !ok {
429                            return Err(SemanticError::ArgTypeMismatch {
430                                function:    callee.clone(),
431                                param_index: i,
432                                expected:    expected_ty.clone(),
433                                found:       found_ty,
434                            });
435                        }
436                    }
437                } else {
438                    // Variadic — still infer arg types to catch undefined variables.
439                    for arg in args { self.infer_expr_type(arg)?; }
440                }
441
442                Ok(sig.return_type)
443            }
444
445            Expr::Cast { ty, .. }     => Ok(ty.clone()),
446            Expr::Spawn { .. }        => Ok(Type::AgentRef),
447            Expr::Await { .. }        => Ok(Type::Trit),
448            Expr::NodeId              => Ok(Type::String),
449
450            Expr::Propagate { expr } => {
451                let inner = self.infer_expr_type(expr)?;
452                if inner != Type::Trit {
453                    return Err(SemanticError::PropagateOnNonTrit { found: inner });
454                }
455                Ok(Type::Trit)
456            }
457
458            Expr::TritTensorLiteral(vals) => {
459                Ok(Type::TritTensor { dims: vec![vals.len()] })
460            }
461
462            Expr::StructLiteral { name, fields } => {
463                // Verify struct exists and fields match
464                let def = self.struct_defs.get(name)
465                    .ok_or_else(|| SemanticError::UndefinedStruct(name.clone()))?;
466                
467                if fields.len() != def.len() {
468                    return Err(SemanticError::ArgCountMismatch { 
469                        function: name.clone(), 
470                        expected: def.len(), 
471                        found: fields.len() 
472                    });
473                }
474
475                for (f_name, f_val) in fields {
476                    let expected_f_ty = def.iter()
477                        .find(|(n, _)| n == f_name)
478                        .ok_or_else(|| SemanticError::UndefinedField { 
479                            struct_name: name.clone(), 
480                            field: f_name.clone() 
481                        })?
482                        .1.clone();
483                    let found_f_ty = self.infer_expr_type(f_val)?;
484                    if found_f_ty != expected_f_ty {
485                        return Err(SemanticError::TypeMismatch { 
486                            expected: expected_f_ty, 
487                            found: found_f_ty 
488                        });
489                    }
490                }
491                Ok(Type::Named(name.clone()))
492            }
493
494            Expr::FieldAccess { object, field } => {
495                let obj_ty = self.infer_expr_type(object)?;
496                if let Type::Named(struct_name) = obj_ty {
497                    self.lookup_field(&struct_name, field)
498                } else {
499                    Ok(Type::Trit)
500                }
501            }
502
503            Expr::Index { object, row, col } => {
504                self.infer_expr_type(object)?;
505                self.infer_expr_type(row)?;
506                self.infer_expr_type(col)?;
507                Ok(Type::Trit)
508            }
509
510            Expr::Slice { object, start, end, stride } => {
511                self.infer_expr_type(object)?;
512                self.infer_expr_type(start)?;
513                self.infer_expr_type(end)?;
514                self.infer_expr_type(stride)?;
515                // Returns a 1D view (TensorView is treated as TensorRef in types)
516                Ok(Type::TritTensor { dims: vec![0] })
517            }
518        }
519    }
520
521    // ── Scope helpers ─────────────────────────────────────────────────────────
522
523    fn lookup_var(&self, name: &str) -> Result<Type, SemanticError> {
524        for scope in self.scopes.iter().rev() {
525            if let Some(ty) = scope.get(name) { return Ok(ty.clone()); }
526        }
527        Err(SemanticError::UndefinedVariable(name.to_string()))
528    }
529
530    fn lookup_field(&self, struct_name: &str, field: &str) -> Result<Type, SemanticError> {
531        let fields = self.struct_defs.get(struct_name)
532            .ok_or_else(|| SemanticError::UndefinedStruct(struct_name.to_string()))?;
533        fields.iter()
534            .find(|(f, _)| f == field)
535            .map(|(_, ty)| ty.clone())
536            .ok_or_else(|| SemanticError::UndefinedField {
537                struct_name: struct_name.to_string(),
538                field: field.to_string(),
539            })
540    }
541}
542
543// ─── Tests ───────────────────────────────────────────────────────────────────
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548    use crate::parser::Parser;
549
550    fn check(src: &str) -> Result<(), SemanticError> {
551        let mut parser = Parser::new(src);
552        let prog = parser.parse_program().expect("parse failed");
553        let mut analyzer = SemanticAnalyzer::new();
554        analyzer.check_program(&prog)
555    }
556
557    fn check_ok(src: &str) {
558        assert!(check(src).is_ok(), "expected ok, got: {:?}", check(src));
559    }
560
561    fn check_err(src: &str) {
562        assert!(check(src).is_err(), "expected error but check passed");
563    }
564
565    // ── Return type validation ────────────────────────────────────────────────
566
567    #[test]
568    fn test_return_correct_type() {
569        check_ok("fn f() -> trit { return 1; }");
570    }
571
572    #[test]
573    fn test_return_int_in_trit_fn() {
574        // Now allowed via implicit coercion
575        check_ok("fn f() -> trit { let x: int = 42; return x; }");
576    }
577
578    #[test]
579    fn test_return_trit_in_trit_fn() {
580        check_ok("fn decide(a: trit, b: trit) -> trit { return consensus(a, b); }");
581    }
582
583    // ── Argument count checking ───────────────────────────────────────────────
584
585    #[test]
586    fn test_call_correct_arity() {
587        check_ok("fn f() -> trit { return consensus(1, -1); }");
588    }
589
590    #[test]
591    fn test_call_too_few_args_caught() {
592        check_err("fn f() -> trit { return consensus(1); }");
593    }
594
595    #[test]
596    fn test_call_too_many_args_caught() {
597        check_err("fn f() -> trit { return invert(1, 1); }");
598    }
599
600    // ── Argument type checking ────────────────────────────────────────────────
601
602    #[test]
603    fn test_call_int_arg_in_trit_fn() {
604        // Now allowed via implicit coercion
605        check_ok("fn f(a: trit) -> trit { return invert(a); } fn main() -> trit { let x: int = 42; return f(x); }");
606    }
607
608    #[test]
609    fn test_call_correct_arg_type() {
610        check_ok("fn f(a: trit) -> trit { return invert(a); }");
611    }
612
613    // ── Undefined function ────────────────────────────────────────────────────
614
615    #[test]
616    fn test_undefined_function_caught() {
617        check_err("fn f() -> trit { return doesnt_exist(1); }");
618    }
619
620    // ── User-defined function forward references ──────────────────────────────
621
622    #[test]
623    fn test_user_fn_return_type_registered() {
624        check_ok("fn helper(a: trit) -> trit { return invert(a); } fn main() -> trit { return helper(1); }");
625    }
626
627    #[test]
628    fn test_user_fn_int_return_ok() {
629        // Now allowed
630        check_ok("fn helper(a: trit) -> trit { let x: int = 1; return x; }");
631    }
632
633    // ── Undefined variable ────────────────────────────────────────────────────
634
635    #[test]
636    fn test_undefined_variable_caught() {
637        check_err("fn f() -> trit { return ghost_var; }");
638    }
639
640    #[test]
641    fn test_defined_variable_ok() {
642        check_ok("fn f() -> trit { let x: trit = 1; return x; }");
643    }
644
645    // ── Struct field types ────────────────────────────────────────────────────
646
647    #[test]
648    fn test_struct_field_access_ok() {
649        check_ok("struct S { val: trit } fn f(s: S) -> trit { return s.val; }");
650    }
651}