1use crate::ast::*;
2
3#[derive(Debug)]
6pub enum SemanticError {
7 TypeMismatch { expected: Type, found: Type },
8 UndefinedVariable(String),
9 UndefinedStruct(String),
10 UndefinedField { struct_name: String, field: String },
11 UndefinedFunction(String),
12 ReturnTypeMismatch { function: String, expected: Type, found: Type },
13 ArgCountMismatch { function: String, expected: usize, found: usize },
14 ArgTypeMismatch { function: String, param_index: usize, expected: Type, found: Type },
15}
16
17impl std::fmt::Display for SemanticError {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 match self {
20 Self::TypeMismatch { expected, found } =>
21 write!(f, "type mismatch: expected {:?}, found {:?}", expected, found),
22 Self::UndefinedVariable(n) =>
23 write!(f, "undefined variable: '{}'", n),
24 Self::UndefinedStruct(n) =>
25 write!(f, "undefined struct: '{}'", n),
26 Self::UndefinedField { struct_name, field } =>
27 write!(f, "struct '{}' has no field '{}'", struct_name, field),
28 Self::UndefinedFunction(n) =>
29 write!(f, "undefined function: '{}'", n),
30 Self::ReturnTypeMismatch { function, expected, found } =>
31 write!(f, "function '{}' declared return type {:?} but returns {:?}", function, expected, found),
32 Self::ArgCountMismatch { function, expected, found } =>
33 write!(f, "function '{}' expects {} argument(s), got {}", function, expected, found),
34 Self::ArgTypeMismatch { function, param_index, expected, found } =>
35 write!(f, "function '{}' argument {}: expected {:?}, found {:?}", function, param_index, expected, found),
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
43pub struct FunctionSig {
44 pub params: Option<Vec<Type>>,
46 pub return_type: Type,
47}
48
49impl FunctionSig {
50 fn exact(params: Vec<Type>, return_type: Type) -> Self {
51 Self { params: Some(params), return_type }
52 }
53 fn variadic(return_type: Type) -> Self {
54 Self { params: None, return_type }
55 }
56}
57
58pub struct SemanticAnalyzer {
61 scopes: Vec<std::collections::HashMap<String, Type>>,
62 struct_defs: std::collections::HashMap<String, Vec<(String, Type)>>,
63 func_signatures: std::collections::HashMap<String, FunctionSig>,
64 current_fn_name: Option<String>,
66 current_fn_return: Option<Type>,
67}
68
69impl SemanticAnalyzer {
70 pub fn new() -> Self {
71 let mut sigs: std::collections::HashMap<String, FunctionSig> = std::collections::HashMap::new();
72
73 sigs.insert("consensus".into(), FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
75 sigs.insert("invert".into(), FunctionSig::exact(vec![Type::Trit], Type::Trit));
76 sigs.insert("truth".into(), FunctionSig::exact(vec![], Type::Trit));
77 sigs.insert("hold".into(), FunctionSig::exact(vec![], Type::Trit));
78 sigs.insert("conflict".into(), FunctionSig::exact(vec![], Type::Trit));
79 sigs.insert("mul".into(), FunctionSig::exact(vec![Type::Trit, Type::Trit], Type::Trit));
80
81 sigs.insert("matmul".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
83 sigs.insert("sparsity".into(), FunctionSig::variadic(Type::Int));
84 sigs.insert("shape".into(), FunctionSig::variadic(Type::Int));
85 sigs.insert("zeros".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
86
87 sigs.insert("print".into(), FunctionSig::variadic(Type::Trit));
89 sigs.insert("println".into(), FunctionSig::variadic(Type::Trit));
90
91 sigs.insert("abs".into(), FunctionSig::exact(vec![Type::Int], Type::Int));
93 sigs.insert("min".into(), FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
94 sigs.insert("max".into(), FunctionSig::exact(vec![Type::Int, Type::Int], Type::Int));
95
96 sigs.insert("quantize".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
98 sigs.insert("threshold".into(),FunctionSig::variadic(Type::Float));
99
100 sigs.insert("forward".into(), FunctionSig::variadic(Type::TritTensor { dims: vec![0, 0] }));
102 sigs.insert("argmax".into(), FunctionSig::variadic(Type::Int));
103
104 sigs.insert("cast".into(), FunctionSig::variadic(Type::Trit));
106
107 Self {
108 scopes: vec![std::collections::HashMap::new()],
109 struct_defs: std::collections::HashMap::new(),
110 func_signatures: sigs,
111 current_fn_name: None,
112 current_fn_return: None,
113 }
114 }
115
116 pub fn register_structs(&mut self, structs: &[StructDef]) {
119 for s in structs {
120 self.struct_defs.insert(s.name.clone(), s.fields.clone());
121 }
122 }
123
124 pub fn register_functions(&mut self, functions: &[Function]) {
125 for f in functions {
126 let params = f.params.iter().map(|(_, ty)| ty.clone()).collect();
127 self.func_signatures.insert(
128 f.name.clone(),
129 FunctionSig::exact(params, f.return_type.clone()),
130 );
131 }
132 }
133
134 pub fn register_agents(&mut self, agents: &[AgentDef]) {
135 for agent in agents {
136 for method in &agent.methods {
137 let params = method.params.iter().map(|(_, ty)| ty.clone()).collect();
138 let sig = FunctionSig::exact(params, method.return_type.clone());
139 self.func_signatures.insert(method.name.clone(), sig.clone());
140 self.func_signatures.insert(
141 format!("{}::{}", agent.name, method.name),
142 sig,
143 );
144 }
145 }
146 }
147
148 pub fn check_program(&mut self, program: &Program) -> Result<(), SemanticError> {
151 self.register_structs(&program.structs);
152 self.register_functions(&program.functions);
153 self.register_agents(&program.agents);
154 for agent in &program.agents {
155 for method in &agent.methods {
156 self.check_function(method)?;
157 }
158 }
159 for func in &program.functions {
160 self.check_function(func)?;
161 }
162 Ok(())
163 }
164
165 fn check_function(&mut self, func: &Function) -> Result<(), SemanticError> {
166 let prev_name = self.current_fn_name.take();
168 let prev_return = self.current_fn_return.take();
169 self.current_fn_name = Some(func.name.clone());
170 self.current_fn_return = Some(func.return_type.clone());
171
172 self.scopes.push(std::collections::HashMap::new());
173 for (name, ty) in &func.params {
174 self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
175 }
176 for stmt in &func.body {
177 self.check_stmt(stmt)?;
178 }
179 self.scopes.pop();
180
181 self.current_fn_name = prev_name;
183 self.current_fn_return = prev_return;
184 Ok(())
185 }
186
187 pub fn check_stmt(&mut self, stmt: &Stmt) -> Result<(), SemanticError> {
190 match stmt {
191 Stmt::Let { name, ty, value } => {
192 let val_ty = self.infer_expr_type(value)?;
193 let type_ok = val_ty == *ty
194 || matches!(value, Expr::Cast { .. })
195 || (matches!(ty, Type::Named(_)) && val_ty == Type::Trit)
196 || (matches!(ty, Type::TritTensor { .. }) && matches!(val_ty, Type::TritTensor { .. }))
197 || (*ty == Type::AgentRef && val_ty == Type::AgentRef);
198 if !type_ok {
199 return Err(SemanticError::TypeMismatch { expected: ty.clone(), found: val_ty });
200 }
201 self.scopes.last_mut().unwrap().insert(name.clone(), ty.clone());
202 Ok(())
203 }
204
205 Stmt::Return(expr) => {
206 let found = self.infer_expr_type(expr)?;
207 if let (Some(fn_name), Some(expected)) = (&self.current_fn_name, &self.current_fn_return) {
208 let ok = found == *expected
210 || matches!(expr, Expr::Cast { .. })
211 || (matches!(expected, Type::TritTensor { .. }) && matches!(found, Type::TritTensor { .. }))
212 || (matches!(expected, Type::Named(_)) && found == Type::Trit);
213 if !ok {
214 return Err(SemanticError::ReturnTypeMismatch {
215 function: fn_name.clone(),
216 expected: expected.clone(),
217 found,
218 });
219 }
220 }
221 Ok(())
222 }
223
224 Stmt::IfTernary { condition, on_pos, on_zero, on_neg } => {
225 let cond_ty = self.infer_expr_type(condition)?;
226 if cond_ty != Type::Trit {
227 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
228 }
229 self.check_stmt(on_pos)?;
230 self.check_stmt(on_zero)?;
231 self.check_stmt(on_neg)?;
232 Ok(())
233 }
234
235 Stmt::Match { condition, arms } => {
236 let cond_ty = self.infer_expr_type(condition)?;
237 if cond_ty != Type::Trit {
238 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
239 }
240 for (_val, arm_stmt) in arms {
241 self.check_stmt(arm_stmt)?;
242 }
243 Ok(())
244 }
245
246 Stmt::Block(stmts) => {
247 self.scopes.push(std::collections::HashMap::new());
248 for s in stmts { self.check_stmt(s)?; }
249 self.scopes.pop();
250 Ok(())
251 }
252
253 Stmt::Decorated { stmt, .. } => self.check_stmt(stmt),
254
255 Stmt::Expr(expr) => { self.infer_expr_type(expr)?; Ok(()) }
256
257 Stmt::ForIn { var, iter, body } => {
258 self.infer_expr_type(iter)?;
259 self.scopes.push(std::collections::HashMap::new());
260 self.scopes.last_mut().unwrap().insert(var.clone(), Type::Trit);
261 self.check_stmt(body)?;
262 self.scopes.pop();
263 Ok(())
264 }
265
266 Stmt::WhileTernary { condition, on_pos, on_zero, on_neg } => {
267 let cond_ty = self.infer_expr_type(condition)?;
268 if cond_ty != Type::Trit {
269 return Err(SemanticError::TypeMismatch { expected: Type::Trit, found: cond_ty });
270 }
271 self.check_stmt(on_pos)?;
272 self.check_stmt(on_zero)?;
273 self.check_stmt(on_neg)?;
274 Ok(())
275 }
276
277 Stmt::Loop { body } => self.check_stmt(body),
278 Stmt::Break => Ok(()),
279 Stmt::Continue => Ok(()),
280 Stmt::Use { .. } => Ok(()),
281
282 Stmt::Send { target, message } => {
283 self.infer_expr_type(target)?;
284 self.infer_expr_type(message)?;
285 Ok(())
286 }
287
288 Stmt::FieldSet { object, field, value } => {
289 let obj_ty = self.lookup_var(object)?;
290 if let Type::Named(struct_name) = obj_ty {
291 let field_ty = self.lookup_field(&struct_name, field)?;
292 let val_ty = self.infer_expr_type(value)?;
293 if val_ty != field_ty {
294 return Err(SemanticError::TypeMismatch { expected: field_ty, found: val_ty });
295 }
296 } else {
297 self.infer_expr_type(value)?;
298 }
299 Ok(())
300 }
301 }
302 }
303
304 fn infer_expr_type(&self, expr: &Expr) -> Result<Type, SemanticError> {
307 match expr {
308 Expr::TritLiteral(_) => Ok(Type::Trit),
309 Expr::IntLiteral(_) => Ok(Type::Int),
310 Expr::StringLiteral(_) => Ok(Type::String),
311 Expr::Ident(name) => self.lookup_var(name),
312
313 Expr::BinaryOp { lhs, rhs, .. } => {
314 let l = self.infer_expr_type(lhs)?;
315 let r = self.infer_expr_type(rhs)?;
316 if l != r {
317 return Err(SemanticError::TypeMismatch { expected: l, found: r });
318 }
319 Ok(l)
320 }
321
322 Expr::UnaryOp { expr, .. } => self.infer_expr_type(expr),
323
324 Expr::Call { callee, args } => {
325 let sig = self.func_signatures.get(callee.as_str())
326 .ok_or_else(|| SemanticError::UndefinedFunction(callee.clone()))?
327 .clone();
328
329 if let Some(param_types) = &sig.params {
331 if args.len() != param_types.len() {
332 return Err(SemanticError::ArgCountMismatch {
333 function: callee.clone(),
334 expected: param_types.len(),
335 found: args.len(),
336 });
337 }
338 for (i, (arg, expected_ty)) in args.iter().zip(param_types.iter()).enumerate() {
339 let found_ty = self.infer_expr_type(arg)?;
340 let ok = found_ty == *expected_ty
342 || matches!(arg, Expr::Cast { .. })
343 || (matches!(expected_ty, Type::TritTensor { .. })
344 && matches!(found_ty, Type::TritTensor { .. }))
345 || (matches!(expected_ty, Type::Named(_)) && found_ty == Type::Trit);
346 if !ok {
347 return Err(SemanticError::ArgTypeMismatch {
348 function: callee.clone(),
349 param_index: i,
350 expected: expected_ty.clone(),
351 found: found_ty,
352 });
353 }
354 }
355 } else {
356 for arg in args { self.infer_expr_type(arg)?; }
358 }
359
360 Ok(sig.return_type)
361 }
362
363 Expr::Cast { ty, .. } => Ok(ty.clone()),
364 Expr::Spawn { .. } => Ok(Type::AgentRef),
365 Expr::Await { .. } => Ok(Type::Trit),
366 Expr::NodeId => Ok(Type::String),
367
368 Expr::FieldAccess { object, field } => {
369 let obj_ty = self.infer_expr_type(object)?;
370 if let Type::Named(struct_name) = obj_ty {
371 self.lookup_field(&struct_name, field)
372 } else {
373 Ok(Type::Trit)
374 }
375 }
376 }
377 }
378
379 fn lookup_var(&self, name: &str) -> Result<Type, SemanticError> {
382 for scope in self.scopes.iter().rev() {
383 if let Some(ty) = scope.get(name) { return Ok(ty.clone()); }
384 }
385 Err(SemanticError::UndefinedVariable(name.to_string()))
386 }
387
388 fn lookup_field(&self, struct_name: &str, field: &str) -> Result<Type, SemanticError> {
389 let fields = self.struct_defs.get(struct_name)
390 .ok_or_else(|| SemanticError::UndefinedStruct(struct_name.to_string()))?;
391 fields.iter()
392 .find(|(f, _)| f == field)
393 .map(|(_, ty)| ty.clone())
394 .ok_or_else(|| SemanticError::UndefinedField {
395 struct_name: struct_name.to_string(),
396 field: field.to_string(),
397 })
398 }
399}
400
401#[cfg(test)]
404mod tests {
405 use super::*;
406 use crate::parser::Parser;
407
408 fn check(src: &str) -> Result<(), SemanticError> {
409 let mut parser = Parser::new(src);
410 let prog = parser.parse_program().expect("parse failed");
411 let mut analyzer = SemanticAnalyzer::new();
412 analyzer.check_program(&prog)
413 }
414
415 fn check_ok(src: &str) {
416 assert!(check(src).is_ok(), "expected ok, got: {:?}", check(src));
417 }
418
419 fn check_err(src: &str) {
420 assert!(check(src).is_err(), "expected error but check passed");
421 }
422
423 #[test]
426 fn test_return_correct_type() {
427 check_ok("fn f() -> trit { return 1; }");
428 }
429
430 #[test]
431 fn test_return_wrong_type_caught() {
432 check_err("fn f() -> trit { let x: int = 42; return x; }");
434 }
435
436 #[test]
437 fn test_return_trit_in_trit_fn() {
438 check_ok("fn decide(a: trit, b: trit) -> trit { return consensus(a, b); }");
439 }
440
441 #[test]
444 fn test_call_correct_arity() {
445 check_ok("fn f() -> trit { return consensus(1, -1); }");
446 }
447
448 #[test]
449 fn test_call_too_few_args_caught() {
450 check_err("fn f() -> trit { return consensus(1); }");
451 }
452
453 #[test]
454 fn test_call_too_many_args_caught() {
455 check_err("fn f() -> trit { return invert(1, 1); }");
456 }
457
458 #[test]
461 fn test_call_wrong_arg_type_caught() {
462 check_err("fn f() -> trit { let x: int = 42; return invert(x); }");
464 }
465
466 #[test]
467 fn test_call_correct_arg_type() {
468 check_ok("fn f(a: trit) -> trit { return invert(a); }");
469 }
470
471 #[test]
474 fn test_undefined_function_caught() {
475 check_err("fn f() -> trit { return doesnt_exist(1); }");
476 }
477
478 #[test]
481 fn test_user_fn_return_type_registered() {
482 check_ok("fn helper(a: trit) -> trit { return invert(a); } fn main() -> trit { return helper(1); }");
483 }
484
485 #[test]
486 fn test_user_fn_wrong_return_caught() {
487 check_err("fn helper(a: trit) -> trit { let x: int = 1; return x; }");
488 }
489
490 #[test]
493 fn test_undefined_variable_caught() {
494 check_err("fn f() -> trit { return ghost_var; }");
495 }
496
497 #[test]
498 fn test_defined_variable_ok() {
499 check_ok("fn f() -> trit { let x: trit = 1; return x; }");
500 }
501
502 #[test]
505 fn test_struct_field_access_ok() {
506 check_ok("struct S { val: trit } fn f(s: S) -> trit { return s.val; }");
507 }
508}