1use crate::ast::*;
2
3#[derive(Debug)]
6pub enum SemanticError {
7 TypeMismatch { expected: Type, found: Type },
8 UndefinedVariable(String),
9 UndefinedStruct(String),
10 UndefinedField { struct_name: String, field: String },
11 UndefinedFunction(String),
12 ReturnTypeMismatch { function: String, expected: Type, found: Type },
13 ArgCountMismatch { function: String, expected: usize, found: usize },
14 ArgTypeMismatch { function: String, param_index: usize, expected: Type, found: Type },
15 PropagateOnNonTrit { found: Type },
17}
18
19impl std::fmt::Display for SemanticError {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 match self {
22 Self::TypeMismatch { expected, found } =>
23 write!(f, "[TYPE-001] Type mismatch: expected {expected:?}, found {found:?}. Binary types don't map cleanly to ternary space."),
24 Self::UndefinedVariable(n) =>
25 write!(f, "[SCOPE-001] '{n}' is undefined. Hold state — declare before use."),
26 Self::UndefinedStruct(n) =>
27 write!(f, "[STRUCT-001] Struct '{n}' doesn't exist. The type system can't find it."),
28 Self::UndefinedField { struct_name, field } =>
29 write!(f, "[STRUCT-002] Struct '{struct_name}' has no field '{field}'. Check your definition."),
30 Self::UndefinedFunction(n) =>
31 write!(f, "[FN-001] '{n}' is not defined. Did you forget to declare it or import its module?"),
32 Self::ReturnTypeMismatch { function, expected, found } =>
33 write!(f, "[FN-002] Function '{function}' declared return type {expected:?} but returned {found:?}. Ternary contracts are strict."),
34 Self::ArgCountMismatch { function, expected, found } =>
35 write!(f, "[FN-003] '{function}' expects {expected} arg(s), got {found}. Arity is not optional."),
36 Self::ArgTypeMismatch { function, param_index, expected, found } =>
37 write!(f, "[FN-004] '{function}' arg {param_index}: expected {expected:?}, found {found:?}. Types travel with their values."),
38 Self::PropagateOnNonTrit { found } =>
39 write!(f, "[PROP-001] '?' used on a {found:?} expression. Only trit-returning functions can signal conflict. The third state requires a trit."),
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
47pub struct FunctionSig {
48 pub params: Option<Vec<Type>>,
50 pub return_type: Type,
51}
52
53impl FunctionSig {
54 fn exact(params: Vec<Type>, return_type: Type) -> Self {
55 Self { params: Some(params), return_type }
56 }
57 fn variadic(return_type: Type) -> Self {
58 Self { params: None, return_type }
59 }
60}
61
62pub struct SemanticAnalyzer {
65 scopes: Vec<std::collections::HashMap<String, Type>>,
66 struct_defs: std::collections::HashMap<String, Vec<(String, Type)>>,
67 func_signatures: std::collections::HashMap<String, FunctionSig>,
68 current_fn_name: Option<String>,
70 current_fn_return: Option<Type>,
71}
72
73impl SemanticAnalyzer {
74 pub fn new() -> Self {
75 let mut sigs: std::collections::HashMap<String, FunctionSig> = std::collections::HashMap::new();
76
77 sigs.insert("consensus".into(), FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
79 sigs.insert("invert".into(), FunctionSig::exact(vec![Type::Trit], Type::Trit));
80 sigs.insert("truth".into(), FunctionSig::exact(vec![], Type::Trit));
81 sigs.insert("hold".into(), FunctionSig::exact(vec![], Type::Trit));
82 sigs.insert("conflict".into(), FunctionSig::exact(vec![], Type::Trit));
83 sigs.insert("mul".into(), FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
84
85 sigs.insert("matmul".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
87 sigs.insert("sparsity".into(), FunctionSig::variadic(Type::Int));
88 sigs.insert("shape".into(), FunctionSig::variadic(Type::Int));
89 sigs.insert("zeros".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
90
91 sigs.insert("print".into(), FunctionSig::variadic(Type::Trit));
93 sigs.insert("println".into(), FunctionSig::variadic(Type::Trit));
94
95 sigs.insert("abs".into(), FunctionSig::exact(vec![Type::Int], Type::Int));
97 sigs.insert("min".into(), FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
98 sigs.insert("max".into(), FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
99
100 sigs.insert("quantize".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
102 sigs.insert("threshold".into(),FunctionSig::variadic(Type::Float));
103
104 sigs.insert("forward".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
106 sigs.insert("argmax".into(), FunctionSig::variadic(Type::Int));
107
108 sigs.insert("cast".into(), FunctionSig::variadic(Type::Trit));
110
111 Self {
112 scopes: vec![std::collections::HashMap::new()],
113 struct_defs: std::collections::HashMap::new(),
114 func_signatures: sigs,
115 current_fn_name: None,
116 current_fn_return: None,
117 }
118 }
119
120 pub fn register_structs(&mut self, structs: &[StructDef]) {
123 for s in structs {
124 self.struct_defs.insert(s.name.clone(), s.fields.clone());
125 }
126 }
127
128 pub fn register_functions(&mut self, functions: &[Function]) {
129 for f in functions {
130 let params = f.params.iter().map(|(_, ty)| ty.clone()).collect();
131 self.func_signatures.insert(
132 f.name.clone(),
133 FunctionSig::exact(params, f.return_type.clone()),
134 );
135 }
136 }
137
138 pub fn register_agents(&mut self, agents: &[AgentDef]) {
139 for agent in agents {
140 for method in &agent.methods {
141 let params = method.params.iter().map(|(_, ty)| ty.clone()).collect();
142 let sig = FunctionSig::exact(params, method.return_type.clone());
143 self.func_signatures.insert(method.name.clone(), sig.clone());
144 self.func_signatures.insert(
145 format!("{}::{}", agent.name, method.name),
146 sig,
147 );
148 }
149 }
150 }
151
152 pub fn check_program(&mut self, program: &Program) -> Result<(), SemanticError> {
155 self.register_structs(&program.structs);
156 self.register_functions(&program.functions);
157 self.register_agents(&program.agents);
158 for agent in &program.agents {
159 for method in &agent.methods {
160 self.check_function(method)?;
161 }
162 }
163 for func in &program.functions {
164 self.check_function(func)?;
165 }
166 Ok(())
167 }
168
169 fn check_function(&mut self, func: &Function) -> Result<(), SemanticError> {
170 let prev_name = self.current_fn_name.take();
172 let prev_return = self.current_fn_return.take();
173 self.current_fn_name = Some(func.name.clone());
174 self.current_fn_return = Some(func.return_type.clone());
175
176 self.scopes.push(std::collections::HashMap::new());
177 for (name, ty) in &func.params {
178 self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
179 }
180 for stmt in &func.body {
181 self.check_stmt(stmt)?;
182 }
183 self.scopes.pop();
184
185 self.current_fn_name = prev_name;
187 self.current_fn_return = prev_return;
188 Ok(())
189 }
190
191 pub fn check_stmt(&mut self, stmt: &Stmt) -> Result<(), SemanticError> {
194 match stmt {
195 Stmt::Let { name, ty, value } => {
196 let val_ty = self.infer_expr_type(value)?;
197 let type_ok = val_ty == *ty
198 || matches!(value, Expr::Cast { .. })
199 || (matches!(ty, Type::Named(_)) && val_ty == Type::Trit)
200 || (matches!(ty, Type::TritTensor { .. }) && matches!(val_ty, Type::TritTensor { .. }))
201 || (*ty == Type::AgentRef && val_ty == Type::AgentRef);
202 if !type_ok {
203 return Err(SemanticError::TypeMismatch { expected: ty.clone(), found: val_ty });
204 }
205 self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
206 Ok(())
207 }
208
209 Stmt::Return(expr) => {
210 let found = self.infer_expr_type(expr)?;
211 if let (Some(fn_name), Some(expected)) = (&self.current_fn_name, &self.current_fn_return) {
212 let ok = found == *expected
214 || matches!(expr, Expr::Cast { .. })
215 || (matches!(expected, Type::TritTensor { .. }) && matches!(found, Type::TritTensor { .. }))
216 || (matches!(expected, Type::Named(_)) && found == Type::Trit);
217 if !ok {
218 return Err(SemanticError::ReturnTypeMismatch {
219 function: fn_name.clone(),
220 expected: expected.clone(),
221 found,
222 });
223 }
224 }
225 Ok(())
226 }
227
228 Stmt::IfTernary { condition, on_pos, on_zero, on_neg } => {
229 let cond_ty = self.infer_expr_type(condition)?;
230 if cond_ty != Type::Trit {
231 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
232 }
233 self.check_stmt(on_pos)?;
234 self.check_stmt(on_zero)?;
235 self.check_stmt(on_neg)?;
236 Ok(())
237 }
238
239 Stmt::Match { condition, arms } => {
240 let cond_ty = self.infer_expr_type(condition)?;
241 if cond_ty != Type::Trit {
242 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
243 }
244 for (_val, arm_stmt) in arms {
245 self.check_stmt(arm_stmt)?;
246 }
247 Ok(())
248 }
249
250 Stmt::Block(stmts) => {
251 self.scopes.push(std::collections::HashMap::new());
252 for s in stmts { self.check_stmt(s)?; }
253 self.scopes.pop();
254 Ok(())
255 }
256
257 Stmt::Decorated { stmt, .. } => self.check_stmt(stmt),
258
259 Stmt::Expr(expr) => { self.infer_expr_type(expr)?; Ok(()) }
260
261 Stmt::ForIn { var, iter, body } => {
262 self.infer_expr_type(iter)?;
263 self.scopes.push(std::collections::HashMap::new());
264 self.scopes.last_mut().unwrap().insert(var.clone(), Type::Trit);
265 self.check_stmt(body)?;
266 self.scopes.pop();
267 Ok(())
268 }
269
270 Stmt::WhileTernary { condition, on_pos, on_zero, on_neg } => {
271 let cond_ty = self.infer_expr_type(condition)?;
272 if cond_ty != Type::Trit {
273 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
274 }
275 self.check_stmt(on_pos)?;
276 self.check_stmt(on_zero)?;
277 self.check_stmt(on_neg)?;
278 Ok(())
279 }
280
281 Stmt::Loop { body } => self.check_stmt(body),
282 Stmt::Break => Ok(()),
283 Stmt::Continue => Ok(()),
284 Stmt::Use { .. } => Ok(()),
285
286 Stmt::Send { target, message } => {
287 self.infer_expr_type(target)?;
288 self.infer_expr_type(message)?;
289 Ok(())
290 }
291
292 Stmt::FieldSet { object, field, value } => {
293 let obj_ty = self.lookup_var(object)?;
294 if let Type::Named(struct_name) = obj_ty {
295 let field_ty = self.lookup_field(&struct_name, field)?;
296 let val_ty = self.infer_expr_type(value)?;
297 if val_ty != field_ty {
298 return Err(SemanticError::TypeMismatch { expected: field_ty, found: val_ty });
299 }
300 } else {
301 self.infer_expr_type(value)?;
302 }
303 Ok(())
304 }
305
306 Stmt::IndexSet { object, row, col, value } => {
307 self.lookup_var(object)?;
308 self.infer_expr_type(row)?;
309 self.infer_expr_type(col)?;
310 self.infer_expr_type(value)?;
311 Ok(())
312 }
313 }
314 }
315
316 fn infer_expr_type(&self, expr: &Expr) -> Result<Type, SemanticError> {
319 match expr {
320 Expr::TritLiteral(_) => Ok(Type::Trit),
321 Expr::IntLiteral(_) => Ok(Type::Int),
322 Expr::StringLiteral(_) => Ok(Type::String),
323 Expr::Ident(name) => self.lookup_var(name),
324
325 Expr::BinaryOp { op, lhs, rhs } => {
326 let l = self.infer_expr_type(lhs)?;
327 let r = self.infer_expr_type(rhs)?;
328 match op {
329 BinOp::Less | BinOp::Greater | BinOp::Equal | BinOp::NotEqual | BinOp::And | BinOp::Or => {
330 Ok(Type::Trit)
331 }
332 _ => {
333 if l != r {
334 return Err(SemanticError::TypeMismatch { expected: l, found: r });
335 }
336 Ok(l)
337 }
338 }
339 }
340
341 Expr::UnaryOp { expr, .. } => self.infer_expr_type(expr),
342
343 Expr::Call { callee, args } => {
344 let sig = self.func_signatures.get(callee.as_str())
345 .ok_or_else(|| SemanticError::UndefinedFunction(callee.clone()))?
346 .clone();
347
348 if let Some(param_types) = &sig.params {
350 if args.len() != param_types.len() {
351 return Err(SemanticError::ArgCountMismatch {
352 function: callee.clone(),
353 expected: param_types.len(),
354 found: args.len(),
355 });
356 }
357 for (i, (arg, expected_ty)) in args.iter().zip(param_types.iter()).enumerate() {
358 let found_ty = self.infer_expr_type(arg)?;
359 let ok = found_ty == *expected_ty
361 || matches!(arg, Expr::Cast { .. })
362 || (matches!(expected_ty, Type::TritTensor { .. })
363 && matches!(found_ty, Type::TritTensor { .. }))
364 || (matches!(expected_ty, Type::Named(_)) && found_ty == Type::Trit);
365 if !ok {
366 return Err(SemanticError::ArgTypeMismatch {
367 function: callee.clone(),
368 param_index: i,
369 expected: expected_ty.clone(),
370 found: found_ty,
371 });
372 }
373 }
374 } else {
375 for arg in args { self.infer_expr_type(arg)?; }
377 }
378
379 Ok(sig.return_type)
380 }
381
382 Expr::Cast { ty, .. } => Ok(ty.clone()),
383 Expr::Spawn { .. } => Ok(Type::AgentRef),
384 Expr::Await { .. } => Ok(Type::Trit),
385 Expr::NodeId => Ok(Type::String),
386
387 Expr::Propagate { expr } => {
388 let inner = self.infer_expr_type(expr)?;
389 if inner != Type::Trit {
390 return Err(SemanticError::PropagateOnNonTrit { found: inner });
391 }
392 Ok(Type::Trit)
393 }
394
395 Expr::FieldAccess { object, field } => {
396 let obj_ty = self.infer_expr_type(object)?;
397 if let Type::Named(struct_name) = obj_ty {
398 self.lookup_field(&struct_name, field)
399 } else {
400 Ok(Type::Trit)
401 }
402 }
403
404 Expr::Index { object, row, col } => {
405 self.infer_expr_type(object)?;
406 self.infer_expr_type(row)?;
407 self.infer_expr_type(col)?;
408 Ok(Type::Trit)
409 }
410 }
411 }
412
413 fn lookup_var(&self, name: &str) -> Result<Type, SemanticError> {
416 for scope in self.scopes.iter().rev() {
417 if let Some(ty) = scope.get(name) { return Ok(ty.clone()); }
418 }
419 Err(SemanticError::UndefinedVariable(name.to_string()))
420 }
421
422 fn lookup_field(&self, struct_name: &str, field: &str) -> Result<Type, SemanticError> {
423 let fields = self.struct_defs.get(struct_name)
424 .ok_or_else(|| SemanticError::UndefinedStruct(struct_name.to_string()))?;
425 fields.iter()
426 .find(|(f, _)| f == field)
427 .map(|(_, ty)| ty.clone())
428 .ok_or_else(|| SemanticError::UndefinedField {
429 struct_name: struct_name.to_string(),
430 field: field.to_string(),
431 })
432 }
433}
434
435#[cfg(test)]
438mod tests {
439 use super::*;
440 use crate::parser::Parser;
441
442 fn check(src: &str) -> Result<(), SemanticError> {
443 let mut parser = Parser::new(src);
444 let prog = parser.parse_program().expect("parse failed");
445 let mut analyzer = SemanticAnalyzer::new();
446 analyzer.check_program(&prog)
447 }
448
449 fn check_ok(src: &str) {
450 assert!(check(src).is_ok(), "expected ok, got: {:?}", check(src));
451 }
452
453 fn check_err(src: &str) {
454 assert!(check(src).is_err(), "expected error but check passed");
455 }
456
457 #[test]
460 fn test_return_correct_type() {
461 check_ok("fn f() -> trit { return 1; }");
462 }
463
464 #[test]
465 fn test_return_wrong_type_caught() {
466 check_err("fn f() -> trit { let x: int = 42; return x; }");
468 }
469
470 #[test]
471 fn test_return_trit_in_trit_fn() {
472 check_ok("fn decide(a: trit, b: trit) -> trit { return consensus(a, b); }");
473 }
474
475 #[test]
478 fn test_call_correct_arity() {
479 check_ok("fn f() -> trit { return consensus(1, -1); }");
480 }
481
482 #[test]
483 fn test_call_too_few_args_caught() {
484 check_err("fn f() -> trit { return consensus(1); }");
485 }
486
487 #[test]
488 fn test_call_too_many_args_caught() {
489 check_err("fn f() -> trit { return invert(1, 1); }");
490 }
491
492 #[test]
495 fn test_call_wrong_arg_type_caught() {
496 check_err("fn f() -> trit { let x: int = 42; return invert(x); }");
498 }
499
500 #[test]
501 fn test_call_correct_arg_type() {
502 check_ok("fn f(a: trit) -> trit { return invert(a); }");
503 }
504
505 #[test]
508 fn test_undefined_function_caught() {
509 check_err("fn f() -> trit { return doesnt_exist(1); }");
510 }
511
512 #[test]
515 fn test_user_fn_return_type_registered() {
516 check_ok("fn helper(a: trit) -> trit { return invert(a); } fn main() -> trit { return helper(1); }");
517 }
518
519 #[test]
520 fn test_user_fn_wrong_return_caught() {
521 check_err("fn helper(a: trit) -> trit { let x: int = 1; return x; }");
522 }
523
524 #[test]
527 fn test_undefined_variable_caught() {
528 check_err("fn f() -> trit { return ghost_var; }");
529 }
530
531 #[test]
532 fn test_defined_variable_ok() {
533 check_ok("fn f() -> trit { let x: trit = 1; return x; }");
534 }
535
536 #[test]
539 fn test_struct_field_access_ok() {
540 check_ok("struct S { val: trit } fn f(s: S) -> trit { return s.val; }");
541 }
542}