1use crate::ast::*;
2
3#[derive(Debug)]
6pub enum SemanticError {
7 TypeMismatch { expected: Type, found: Type },
8 UndefinedVariable(String),
9 UndefinedStruct(String),
10 UndefinedField { struct_name: String, field: String },
11 UndefinedFunction(String),
12 ReturnTypeMismatch { function: String, expected: Type, found: Type },
13 ArgCountMismatch { function: String, expected: usize, found: usize },
14 ArgTypeMismatch { function: String, param_index: usize, expected: Type, found: Type },
15 PropagateOnNonTrit { found: Type },
17 NonExhaustiveMatch(String),
18}
19
20impl std::fmt::Display for SemanticError {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 Self::TypeMismatch { expected, found } =>
24 write!(f, "[TYPE-001] Type mismatch: expected {expected:?}, found {found:?}. A trit is not an int. An int is not a trit. They don't coerce.\n → details: stdlib/errors/TYPE-001.tern | ternlang errors TYPE-001"),
25 Self::UndefinedVariable(n) =>
26 write!(f, "[SCOPE-001] '{n}' is undefined — hold state. Declare before use, or check for a typo.\n → details: stdlib/errors/SCOPE-001.tern | ternlang errors SCOPE-001"),
27 Self::UndefinedStruct(n) =>
28 write!(f, "[STRUCT-001] Struct '{n}' doesn't exist. A ghost type — the type system can't find it anywhere.\n → details: stdlib/errors/STRUCT-001.tern | ternlang errors STRUCT-001"),
29 Self::UndefinedField { struct_name, field } =>
30 write!(f, "[STRUCT-002] Struct '{struct_name}' has no field '{field}'. Check the definition — maybe it was renamed.\n → details: stdlib/errors/STRUCT-002.tern | ternlang errors STRUCT-002"),
31 Self::UndefinedFunction(n) =>
32 write!(f, "[FN-001] '{n}' was called but never defined. Declare it above the call site, or check for a typo.\n → details: stdlib/errors/FN-001.tern | ternlang errors FN-001"),
33 Self::ReturnTypeMismatch { function, expected, found } =>
34 write!(f, "[FN-002] '{function}' promised to return {expected:?} but returned {found:?}. Ternary contracts are strict — all paths must match.\n → details: stdlib/errors/FN-002.tern | ternlang errors FN-002"),
35 Self::ArgCountMismatch { function, expected, found } =>
36 write!(f, "[FN-003] '{function}' expects {expected} arg(s), got {found}. Arity is not optional — not even in hold state.\n → details: stdlib/errors/FN-003.tern | ternlang errors FN-003"),
37 Self::ArgTypeMismatch { function, param_index, expected, found } =>
38 write!(f, "[FN-004] '{function}' arg {param_index}: expected {expected:?}, found {found:?}. Types travel with their values — they don't change at the border.\n → details: stdlib/errors/FN-004.tern | ternlang errors FN-004"),
39 Self::PropagateOnNonTrit { found } =>
40 write!(f, "[PROP-001] '?' used on a {found:?} expression. Only trit-returning functions carry the three-valued signal. The third state requires a trit.\n → details: stdlib/errors/PROP-001.tern | ternlang errors PROP-001"),
41 Self::NonExhaustiveMatch(msg) =>
42 write!(f, "Non-exhaustive match: {msg}"),
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
50pub struct FunctionSig {
51 pub params: Option<Vec<Type>>,
53 pub return_type: Type,
54}
55
56impl FunctionSig {
57 fn exact(params: Vec<Type>, return_type: Type) -> Self {
58 Self { params: Some(params), return_type }
59 }
60 fn variadic(return_type: Type) -> Self {
61 Self { params: None, return_type }
62 }
63}
64
65pub struct SemanticAnalyzer {
68 scopes: Vec<std::collections::HashMap<String, Type>>,
69 struct_defs: std::collections::HashMap<String, Vec<(String, Type)>>,
70 func_signatures: std::collections::HashMap<String, FunctionSig>,
71 current_fn_name: Option<String>,
73 current_fn_return: Option<Type>,
74}
75
76impl SemanticAnalyzer {
77 pub fn new() -> Self {
78 let mut sigs: std::collections::HashMap<String, FunctionSig> = std::collections::HashMap::new();
79
80 sigs.insert("consensus".into(), FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
82 sigs.insert("invert".into(), FunctionSig::exact(vec![Type::Trit], Type::Trit));
83 sigs.insert("length".into(), FunctionSig::variadic(Type::Int));
84 sigs.insert("truth".into(), FunctionSig::exact(vec![], Type::Trit));
85 sigs.insert("hold".into(), FunctionSig::exact(vec![], Type::Trit));
86 sigs.insert("conflict".into(), FunctionSig::exact(vec![], Type::Trit));
87 sigs.insert("mul".into(), FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
88
89 sigs.insert("matmul".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
91 sigs.insert("sparsity".into(), FunctionSig::variadic(Type::Int));
92 sigs.insert("shape".into(), FunctionSig::variadic(Type::Int));
93 sigs.insert("zeros".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
94
95 sigs.insert("print".into(), FunctionSig::variadic(Type::Trit));
97 sigs.insert("println".into(), FunctionSig::variadic(Type::Trit));
98
99 sigs.insert("abs".into(), FunctionSig::exact(vec![Type::Int], Type::Int));
101 sigs.insert("min".into(), FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
102 sigs.insert("max".into(), FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
103
104 sigs.insert("quantize".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
106 sigs.insert("threshold".into(),FunctionSig::variadic(Type::Float));
107
108 sigs.insert("forward".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
110 sigs.insert("argmax".into(), FunctionSig::variadic(Type::Int));
111
112 sigs.insert("cast".into(), FunctionSig::variadic(Type::Trit));
114
115 Self {
116 scopes: vec![std::collections::HashMap::new()],
117 struct_defs: std::collections::HashMap::new(),
118 func_signatures: sigs,
119 current_fn_name: None,
120 current_fn_return: None,
121 }
122 }
123
124 pub fn register_structs(&mut self, structs: &[StructDef]) {
127 for s in structs {
128 self.struct_defs.insert(s.name.clone(), s.fields.clone());
129 }
130 }
131
132 pub fn register_functions(&mut self, functions: &[Function]) {
133 for f in functions {
134 let params = f.params.iter().map(|(_, ty)| ty.clone()).collect();
135 self.func_signatures.insert(
136 f.name.clone(),
137 FunctionSig::exact(params, f.return_type.clone()),
138 );
139 }
140 }
141
142 pub fn register_agents(&mut self, agents: &[AgentDef]) {
143 for agent in agents {
144 for method in &agent.methods {
145 let params = method.params.iter().map(|(_, ty)| ty.clone()).collect();
146 let sig = FunctionSig::exact(params, method.return_type.clone());
147 self.func_signatures.insert(method.name.clone(), sig.clone());
148 self.func_signatures.insert(
149 format!("{}::{}", agent.name, method.name),
150 sig,
151 );
152 }
153 }
154 }
155
156 pub fn check_program(&mut self, program: &Program) -> Result<(), SemanticError> {
159 self.register_structs(&program.structs);
160 self.register_functions(&program.functions);
161 self.register_agents(&program.agents);
162 for agent in &program.agents {
163 for method in &agent.methods {
164 self.check_function(method)?;
165 }
166 }
167 for func in &program.functions {
168 self.check_function(func)?;
169 }
170 Ok(())
171 }
172
173 fn check_function(&mut self, func: &Function) -> Result<(), SemanticError> {
174 let prev_name = self.current_fn_name.take();
176 let prev_return = self.current_fn_return.take();
177 self.current_fn_name = Some(func.name.clone());
178 self.current_fn_return = Some(func.return_type.clone());
179
180 self.scopes.push(std::collections::HashMap::new());
181 for (name, ty) in &func.params {
182 self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
183 }
184 for stmt in &func.body {
185 self.check_stmt(stmt)?;
186 }
187 self.scopes.pop();
188
189 self.current_fn_name = prev_name;
191 self.current_fn_return = prev_return;
192 Ok(())
193 }
194
195 pub fn check_stmt(&mut self, stmt: &Stmt) -> Result<(), SemanticError> {
198 match stmt {
199 Stmt::Let { name, ty, value } => {
200 let val_ty = self.infer_expr_type(value)?;
201 let type_ok = val_ty == *ty
202 || matches!(value, Expr::Cast { .. })
203 || matches!(value, Expr::StructLiteral { .. }) || (*ty == Type::Int && val_ty == Type::Trit)
205 || (*ty == Type::Trit && val_ty == Type::Int)
206 || (matches!(ty, Type::Named(_)) && val_ty == Type::Trit)
207 || (matches!(ty, Type::TritTensor { .. }) && matches!(val_ty, Type::TritTensor { .. }))
208 || (*ty == Type::AgentRef && val_ty == Type::AgentRef);
209 if !type_ok {
210 return Err(SemanticError::TypeMismatch { expected: ty.clone(), found: val_ty });
211 }
212 self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
213 Ok(())
214 }
215
216 Stmt::Return(expr) => {
217 let found = self.infer_expr_type(expr)?;
218 if let (Some(fn_name), Some(expected)) = (&self.current_fn_name, &self.current_fn_return) {
219 let ok = found == *expected
221 || matches!(expr, Expr::Cast { .. })
222 || matches!(expr, Expr::StructLiteral { .. })
223 || (*expected == Type::Int && found == Type::Trit)
224 || (*expected == Type::Trit && found == Type::Int)
225 || (matches!(expected, Type::TritTensor { .. }) && matches!(found, Type::TritTensor { .. }))
226 || (matches!(expected, Type::Named(_)) && found == Type::Trit);
227 if !ok {
228 return Err(SemanticError::ReturnTypeMismatch {
229 function: fn_name.clone(),
230 expected: expected.clone(),
231 found,
232 });
233 }
234 }
235 Ok(())
236 }
237
238 Stmt::IfTernary { condition, on_pos, on_zero, on_neg } => {
239 let cond_ty = self.infer_expr_type(condition)?;
240 if cond_ty != Type::Trit {
241 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
242 }
243 self.check_stmt(on_pos)?;
244 self.check_stmt(on_zero)?;
245 self.check_stmt(on_neg)?;
246 Ok(())
247 }
248
249 Stmt::Match { condition, arms } => {
250 let cond_ty = self.infer_expr_type(condition)?;
251 if cond_ty != Type::Trit && cond_ty != Type::Int && cond_ty != Type::Float {
252 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
253 }
254
255 if cond_ty == Type::Trit {
256 let has_pos = arms.iter().any(|(p, _)| matches!(p, Pattern::Trit(1) | Pattern::Int(1)));
258 let has_zero = arms.iter().any(|(p, _)| matches!(p, Pattern::Trit(0) | Pattern::Int(0)));
259 let has_neg = arms.iter().any(|(p, _)| matches!(p, Pattern::Trit(-1) | Pattern::Int(-1)));
260 if !has_pos || !has_zero || !has_neg {
261 return Err(SemanticError::NonExhaustiveMatch("Trit match must cover -1, 0, and 1".into()));
262 }
263 for (pattern, _) in arms {
264 match pattern {
265 Pattern::Trit(v) => if *v < -1 || *v > 1 { return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: Type::Int }); }
266 Pattern::Int(v) => if *v < -1 || *v > 1 { return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: Type::Int }); }
267 Pattern::Float(_) => return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: Type::Float }),
268 }
269 }
270 }
271
272 for (_pattern, arm_stmt) in arms {
273 self.check_stmt(arm_stmt)?;
274 }
275 Ok(())
276 }
277
278 Stmt::Block(stmts) => {
279 self.scopes.push(std::collections::HashMap::new());
280 for s in stmts { self.check_stmt(s)?; }
281 self.scopes.pop();
282 Ok(())
283 }
284
285 Stmt::Decorated { stmt, .. } => self.check_stmt(stmt),
286
287 Stmt::Expr(expr) => { self.infer_expr_type(expr)?; Ok(()) }
288
289 Stmt::ForIn { var, iter, body } => {
290 self.infer_expr_type(iter)?;
291 self.scopes.push(std::collections::HashMap::new());
292 self.scopes.last_mut().unwrap().insert(var.clone(), Type::Trit);
293 self.check_stmt(body)?;
294 self.scopes.pop();
295 Ok(())
296 }
297
298 Stmt::WhileTernary { condition, on_pos, on_zero, on_neg } => {
299 let cond_ty = self.infer_expr_type(condition)?;
300 if cond_ty != Type::Trit {
301 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
302 }
303 self.check_stmt(on_pos)?;
304 self.check_stmt(on_zero)?;
305 self.check_stmt(on_neg)?;
306 Ok(())
307 }
308
309 Stmt::Loop { body } => self.check_stmt(body),
310 Stmt::Break => Ok(()),
311 Stmt::Continue => Ok(()),
312 Stmt::Use { .. } => Ok(()),
313 Stmt::FromImport { .. } => Ok(()),
314
315 Stmt::Send { target, message } => {
316 self.infer_expr_type(target)?;
317 self.infer_expr_type(message)?;
318 Ok(())
319 }
320
321 Stmt::FieldSet { object, field, value } => {
322 let obj_ty = self.lookup_var(object)?;
323 if let Type::Named(struct_name) = obj_ty {
324 let field_ty = self.lookup_field(&struct_name, field)?;
325 let val_ty = self.infer_expr_type(value)?;
326 if val_ty != field_ty {
327 return Err(SemanticError::TypeMismatch { expected: field_ty, found: val_ty });
328 }
329 } else {
330 self.infer_expr_type(value)?;
331 }
332 Ok(())
333 }
334
335 Stmt::IndexSet { object, row, col, value } => {
336 self.lookup_var(object)?;
337 self.infer_expr_type(row)?;
338 self.infer_expr_type(col)?;
339 self.infer_expr_type(value)?;
340 Ok(())
341 }
342
343 Stmt::Set { name, value } => {
344 let var_ty = self.lookup_var(name)?;
345 let val_ty = self.infer_expr_type(value)?;
346 let ok = var_ty == val_ty
347 || matches!(value, Expr::Cast { .. })
348 || (var_ty == Type::Int && val_ty == Type::Trit)
349 || (var_ty == Type::Trit && val_ty == Type::Int);
350 if !ok {
351 return Err(SemanticError::TypeMismatch { expected: var_ty, found: val_ty });
352 }
353 Ok(())
354 }
355 }
356 }
357
358 fn infer_expr_type(&self, expr: &Expr) -> Result<Type, SemanticError> {
361 match expr {
362 Expr::TritLiteral(_) => Ok(Type::Trit),
363 Expr::IntLiteral(_) => Ok(Type::Int),
364 Expr::FloatLiteral(_) => Ok(Type::Float),
365 Expr::StringLiteral(_) => Ok(Type::String),
366 Expr::Ident(name) => self.lookup_var(name),
367
368 Expr::BinaryOp { op, lhs, rhs } => {
369 let l = self.infer_expr_type(lhs)?;
370 let r = self.infer_expr_type(rhs)?;
371 match op {
372 BinOp::Less | BinOp::Greater | BinOp::LessEqual | BinOp::GreaterEqual | BinOp::Equal | BinOp::NotEqual | BinOp::And | BinOp::Or => {
373 Ok(Type::Trit)
374 }
375
376 _ => {
377 let is_numeric = |t: &Type| matches!(t, Type::Int | Type::Trit | Type::Float);
379 if is_numeric(&l) && is_numeric(&r) {
380 if l == Type::Float || r == Type::Float { return Ok(Type::Float); }
381 if l == Type::Int || r == Type::Int { return Ok(Type::Int); }
382 return Ok(Type::Trit);
383 }
384
385 if l != r {
386 return Err(SemanticError::TypeMismatch { expected: l, found: r });
387 }
388 Ok(l)
389 }
390 }
391 }
392
393 Expr::UnaryOp { expr, .. } => self.infer_expr_type(expr),
394
395 Expr::Call { callee, args } => {
396 let sig = self.func_signatures.get(callee.as_str())
397 .ok_or_else(|| SemanticError::UndefinedFunction(callee.clone()))?
398 .clone();
399
400 if let Some(param_types) = &sig.params {
402 if args.len() != param_types.len() {
403 return Err(SemanticError::ArgCountMismatch {
404 function: callee.clone(),
405 expected: param_types.len(),
406 found: args.len(),
407 });
408 }
409 for (i, (arg, expected_ty)) in args.iter().zip(param_types.iter()).enumerate() {
410 let found_ty = self.infer_expr_type(arg)?;
411 let ok = found_ty == *expected_ty
413 || matches!(arg, Expr::Cast { .. })
414 || (expected_ty == &Type::Int && found_ty == Type::Trit)
415 || (expected_ty == &Type::Trit && found_ty == Type::Int)
416 || (matches!(expected_ty, Type::TritTensor { .. })
417 && matches!(found_ty, Type::TritTensor { .. }))
418 || (matches!(expected_ty, Type::Named(_)) && found_ty == Type::Trit);
419 if !ok {
420 return Err(SemanticError::ArgTypeMismatch {
421 function: callee.clone(),
422 param_index: i,
423 expected: expected_ty.clone(),
424 found: found_ty,
425 });
426 }
427 }
428 } else {
429 for arg in args { self.infer_expr_type(arg)?; }
431 }
432
433 Ok(sig.return_type)
434 }
435
436 Expr::Cast { ty, .. } => Ok(ty.clone()),
437 Expr::Spawn { .. } => Ok(Type::AgentRef),
438 Expr::Await { .. } => Ok(Type::Trit),
439 Expr::NodeId => Ok(Type::String),
440
441 Expr::Propagate { expr } => {
442 let inner = self.infer_expr_type(expr)?;
443 if inner != Type::Trit {
444 return Err(SemanticError::PropagateOnNonTrit { found: inner });
445 }
446 Ok(Type::Trit)
447 }
448
449 Expr::TritTensorLiteral(vals) => {
450 Ok(Type::TritTensor { dims: vec![vals.len()] })
451 }
452
453 Expr::StructLiteral { name, fields } => {
454 let def = self.struct_defs.get(name)
456 .ok_or_else(|| SemanticError::UndefinedStruct(name.clone()))?;
457
458 if fields.len() != def.len() {
459 return Err(SemanticError::ArgCountMismatch {
460 function: name.clone(),
461 expected: def.len(),
462 found: fields.len()
463 });
464 }
465
466 for (f_name, f_val) in fields {
467 let expected_f_ty = def.iter()
468 .find(|(n, _)| n == f_name)
469 .ok_or_else(|| SemanticError::UndefinedField {
470 struct_name: name.clone(),
471 field: f_name.clone()
472 })?
473 .1.clone();
474 let found_f_ty = self.infer_expr_type(f_val)?;
475 if found_f_ty != expected_f_ty {
476 return Err(SemanticError::TypeMismatch {
477 expected: expected_f_ty,
478 found: found_f_ty
479 });
480 }
481 }
482 Ok(Type::Named(name.clone()))
483 }
484
485 Expr::FieldAccess { object, field } => {
486 let obj_ty = self.infer_expr_type(object)?;
487 if let Type::Named(struct_name) = obj_ty {
488 self.lookup_field(&struct_name, field)
489 } else {
490 Ok(Type::Trit)
491 }
492 }
493
494 Expr::Index { object, row, col } => {
495 self.infer_expr_type(object)?;
496 self.infer_expr_type(row)?;
497 self.infer_expr_type(col)?;
498 Ok(Type::Trit)
499 }
500 }
501 }
502
503 fn lookup_var(&self, name: &str) -> Result<Type, SemanticError> {
506 for scope in self.scopes.iter().rev() {
507 if let Some(ty) = scope.get(name) { return Ok(ty.clone()); }
508 }
509 Err(SemanticError::UndefinedVariable(name.to_string()))
510 }
511
512 fn lookup_field(&self, struct_name: &str, field: &str) -> Result<Type, SemanticError> {
513 let fields = self.struct_defs.get(struct_name)
514 .ok_or_else(|| SemanticError::UndefinedStruct(struct_name.to_string()))?;
515 fields.iter()
516 .find(|(f, _)| f == field)
517 .map(|(_, ty)| ty.clone())
518 .ok_or_else(|| SemanticError::UndefinedField {
519 struct_name: struct_name.to_string(),
520 field: field.to_string(),
521 })
522 }
523}
524
525#[cfg(test)]
528mod tests {
529 use super::*;
530 use crate::parser::Parser;
531
532 fn check(src: &str) -> Result<(), SemanticError> {
533 let mut parser = Parser::new(src);
534 let prog = parser.parse_program().expect("parse failed");
535 let mut analyzer = SemanticAnalyzer::new();
536 analyzer.check_program(&prog)
537 }
538
539 fn check_ok(src: &str) {
540 assert!(check(src).is_ok(), "expected ok, got: {:?}", check(src));
541 }
542
543 fn check_err(src: &str) {
544 assert!(check(src).is_err(), "expected error but check passed");
545 }
546
547 #[test]
550 fn test_return_correct_type() {
551 check_ok("fn f() -> trit { return 1; }");
552 }
553
554 #[test]
555 fn test_return_int_in_trit_fn() {
556 check_ok("fn f() -> trit { let x: int = 42; return x; }");
558 }
559
560 #[test]
561 fn test_return_trit_in_trit_fn() {
562 check_ok("fn decide(a: trit, b: trit) -> trit { return consensus(a, b); }");
563 }
564
565 #[test]
568 fn test_call_correct_arity() {
569 check_ok("fn f() -> trit { return consensus(1, -1); }");
570 }
571
572 #[test]
573 fn test_call_too_few_args_caught() {
574 check_err("fn f() -> trit { return consensus(1); }");
575 }
576
577 #[test]
578 fn test_call_too_many_args_caught() {
579 check_err("fn f() -> trit { return invert(1, 1); }");
580 }
581
582 #[test]
585 fn test_call_int_arg_in_trit_fn() {
586 check_ok("fn f(a: trit) -> trit { return invert(a); } fn main() -> trit { let x: int = 42; return f(x); }");
588 }
589
590 #[test]
591 fn test_call_correct_arg_type() {
592 check_ok("fn f(a: trit) -> trit { return invert(a); }");
593 }
594
595 #[test]
598 fn test_undefined_function_caught() {
599 check_err("fn f() -> trit { return doesnt_exist(1); }");
600 }
601
602 #[test]
605 fn test_user_fn_return_type_registered() {
606 check_ok("fn helper(a: trit) -> trit { return invert(a); } fn main() -> trit { return helper(1); }");
607 }
608
609 #[test]
610 fn test_user_fn_int_return_ok() {
611 check_ok("fn helper(a: trit) -> trit { let x: int = 1; return x; }");
613 }
614
615 #[test]
618 fn test_undefined_variable_caught() {
619 check_err("fn f() -> trit { return ghost_var; }");
620 }
621
622 #[test]
623 fn test_defined_variable_ok() {
624 check_ok("fn f() -> trit { let x: trit = 1; return x; }");
625 }
626
627 #[test]
630 fn test_struct_field_access_ok() {
631 check_ok("struct S { val: trit } fn f(s: S) -> trit { return s.val; }");
632 }
633}