1use crate::ast::*;
2
3#[derive(Debug)]
6pub enum SemanticError {
7 TypeMismatch { expected: Type, found: Type },
8 UndefinedVariable(String),
9 UndefinedStruct(String),
10 UndefinedField { struct_name: String, field: String },
11 UndefinedFunction(String),
12 ReturnTypeMismatch { function: String, expected: Type, found: Type },
13 ArgCountMismatch { function: String, expected: usize, found: usize },
14 ArgTypeMismatch { function: String, param_index: usize, expected: Type, found: Type },
15 PropagateOnNonTrit { found: Type },
17}
18
19impl std::fmt::Display for SemanticError {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 match self {
22 Self::TypeMismatch { expected, found } =>
23 write!(f, "[TYPE-001] Type mismatch: expected {expected:?}, found {found:?}. A trit is not an int. An int is not a trit. They don't coerce.\n → details: stdlib/errors/TYPE-001.tern | ternlang errors TYPE-001"),
24 Self::UndefinedVariable(n) =>
25 write!(f, "[SCOPE-001] '{n}' is undefined — hold state. Declare before use, or check for a typo.\n → details: stdlib/errors/SCOPE-001.tern | ternlang errors SCOPE-001"),
26 Self::UndefinedStruct(n) =>
27 write!(f, "[STRUCT-001] Struct '{n}' doesn't exist. A ghost type — the type system can't find it anywhere.\n → details: stdlib/errors/STRUCT-001.tern | ternlang errors STRUCT-001"),
28 Self::UndefinedField { struct_name, field } =>
29 write!(f, "[STRUCT-002] Struct '{struct_name}' has no field '{field}'. Check the definition — maybe it was renamed.\n → details: stdlib/errors/STRUCT-002.tern | ternlang errors STRUCT-002"),
30 Self::UndefinedFunction(n) =>
31 write!(f, "[FN-001] '{n}' was called but never defined. Declare it above the call site, or check for a typo.\n → details: stdlib/errors/FN-001.tern | ternlang errors FN-001"),
32 Self::ReturnTypeMismatch { function, expected, found } =>
33 write!(f, "[FN-002] '{function}' promised to return {expected:?} but returned {found:?}. Ternary contracts are strict — all paths must match.\n → details: stdlib/errors/FN-002.tern | ternlang errors FN-002"),
34 Self::ArgCountMismatch { function, expected, found } =>
35 write!(f, "[FN-003] '{function}' expects {expected} arg(s), got {found}. Arity is not optional — not even in hold state.\n → details: stdlib/errors/FN-003.tern | ternlang errors FN-003"),
36 Self::ArgTypeMismatch { function, param_index, expected, found } =>
37 write!(f, "[FN-004] '{function}' arg {param_index}: expected {expected:?}, found {found:?}. Types travel with their values — they don't change at the border.\n → details: stdlib/errors/FN-004.tern | ternlang errors FN-004"),
38 Self::PropagateOnNonTrit { found } =>
39 write!(f, "[PROP-001] '?' used on a {found:?} expression. Only trit-returning functions carry the three-valued signal. The third state requires a trit.\n → details: stdlib/errors/PROP-001.tern | ternlang errors PROP-001"),
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
47pub struct FunctionSig {
48 pub params: Option<Vec<Type>>,
50 pub return_type: Type,
51}
52
53impl FunctionSig {
54 fn exact(params: Vec<Type>, return_type: Type) -> Self {
55 Self { params: Some(params), return_type }
56 }
57 fn variadic(return_type: Type) -> Self {
58 Self { params: None, return_type }
59 }
60}
61
62pub struct SemanticAnalyzer {
65 scopes: Vec<std::collections::HashMap<String, Type>>,
66 struct_defs: std::collections::HashMap<String, Vec<(String, Type)>>,
67 func_signatures: std::collections::HashMap<String, FunctionSig>,
68 current_fn_name: Option<String>,
70 current_fn_return: Option<Type>,
71}
72
73impl SemanticAnalyzer {
74 pub fn new() -> Self {
75 let mut sigs: std::collections::HashMap<String, FunctionSig> = std::collections::HashMap::new();
76
77 sigs.insert("consensus".into(), FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
79 sigs.insert("invert".into(), FunctionSig::exact(vec![Type::Trit], Type::Trit));
80 sigs.insert("length".into(), FunctionSig::variadic(Type::Int));
81 sigs.insert("truth".into(), FunctionSig::exact(vec![], Type::Trit));
82 sigs.insert("hold".into(), FunctionSig::exact(vec![], Type::Trit));
83 sigs.insert("conflict".into(), FunctionSig::exact(vec![], Type::Trit));
84 sigs.insert("mul".into(), FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
85
86 sigs.insert("matmul".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
88 sigs.insert("sparsity".into(), FunctionSig::variadic(Type::Int));
89 sigs.insert("shape".into(), FunctionSig::variadic(Type::Int));
90 sigs.insert("zeros".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
91
92 sigs.insert("print".into(), FunctionSig::variadic(Type::Trit));
94 sigs.insert("println".into(), FunctionSig::variadic(Type::Trit));
95
96 sigs.insert("abs".into(), FunctionSig::exact(vec![Type::Int], Type::Int));
98 sigs.insert("min".into(), FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
99 sigs.insert("max".into(), FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
100
101 sigs.insert("quantize".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
103 sigs.insert("threshold".into(),FunctionSig::variadic(Type::Float));
104
105 sigs.insert("forward".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
107 sigs.insert("argmax".into(), FunctionSig::variadic(Type::Int));
108
109 sigs.insert("cast".into(), FunctionSig::variadic(Type::Trit));
111
112 Self {
113 scopes: vec![std::collections::HashMap::new()],
114 struct_defs: std::collections::HashMap::new(),
115 func_signatures: sigs,
116 current_fn_name: None,
117 current_fn_return: None,
118 }
119 }
120
121 pub fn register_structs(&mut self, structs: &[StructDef]) {
124 for s in structs {
125 self.struct_defs.insert(s.name.clone(), s.fields.clone());
126 }
127 }
128
129 pub fn register_functions(&mut self, functions: &[Function]) {
130 for f in functions {
131 let params = f.params.iter().map(|(_, ty)| ty.clone()).collect();
132 self.func_signatures.insert(
133 f.name.clone(),
134 FunctionSig::exact(params, f.return_type.clone()),
135 );
136 }
137 }
138
139 pub fn register_agents(&mut self, agents: &[AgentDef]) {
140 for agent in agents {
141 for method in &agent.methods {
142 let params = method.params.iter().map(|(_, ty)| ty.clone()).collect();
143 let sig = FunctionSig::exact(params, method.return_type.clone());
144 self.func_signatures.insert(method.name.clone(), sig.clone());
145 self.func_signatures.insert(
146 format!("{}::{}", agent.name, method.name),
147 sig,
148 );
149 }
150 }
151 }
152
153 pub fn check_program(&mut self, program: &Program) -> Result<(), SemanticError> {
156 self.register_structs(&program.structs);
157 self.register_functions(&program.functions);
158 self.register_agents(&program.agents);
159 for agent in &program.agents {
160 for method in &agent.methods {
161 self.check_function(method)?;
162 }
163 }
164 for func in &program.functions {
165 self.check_function(func)?;
166 }
167 Ok(())
168 }
169
170 fn check_function(&mut self, func: &Function) -> Result<(), SemanticError> {
171 let prev_name = self.current_fn_name.take();
173 let prev_return = self.current_fn_return.take();
174 self.current_fn_name = Some(func.name.clone());
175 self.current_fn_return = Some(func.return_type.clone());
176
177 self.scopes.push(std::collections::HashMap::new());
178 for (name, ty) in &func.params {
179 self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
180 }
181 for stmt in &func.body {
182 self.check_stmt(stmt)?;
183 }
184 self.scopes.pop();
185
186 self.current_fn_name = prev_name;
188 self.current_fn_return = prev_return;
189 Ok(())
190 }
191
192 pub fn check_stmt(&mut self, stmt: &Stmt) -> Result<(), SemanticError> {
195 match stmt {
196 Stmt::Let { name, ty, value } => {
197 let val_ty = self.infer_expr_type(value)?;
198 let type_ok = val_ty == *ty
199 || matches!(value, Expr::Cast { .. })
200 || (*ty == Type::Int && val_ty == Type::Trit)
201 || (*ty == Type::Trit && val_ty == Type::Int)
202 || (matches!(ty, Type::Named(_)) && val_ty == Type::Trit)
203 || (matches!(ty, Type::TritTensor { .. }) && matches!(val_ty, Type::TritTensor { .. }))
204 || (*ty == Type::AgentRef && val_ty == Type::AgentRef);
205 if !type_ok {
206 return Err(SemanticError::TypeMismatch { expected: ty.clone(), found: val_ty });
207 }
208 self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
209 Ok(())
210 }
211
212 Stmt::Return(expr) => {
213 let found = self.infer_expr_type(expr)?;
214 if let (Some(fn_name), Some(expected)) = (&self.current_fn_name, &self.current_fn_return) {
215 let ok = found == *expected
217 || matches!(expr, Expr::Cast { .. })
218 || (*expected == Type::Int && found == Type::Trit)
219 || (*expected == Type::Trit && found == Type::Int)
220 || (matches!(expected, Type::TritTensor { .. }) && matches!(found, Type::TritTensor { .. }))
221 || (matches!(expected, Type::Named(_)) && found == Type::Trit);
222 if !ok {
223 return Err(SemanticError::ReturnTypeMismatch {
224 function: fn_name.clone(),
225 expected: expected.clone(),
226 found,
227 });
228 }
229 }
230 Ok(())
231 }
232
233 Stmt::IfTernary { condition, on_pos, on_zero, on_neg } => {
234 let cond_ty = self.infer_expr_type(condition)?;
235 if cond_ty != Type::Trit {
236 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
237 }
238 self.check_stmt(on_pos)?;
239 self.check_stmt(on_zero)?;
240 self.check_stmt(on_neg)?;
241 Ok(())
242 }
243
244 Stmt::Match { condition, arms } => {
245 let cond_ty = self.infer_expr_type(condition)?;
246 if cond_ty != Type::Trit && cond_ty != Type::Int {
247 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
248 }
249
250 if cond_ty == Type::Trit {
251 let has_pos = arms.iter().any(|(v, _)| *v == 1);
253 let has_zero = arms.iter().any(|(v, _)| *v == 0);
254 let has_neg = arms.iter().any(|(v, _)| *v == -1);
255 if !has_pos || !has_zero || !has_neg {
256 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
259 }
260 for (val, _) in arms {
261 if *val < -1 || *val > 1 {
262 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: Type::Int });
263 }
264 }
265 }
266
267 for (_val, arm_stmt) in arms {
268 self.check_stmt(arm_stmt)?;
269 }
270 Ok(())
271 }
272
273 Stmt::Block(stmts) => {
274 self.scopes.push(std::collections::HashMap::new());
275 for s in stmts { self.check_stmt(s)?; }
276 self.scopes.pop();
277 Ok(())
278 }
279
280 Stmt::Decorated { stmt, .. } => self.check_stmt(stmt),
281
282 Stmt::Expr(expr) => { self.infer_expr_type(expr)?; Ok(()) }
283
284 Stmt::ForIn { var, iter, body } => {
285 self.infer_expr_type(iter)?;
286 self.scopes.push(std::collections::HashMap::new());
287 self.scopes.last_mut().unwrap().insert(var.clone(), Type::Trit);
288 self.check_stmt(body)?;
289 self.scopes.pop();
290 Ok(())
291 }
292
293 Stmt::WhileTernary { condition, on_pos, on_zero, on_neg } => {
294 let cond_ty = self.infer_expr_type(condition)?;
295 if cond_ty != Type::Trit {
296 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
297 }
298 self.check_stmt(on_pos)?;
299 self.check_stmt(on_zero)?;
300 self.check_stmt(on_neg)?;
301 Ok(())
302 }
303
304 Stmt::Loop { body } => self.check_stmt(body),
305 Stmt::Break => Ok(()),
306 Stmt::Continue => Ok(()),
307 Stmt::Use { .. } => Ok(()),
308
309 Stmt::Send { target, message } => {
310 self.infer_expr_type(target)?;
311 self.infer_expr_type(message)?;
312 Ok(())
313 }
314
315 Stmt::FieldSet { object, field, value } => {
316 let obj_ty = self.lookup_var(object)?;
317 if let Type::Named(struct_name) = obj_ty {
318 let field_ty = self.lookup_field(&struct_name, field)?;
319 let val_ty = self.infer_expr_type(value)?;
320 if val_ty != field_ty {
321 return Err(SemanticError::TypeMismatch { expected: field_ty, found: val_ty });
322 }
323 } else {
324 self.infer_expr_type(value)?;
325 }
326 Ok(())
327 }
328
329 Stmt::IndexSet { object, row, col, value } => {
330 self.lookup_var(object)?;
331 self.infer_expr_type(row)?;
332 self.infer_expr_type(col)?;
333 self.infer_expr_type(value)?;
334 Ok(())
335 }
336
337 Stmt::Set { name, value } => {
338 let var_ty = self.lookup_var(name)?;
339 let val_ty = self.infer_expr_type(value)?;
340 let ok = var_ty == val_ty
341 || matches!(value, Expr::Cast { .. })
342 || (var_ty == Type::Int && val_ty == Type::Trit)
343 || (var_ty == Type::Trit && val_ty == Type::Int);
344 if !ok {
345 return Err(SemanticError::TypeMismatch { expected: var_ty, found: val_ty });
346 }
347 Ok(())
348 }
349 }
350 }
351
352 fn infer_expr_type(&self, expr: &Expr) -> Result<Type, SemanticError> {
355 match expr {
356 Expr::TritLiteral(_) => Ok(Type::Trit),
357 Expr::IntLiteral(_) => Ok(Type::Int),
358 Expr::FloatLiteral(_) => Ok(Type::Float),
359 Expr::StringLiteral(_) => Ok(Type::String),
360 Expr::Ident(name) => self.lookup_var(name),
361
362 Expr::BinaryOp { op, lhs, rhs } => {
363 let l = self.infer_expr_type(lhs)?;
364 let r = self.infer_expr_type(rhs)?;
365 match op {
366 BinOp::Less | BinOp::Greater | BinOp::LessEqual | BinOp::GreaterEqual | BinOp::Equal | BinOp::NotEqual | BinOp::And | BinOp::Or => {
367 Ok(Type::Trit)
368 }
369
370 _ => {
371 let is_numeric = |t: &Type| matches!(t, Type::Int | Type::Trit | Type::Float);
373 if is_numeric(&l) && is_numeric(&r) {
374 if l == Type::Float || r == Type::Float { return Ok(Type::Float); }
375 if l == Type::Int || r == Type::Int { return Ok(Type::Int); }
376 return Ok(Type::Trit);
377 }
378
379 if l != r {
380 return Err(SemanticError::TypeMismatch { expected: l, found: r });
381 }
382 Ok(l)
383 }
384 }
385 }
386
387 Expr::UnaryOp { expr, .. } => self.infer_expr_type(expr),
388
389 Expr::Call { callee, args } => {
390 let sig = self.func_signatures.get(callee.as_str())
391 .ok_or_else(|| SemanticError::UndefinedFunction(callee.clone()))?
392 .clone();
393
394 if let Some(param_types) = &sig.params {
396 if args.len() != param_types.len() {
397 return Err(SemanticError::ArgCountMismatch {
398 function: callee.clone(),
399 expected: param_types.len(),
400 found: args.len(),
401 });
402 }
403 for (i, (arg, expected_ty)) in args.iter().zip(param_types.iter()).enumerate() {
404 let found_ty = self.infer_expr_type(arg)?;
405 let ok = found_ty == *expected_ty
407 || matches!(arg, Expr::Cast { .. })
408 || (expected_ty == &Type::Int && found_ty == Type::Trit)
409 || (expected_ty == &Type::Trit && found_ty == Type::Int)
410 || (matches!(expected_ty, Type::TritTensor { .. })
411 && matches!(found_ty, Type::TritTensor { .. }))
412 || (matches!(expected_ty, Type::Named(_)) && found_ty == Type::Trit);
413 if !ok {
414 return Err(SemanticError::ArgTypeMismatch {
415 function: callee.clone(),
416 param_index: i,
417 expected: expected_ty.clone(),
418 found: found_ty,
419 });
420 }
421 }
422 } else {
423 for arg in args { self.infer_expr_type(arg)?; }
425 }
426
427 Ok(sig.return_type)
428 }
429
430 Expr::Cast { ty, .. } => Ok(ty.clone()),
431 Expr::Spawn { .. } => Ok(Type::AgentRef),
432 Expr::Await { .. } => Ok(Type::Trit),
433 Expr::NodeId => Ok(Type::String),
434
435 Expr::Propagate { expr } => {
436 let inner = self.infer_expr_type(expr)?;
437 if inner != Type::Trit {
438 return Err(SemanticError::PropagateOnNonTrit { found: inner });
439 }
440 Ok(Type::Trit)
441 }
442
443 Expr::TritTensorLiteral(vals) => {
444 Ok(Type::TritTensor { dims: vec![vals.len()] })
445 }
446
447 Expr::FieldAccess { object, field } => {
448 let obj_ty = self.infer_expr_type(object)?;
449 if let Type::Named(struct_name) = obj_ty {
450 self.lookup_field(&struct_name, field)
451 } else {
452 Ok(Type::Trit)
453 }
454 }
455
456 Expr::Index { object, row, col } => {
457 self.infer_expr_type(object)?;
458 self.infer_expr_type(row)?;
459 self.infer_expr_type(col)?;
460 Ok(Type::Trit)
461 }
462 }
463 }
464
465 fn lookup_var(&self, name: &str) -> Result<Type, SemanticError> {
468 for scope in self.scopes.iter().rev() {
469 if let Some(ty) = scope.get(name) { return Ok(ty.clone()); }
470 }
471 Err(SemanticError::UndefinedVariable(name.to_string()))
472 }
473
474 fn lookup_field(&self, struct_name: &str, field: &str) -> Result<Type, SemanticError> {
475 let fields = self.struct_defs.get(struct_name)
476 .ok_or_else(|| SemanticError::UndefinedStruct(struct_name.to_string()))?;
477 fields.iter()
478 .find(|(f, _)| f == field)
479 .map(|(_, ty)| ty.clone())
480 .ok_or_else(|| SemanticError::UndefinedField {
481 struct_name: struct_name.to_string(),
482 field: field.to_string(),
483 })
484 }
485}
486
487#[cfg(test)]
490mod tests {
491 use super::*;
492 use crate::parser::Parser;
493
494 fn check(src: &str) -> Result<(), SemanticError> {
495 let mut parser = Parser::new(src);
496 let prog = parser.parse_program().expect("parse failed");
497 let mut analyzer = SemanticAnalyzer::new();
498 analyzer.check_program(&prog)
499 }
500
501 fn check_ok(src: &str) {
502 assert!(check(src).is_ok(), "expected ok, got: {:?}", check(src));
503 }
504
505 fn check_err(src: &str) {
506 assert!(check(src).is_err(), "expected error but check passed");
507 }
508
509 #[test]
512 fn test_return_correct_type() {
513 check_ok("fn f() -> trit { return 1; }");
514 }
515
516 #[test]
517 fn test_return_int_in_trit_fn() {
518 check_ok("fn f() -> trit { let x: int = 42; return x; }");
520 }
521
522 #[test]
523 fn test_return_trit_in_trit_fn() {
524 check_ok("fn decide(a: trit, b: trit) -> trit { return consensus(a, b); }");
525 }
526
527 #[test]
530 fn test_call_correct_arity() {
531 check_ok("fn f() -> trit { return consensus(1, -1); }");
532 }
533
534 #[test]
535 fn test_call_too_few_args_caught() {
536 check_err("fn f() -> trit { return consensus(1); }");
537 }
538
539 #[test]
540 fn test_call_too_many_args_caught() {
541 check_err("fn f() -> trit { return invert(1, 1); }");
542 }
543
544 #[test]
547 fn test_call_int_arg_in_trit_fn() {
548 check_ok("fn f(a: trit) -> trit { return invert(a); } fn main() -> trit { let x: int = 42; return f(x); }");
550 }
551
552 #[test]
553 fn test_call_correct_arg_type() {
554 check_ok("fn f(a: trit) -> trit { return invert(a); }");
555 }
556
557 #[test]
560 fn test_undefined_function_caught() {
561 check_err("fn f() -> trit { return doesnt_exist(1); }");
562 }
563
564 #[test]
567 fn test_user_fn_return_type_registered() {
568 check_ok("fn helper(a: trit) -> trit { return invert(a); } fn main() -> trit { return helper(1); }");
569 }
570
571 #[test]
572 fn test_user_fn_int_return_ok() {
573 check_ok("fn helper(a: trit) -> trit { let x: int = 1; return x; }");
575 }
576
577 #[test]
580 fn test_undefined_variable_caught() {
581 check_err("fn f() -> trit { return ghost_var; }");
582 }
583
584 #[test]
585 fn test_defined_variable_ok() {
586 check_ok("fn f() -> trit { let x: trit = 1; return x; }");
587 }
588
589 #[test]
592 fn test_struct_field_access_ok() {
593 check_ok("struct S { val: trit } fn f(s: S) -> trit { return s.val; }");
594 }
595}