1use crate::ast::*;
2
3#[derive(Debug)]
6pub enum SemanticError {
7 TypeMismatch { expected: Type, found: Type },
8 UndefinedVariable(String),
9 UndefinedStruct(String),
10 UndefinedField { struct_name: String, field: String },
11 UndefinedFunction(String),
12 ReturnTypeMismatch { function: String, expected: Type, found: Type },
13 ArgCountMismatch { function: String, expected: usize, found: usize },
14 ArgTypeMismatch { function: String, param_index: usize, expected: Type, found: Type },
15 PropagateOnNonTrit { found: Type },
17}
18
19impl std::fmt::Display for SemanticError {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 match self {
22 Self::TypeMismatch { expected, found } =>
23 write!(f, "[TYPE-001] Type mismatch: expected {expected:?}, found {found:?}. A trit is not an int. An int is not a trit. They don't coerce.\n → details: stdlib/errors/TYPE-001.tern | ternlang errors TYPE-001"),
24 Self::UndefinedVariable(n) =>
25 write!(f, "[SCOPE-001] '{n}' is undefined — hold state. Declare before use, or check for a typo.\n → details: stdlib/errors/SCOPE-001.tern | ternlang errors SCOPE-001"),
26 Self::UndefinedStruct(n) =>
27 write!(f, "[STRUCT-001] Struct '{n}' doesn't exist. A ghost type — the type system can't find it anywhere.\n → details: stdlib/errors/STRUCT-001.tern | ternlang errors STRUCT-001"),
28 Self::UndefinedField { struct_name, field } =>
29 write!(f, "[STRUCT-002] Struct '{struct_name}' has no field '{field}'. Check the definition — maybe it was renamed.\n → details: stdlib/errors/STRUCT-002.tern | ternlang errors STRUCT-002"),
30 Self::UndefinedFunction(n) =>
31 write!(f, "[FN-001] '{n}' was called but never defined. Declare it above the call site, or check for a typo.\n → details: stdlib/errors/FN-001.tern | ternlang errors FN-001"),
32 Self::ReturnTypeMismatch { function, expected, found } =>
33 write!(f, "[FN-002] '{function}' promised to return {expected:?} but returned {found:?}. Ternary contracts are strict — all paths must match.\n → details: stdlib/errors/FN-002.tern | ternlang errors FN-002"),
34 Self::ArgCountMismatch { function, expected, found } =>
35 write!(f, "[FN-003] '{function}' expects {expected} arg(s), got {found}. Arity is not optional — not even in hold state.\n → details: stdlib/errors/FN-003.tern | ternlang errors FN-003"),
36 Self::ArgTypeMismatch { function, param_index, expected, found } =>
37 write!(f, "[FN-004] '{function}' arg {param_index}: expected {expected:?}, found {found:?}. Types travel with their values — they don't change at the border.\n → details: stdlib/errors/FN-004.tern | ternlang errors FN-004"),
38 Self::PropagateOnNonTrit { found } =>
39 write!(f, "[PROP-001] '?' used on a {found:?} expression. Only trit-returning functions carry the three-valued signal. The third state requires a trit.\n → details: stdlib/errors/PROP-001.tern | ternlang errors PROP-001"),
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
47pub struct FunctionSig {
48 pub params: Option<Vec<Type>>,
50 pub return_type: Type,
51}
52
53impl FunctionSig {
54 fn exact(params: Vec<Type>, return_type: Type) -> Self {
55 Self { params: Some(params), return_type }
56 }
57 fn variadic(return_type: Type) -> Self {
58 Self { params: None, return_type }
59 }
60}
61
62pub struct SemanticAnalyzer {
65 scopes: Vec<std::collections::HashMap<String, Type>>,
66 struct_defs: std::collections::HashMap<String, Vec<(String, Type)>>,
67 func_signatures: std::collections::HashMap<String, FunctionSig>,
68 current_fn_name: Option<String>,
70 current_fn_return: Option<Type>,
71}
72
73impl SemanticAnalyzer {
74 pub fn new() -> Self {
75 let mut sigs: std::collections::HashMap<String, FunctionSig> = std::collections::HashMap::new();
76
77 sigs.insert("consensus".into(), FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
79 sigs.insert("invert".into(), FunctionSig::exact(vec![Type::Trit], Type::Trit));
80 sigs.insert("length".into(), FunctionSig::variadic(Type::Int));
81 sigs.insert("truth".into(), FunctionSig::exact(vec![], Type::Trit));
82 sigs.insert("hold".into(), FunctionSig::exact(vec![], Type::Trit));
83 sigs.insert("conflict".into(), FunctionSig::exact(vec![], Type::Trit));
84 sigs.insert("mul".into(), FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
85
86 sigs.insert("matmul".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
88 sigs.insert("sparsity".into(), FunctionSig::variadic(Type::Int));
89 sigs.insert("shape".into(), FunctionSig::variadic(Type::Int));
90 sigs.insert("zeros".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
91
92 sigs.insert("print".into(), FunctionSig::variadic(Type::Trit));
94 sigs.insert("println".into(), FunctionSig::variadic(Type::Trit));
95
96 sigs.insert("abs".into(), FunctionSig::exact(vec![Type::Int], Type::Int));
98 sigs.insert("min".into(), FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
99 sigs.insert("max".into(), FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
100
101 sigs.insert("quantize".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
103 sigs.insert("threshold".into(),FunctionSig::variadic(Type::Float));
104
105 sigs.insert("forward".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
107 sigs.insert("argmax".into(), FunctionSig::variadic(Type::Int));
108
109 sigs.insert("cast".into(), FunctionSig::variadic(Type::Trit));
111
112 Self {
113 scopes: vec![std::collections::HashMap::new()],
114 struct_defs: std::collections::HashMap::new(),
115 func_signatures: sigs,
116 current_fn_name: None,
117 current_fn_return: None,
118 }
119 }
120
121 pub fn register_structs(&mut self, structs: &[StructDef]) {
124 for s in structs {
125 self.struct_defs.insert(s.name.clone(), s.fields.clone());
126 }
127 }
128
129 pub fn register_functions(&mut self, functions: &[Function]) {
130 for f in functions {
131 let params = f.params.iter().map(|(_, ty)| ty.clone()).collect();
132 self.func_signatures.insert(
133 f.name.clone(),
134 FunctionSig::exact(params, f.return_type.clone()),
135 );
136 }
137 }
138
139 pub fn register_agents(&mut self, agents: &[AgentDef]) {
140 for agent in agents {
141 for method in &agent.methods {
142 let params = method.params.iter().map(|(_, ty)| ty.clone()).collect();
143 let sig = FunctionSig::exact(params, method.return_type.clone());
144 self.func_signatures.insert(method.name.clone(), sig.clone());
145 self.func_signatures.insert(
146 format!("{}::{}", agent.name, method.name),
147 sig,
148 );
149 }
150 }
151 }
152
153 pub fn check_program(&mut self, program: &Program) -> Result<(), SemanticError> {
156 self.register_structs(&program.structs);
157 self.register_functions(&program.functions);
158 self.register_agents(&program.agents);
159 for agent in &program.agents {
160 for method in &agent.methods {
161 self.check_function(method)?;
162 }
163 }
164 for func in &program.functions {
165 self.check_function(func)?;
166 }
167 Ok(())
168 }
169
170 fn check_function(&mut self, func: &Function) -> Result<(), SemanticError> {
171 let prev_name = self.current_fn_name.take();
173 let prev_return = self.current_fn_return.take();
174 self.current_fn_name = Some(func.name.clone());
175 self.current_fn_return = Some(func.return_type.clone());
176
177 self.scopes.push(std::collections::HashMap::new());
178 for (name, ty) in &func.params {
179 self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
180 }
181 for stmt in &func.body {
182 self.check_stmt(stmt)?;
183 }
184 self.scopes.pop();
185
186 self.current_fn_name = prev_name;
188 self.current_fn_return = prev_return;
189 Ok(())
190 }
191
192 pub fn check_stmt(&mut self, stmt: &Stmt) -> Result<(), SemanticError> {
195 match stmt {
196 Stmt::Let { name, ty, value } => {
197 let val_ty = self.infer_expr_type(value)?;
198 let type_ok = val_ty == *ty
199 || matches!(value, Expr::Cast { .. })
200 || matches!(value, Expr::StructLiteral { .. }) || (*ty == Type::Int && val_ty == Type::Trit)
202 || (*ty == Type::Trit && val_ty == Type::Int)
203 || (matches!(ty, Type::Named(_)) && val_ty == Type::Trit)
204 || (matches!(ty, Type::TritTensor { .. }) && matches!(val_ty, Type::TritTensor { .. }))
205 || (*ty == Type::AgentRef && val_ty == Type::AgentRef);
206 if !type_ok {
207 return Err(SemanticError::TypeMismatch { expected: ty.clone(), found: val_ty });
208 }
209 self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
210 Ok(())
211 }
212
213 Stmt::Return(expr) => {
214 let found = self.infer_expr_type(expr)?;
215 if let (Some(fn_name), Some(expected)) = (&self.current_fn_name, &self.current_fn_return) {
216 let ok = found == *expected
218 || matches!(expr, Expr::Cast { .. })
219 || matches!(expr, Expr::StructLiteral { .. })
220 || (*expected == Type::Int && found == Type::Trit)
221 || (*expected == Type::Trit && found == Type::Int)
222 || (matches!(expected, Type::TritTensor { .. }) && matches!(found, Type::TritTensor { .. }))
223 || (matches!(expected, Type::Named(_)) && found == Type::Trit);
224 if !ok {
225 return Err(SemanticError::ReturnTypeMismatch {
226 function: fn_name.clone(),
227 expected: expected.clone(),
228 found,
229 });
230 }
231 }
232 Ok(())
233 }
234
235 Stmt::IfTernary { condition, on_pos, on_zero, on_neg } => {
236 let cond_ty = self.infer_expr_type(condition)?;
237 if cond_ty != Type::Trit {
238 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
239 }
240 self.check_stmt(on_pos)?;
241 self.check_stmt(on_zero)?;
242 self.check_stmt(on_neg)?;
243 Ok(())
244 }
245
246 Stmt::Match { condition, arms } => {
247 let cond_ty = self.infer_expr_type(condition)?;
248 if cond_ty != Type::Trit && cond_ty != Type::Int {
249 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
250 }
251
252 if cond_ty == Type::Trit {
253 let has_pos = arms.iter().any(|(v, _)| *v == 1);
255 let has_zero = arms.iter().any(|(v, _)| *v == 0);
256 let has_neg = arms.iter().any(|(v, _)| *v == -1);
257 if !has_pos || !has_zero || !has_neg {
258 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
261 }
262 for (val, _) in arms {
263 if *val < -1 || *val > 1 {
264 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: Type::Int });
265 }
266 }
267 }
268
269 for (_val, arm_stmt) in arms {
270 self.check_stmt(arm_stmt)?;
271 }
272 Ok(())
273 }
274
275 Stmt::Block(stmts) => {
276 self.scopes.push(std::collections::HashMap::new());
277 for s in stmts { self.check_stmt(s)?; }
278 self.scopes.pop();
279 Ok(())
280 }
281
282 Stmt::Decorated { stmt, .. } => self.check_stmt(stmt),
283
284 Stmt::Expr(expr) => { self.infer_expr_type(expr)?; Ok(()) }
285
286 Stmt::ForIn { var, iter, body } => {
287 self.infer_expr_type(iter)?;
288 self.scopes.push(std::collections::HashMap::new());
289 self.scopes.last_mut().unwrap().insert(var.clone(), Type::Trit);
290 self.check_stmt(body)?;
291 self.scopes.pop();
292 Ok(())
293 }
294
295 Stmt::WhileTernary { condition, on_pos, on_zero, on_neg } => {
296 let cond_ty = self.infer_expr_type(condition)?;
297 if cond_ty != Type::Trit {
298 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
299 }
300 self.check_stmt(on_pos)?;
301 self.check_stmt(on_zero)?;
302 self.check_stmt(on_neg)?;
303 Ok(())
304 }
305
306 Stmt::Loop { body } => self.check_stmt(body),
307 Stmt::Break => Ok(()),
308 Stmt::Continue => Ok(()),
309 Stmt::Use { .. } => Ok(()),
310 Stmt::FromImport { .. } => Ok(()),
311
312 Stmt::Send { target, message } => {
313 self.infer_expr_type(target)?;
314 self.infer_expr_type(message)?;
315 Ok(())
316 }
317
318 Stmt::FieldSet { object, field, value } => {
319 let obj_ty = self.lookup_var(object)?;
320 if let Type::Named(struct_name) = obj_ty {
321 let field_ty = self.lookup_field(&struct_name, field)?;
322 let val_ty = self.infer_expr_type(value)?;
323 if val_ty != field_ty {
324 return Err(SemanticError::TypeMismatch { expected: field_ty, found: val_ty });
325 }
326 } else {
327 self.infer_expr_type(value)?;
328 }
329 Ok(())
330 }
331
332 Stmt::IndexSet { object, row, col, value } => {
333 self.lookup_var(object)?;
334 self.infer_expr_type(row)?;
335 self.infer_expr_type(col)?;
336 self.infer_expr_type(value)?;
337 Ok(())
338 }
339
340 Stmt::Set { name, value } => {
341 let var_ty = self.lookup_var(name)?;
342 let val_ty = self.infer_expr_type(value)?;
343 let ok = var_ty == val_ty
344 || matches!(value, Expr::Cast { .. })
345 || (var_ty == Type::Int && val_ty == Type::Trit)
346 || (var_ty == Type::Trit && val_ty == Type::Int);
347 if !ok {
348 return Err(SemanticError::TypeMismatch { expected: var_ty, found: val_ty });
349 }
350 Ok(())
351 }
352 }
353 }
354
355 fn infer_expr_type(&self, expr: &Expr) -> Result<Type, SemanticError> {
358 match expr {
359 Expr::TritLiteral(_) => Ok(Type::Trit),
360 Expr::IntLiteral(_) => Ok(Type::Int),
361 Expr::FloatLiteral(_) => Ok(Type::Float),
362 Expr::StringLiteral(_) => Ok(Type::String),
363 Expr::Ident(name) => self.lookup_var(name),
364
365 Expr::BinaryOp { op, lhs, rhs } => {
366 let l = self.infer_expr_type(lhs)?;
367 let r = self.infer_expr_type(rhs)?;
368 match op {
369 BinOp::Less | BinOp::Greater | BinOp::LessEqual | BinOp::GreaterEqual | BinOp::Equal | BinOp::NotEqual | BinOp::And | BinOp::Or => {
370 Ok(Type::Trit)
371 }
372
373 _ => {
374 let is_numeric = |t: &Type| matches!(t, Type::Int | Type::Trit | Type::Float);
376 if is_numeric(&l) && is_numeric(&r) {
377 if l == Type::Float || r == Type::Float { return Ok(Type::Float); }
378 if l == Type::Int || r == Type::Int { return Ok(Type::Int); }
379 return Ok(Type::Trit);
380 }
381
382 if l != r {
383 return Err(SemanticError::TypeMismatch { expected: l, found: r });
384 }
385 Ok(l)
386 }
387 }
388 }
389
390 Expr::UnaryOp { expr, .. } => self.infer_expr_type(expr),
391
392 Expr::Call { callee, args } => {
393 let sig = self.func_signatures.get(callee.as_str())
394 .ok_or_else(|| SemanticError::UndefinedFunction(callee.clone()))?
395 .clone();
396
397 if let Some(param_types) = &sig.params {
399 if args.len() != param_types.len() {
400 return Err(SemanticError::ArgCountMismatch {
401 function: callee.clone(),
402 expected: param_types.len(),
403 found: args.len(),
404 });
405 }
406 for (i, (arg, expected_ty)) in args.iter().zip(param_types.iter()).enumerate() {
407 let found_ty = self.infer_expr_type(arg)?;
408 let ok = found_ty == *expected_ty
410 || matches!(arg, Expr::Cast { .. })
411 || (expected_ty == &Type::Int && found_ty == Type::Trit)
412 || (expected_ty == &Type::Trit && found_ty == Type::Int)
413 || (matches!(expected_ty, Type::TritTensor { .. })
414 && matches!(found_ty, Type::TritTensor { .. }))
415 || (matches!(expected_ty, Type::Named(_)) && found_ty == Type::Trit);
416 if !ok {
417 return Err(SemanticError::ArgTypeMismatch {
418 function: callee.clone(),
419 param_index: i,
420 expected: expected_ty.clone(),
421 found: found_ty,
422 });
423 }
424 }
425 } else {
426 for arg in args { self.infer_expr_type(arg)?; }
428 }
429
430 Ok(sig.return_type)
431 }
432
433 Expr::Cast { ty, .. } => Ok(ty.clone()),
434 Expr::Spawn { .. } => Ok(Type::AgentRef),
435 Expr::Await { .. } => Ok(Type::Trit),
436 Expr::NodeId => Ok(Type::String),
437
438 Expr::Propagate { expr } => {
439 let inner = self.infer_expr_type(expr)?;
440 if inner != Type::Trit {
441 return Err(SemanticError::PropagateOnNonTrit { found: inner });
442 }
443 Ok(Type::Trit)
444 }
445
446 Expr::TritTensorLiteral(vals) => {
447 Ok(Type::TritTensor { dims: vec![vals.len()] })
448 }
449
450 Expr::StructLiteral { name, fields } => {
451 let def = self.struct_defs.get(name)
453 .ok_or_else(|| SemanticError::UndefinedStruct(name.clone()))?;
454
455 if fields.len() != def.len() {
456 return Err(SemanticError::ArgCountMismatch {
457 function: name.clone(),
458 expected: def.len(),
459 found: fields.len()
460 });
461 }
462
463 for (f_name, f_val) in fields {
464 let expected_f_ty = def.iter()
465 .find(|(n, _)| n == f_name)
466 .ok_or_else(|| SemanticError::UndefinedField {
467 struct_name: name.clone(),
468 field: f_name.clone()
469 })?
470 .1.clone();
471 let found_f_ty = self.infer_expr_type(f_val)?;
472 if found_f_ty != expected_f_ty {
473 return Err(SemanticError::TypeMismatch {
474 expected: expected_f_ty,
475 found: found_f_ty
476 });
477 }
478 }
479 Ok(Type::Named(name.clone()))
480 }
481
482 Expr::FieldAccess { object, field } => {
483 let obj_ty = self.infer_expr_type(object)?;
484 if let Type::Named(struct_name) = obj_ty {
485 self.lookup_field(&struct_name, field)
486 } else {
487 Ok(Type::Trit)
488 }
489 }
490
491 Expr::Index { object, row, col } => {
492 self.infer_expr_type(object)?;
493 self.infer_expr_type(row)?;
494 self.infer_expr_type(col)?;
495 Ok(Type::Trit)
496 }
497 }
498 }
499
500 fn lookup_var(&self, name: &str) -> Result<Type, SemanticError> {
503 for scope in self.scopes.iter().rev() {
504 if let Some(ty) = scope.get(name) { return Ok(ty.clone()); }
505 }
506 Err(SemanticError::UndefinedVariable(name.to_string()))
507 }
508
509 fn lookup_field(&self, struct_name: &str, field: &str) -> Result<Type, SemanticError> {
510 let fields = self.struct_defs.get(struct_name)
511 .ok_or_else(|| SemanticError::UndefinedStruct(struct_name.to_string()))?;
512 fields.iter()
513 .find(|(f, _)| f == field)
514 .map(|(_, ty)| ty.clone())
515 .ok_or_else(|| SemanticError::UndefinedField {
516 struct_name: struct_name.to_string(),
517 field: field.to_string(),
518 })
519 }
520}
521
522#[cfg(test)]
525mod tests {
526 use super::*;
527 use crate::parser::Parser;
528
529 fn check(src: &str) -> Result<(), SemanticError> {
530 let mut parser = Parser::new(src);
531 let prog = parser.parse_program().expect("parse failed");
532 let mut analyzer = SemanticAnalyzer::new();
533 analyzer.check_program(&prog)
534 }
535
536 fn check_ok(src: &str) {
537 assert!(check(src).is_ok(), "expected ok, got: {:?}", check(src));
538 }
539
540 fn check_err(src: &str) {
541 assert!(check(src).is_err(), "expected error but check passed");
542 }
543
544 #[test]
547 fn test_return_correct_type() {
548 check_ok("fn f() -> trit { return 1; }");
549 }
550
551 #[test]
552 fn test_return_int_in_trit_fn() {
553 check_ok("fn f() -> trit { let x: int = 42; return x; }");
555 }
556
557 #[test]
558 fn test_return_trit_in_trit_fn() {
559 check_ok("fn decide(a: trit, b: trit) -> trit { return consensus(a, b); }");
560 }
561
562 #[test]
565 fn test_call_correct_arity() {
566 check_ok("fn f() -> trit { return consensus(1, -1); }");
567 }
568
569 #[test]
570 fn test_call_too_few_args_caught() {
571 check_err("fn f() -> trit { return consensus(1); }");
572 }
573
574 #[test]
575 fn test_call_too_many_args_caught() {
576 check_err("fn f() -> trit { return invert(1, 1); }");
577 }
578
579 #[test]
582 fn test_call_int_arg_in_trit_fn() {
583 check_ok("fn f(a: trit) -> trit { return invert(a); } fn main() -> trit { let x: int = 42; return f(x); }");
585 }
586
587 #[test]
588 fn test_call_correct_arg_type() {
589 check_ok("fn f(a: trit) -> trit { return invert(a); }");
590 }
591
592 #[test]
595 fn test_undefined_function_caught() {
596 check_err("fn f() -> trit { return doesnt_exist(1); }");
597 }
598
599 #[test]
602 fn test_user_fn_return_type_registered() {
603 check_ok("fn helper(a: trit) -> trit { return invert(a); } fn main() -> trit { return helper(1); }");
604 }
605
606 #[test]
607 fn test_user_fn_int_return_ok() {
608 check_ok("fn helper(a: trit) -> trit { let x: int = 1; return x; }");
610 }
611
612 #[test]
615 fn test_undefined_variable_caught() {
616 check_err("fn f() -> trit { return ghost_var; }");
617 }
618
619 #[test]
620 fn test_defined_variable_ok() {
621 check_ok("fn f() -> trit { let x: trit = 1; return x; }");
622 }
623
624 #[test]
627 fn test_struct_field_access_ok() {
628 check_ok("struct S { val: trit } fn f(s: S) -> trit { return s.val; }");
629 }
630}