1use crate::ast::*;
14use std::collections::HashMap;
15
16#[derive(Debug, thiserror::Error)]
17#[error("type error: {0}")]
18pub struct TypeError(pub String);
19
20fn e(msg: impl Into<String>) -> TypeError {
21 TypeError(msg.into())
22}
23
24pub struct TypeChecker<'a> {
25 cell: &'a CellDef,
26 errors: HashMap<String, ()>,
27 storage: HashMap<String, Type>,
28 structs: HashMap<String, Vec<FieldDef>>,
29}
30
31impl<'a> TypeChecker<'a> {
32 pub fn new(cell: &'a CellDef) -> Self {
33 let errors = cell.errors.iter().map(|e| (e.name.clone(), ())).collect();
34 let storage = cell
35 .storage
36 .iter()
37 .map(|s| (s.name.clone(), s.ty.clone()))
38 .collect();
39 let structs = cell
40 .structs
41 .iter()
42 .map(|s| (s.name.clone(), s.fields.clone()))
43 .collect();
44 Self {
45 cell,
46 errors,
47 storage,
48 structs,
49 }
50 }
51
52 pub fn check(&self) -> Result<(), TypeError> {
53 if let Some(init) = &self.cell.init {
54 self.check_body(&init.body, &self.params_scope(&init.params), None)?;
55 self.check_ownership(&init.body, &init.params)?;
56 }
57 for f in &self.cell.fns {
58 self.check_body(&f.body, &self.params_scope(&f.params), f.ret.as_deref())?;
59 self.check_ownership(&f.body, &f.params)?;
60 }
61 Ok(())
62 }
63
64 fn check_ownership(&self, stmts: &[Stmt], params: &[Param]) -> Result<(), TypeError> {
66 let owned: Vec<&str> = params
67 .iter()
68 .filter(|p| p.owned)
69 .map(|p| p.name.as_str())
70 .collect();
71 if owned.is_empty() {
72 return Ok(());
73 }
74
75 let mut consumed: HashMap<&str, usize> = owned.iter().map(|n| (*n, 0)).collect();
77 let mut aliases: HashMap<String, String> = HashMap::new();
79 self.count_consumptions(stmts, &mut consumed, &mut aliases, false)?;
80
81 for name in &owned {
82 let count = consumed[name];
83 if count == 0 {
84 let alias_consumed = aliases.iter().any(|(alias, src)| {
86 src.as_str() == *name && consumed.get(alias.as_str()).copied().unwrap_or(0) > 0
87 });
88 if !alias_consumed {
89 return Err(e(format!(
90 "owned parameter '{}' is never consumed - must be passed to a token op, call, or returned",
91 name
92 )));
93 }
94 } else if count == 2 {
95 return Err(e(format!(
96 "owned parameter '{}' is consumed in only one branch of an if - must be consumed in both branches or neither",
97 name
98 )));
99 } else if count > 1 {
100 return Err(e(format!(
101 "owned parameter '{}' is consumed {} times - double-spend detected",
102 name, count
103 )));
104 }
105 }
106 Ok(())
107 }
108
109 fn count_consumptions<'b>(
110 &self,
111 stmts: &[Stmt],
112 consumed: &mut HashMap<&'b str, usize>,
113 aliases: &mut HashMap<String, String>,
114 in_loop: bool,
115 ) -> Result<(), TypeError>
116 where
117 'a: 'b,
118 {
119 for stmt in stmts {
120 match stmt {
121 Stmt::Let { name, expr, .. } => {
122 if let Expr::Var(src) = expr {
124 if consumed.contains_key(src.as_str()) {
125 aliases.insert(name.clone(), src.clone());
126 }
127 }
128 self.count_expr_consumptions(expr, consumed, aliases, in_loop)?;
129 }
130 Stmt::Assign { expr, .. } => {
131 self.count_expr_consumptions(expr, consumed, aliases, in_loop)?
132 }
133 Stmt::AssignAdd { expr, .. }
134 | Stmt::AssignSub { expr, .. }
135 | Stmt::AssignMul { expr, .. }
136 | Stmt::AssignDiv { expr, .. } => {
137 self.count_expr_consumptions(expr, consumed, aliases, in_loop)?;
138 }
139 Stmt::Return { exprs } => {
140 for ex in exprs {
141 self.mark_if_owned(ex, consumed, aliases, in_loop)?;
142 self.count_expr_consumptions(ex, consumed, aliases, in_loop)?;
143 }
144 }
145 Stmt::If { cond, then, else_ } => {
146 self.count_expr_consumptions(cond, consumed, aliases, in_loop)?;
147 let mut then_c = consumed.clone();
148 let mut else_c = consumed.clone();
149 let mut then_a = aliases.clone();
150 let mut else_a = aliases.clone();
151 self.count_consumptions(then, &mut then_c, &mut then_a, in_loop)?;
152 self.count_consumptions(else_, &mut else_c, &mut else_a, in_loop)?;
153 for (k, v) in consumed.iter_mut() {
154 let in_then = then_c[k];
155 let in_else = else_c[k];
156 if in_then != in_else {
157 *v += 2; } else {
159 *v += in_then;
160 }
161 }
162 }
163 Stmt::While { cond, body } => {
164 self.count_expr_consumptions(cond, consumed, aliases, in_loop)?;
165 let mut loop_c = consumed.clone();
166 let mut loop_a = aliases.clone();
167 self.count_consumptions(body, &mut loop_c, &mut loop_a, true)?;
168 for (k, v) in consumed.iter_mut() {
169 if loop_c[k] > *v {
170 return Err(e(format!(
171 "owned parameter '{}' consumed inside a loop - would be consumed multiple times",
172 k
173 )));
174 }
175 }
176 }
177 Stmt::Loop { body } => {
178 let mut loop_c = consumed.clone();
179 let mut loop_a = aliases.clone();
180 self.count_consumptions(body, &mut loop_c, &mut loop_a, true)?;
181 for (k, v) in consumed.iter_mut() {
182 if loop_c[k] > *v {
183 return Err(e(format!(
184 "owned parameter '{}' consumed inside a loop - would be consumed multiple times",
185 k
186 )));
187 }
188 }
189 }
190 Stmt::For {
191 start, end, body, ..
192 } => {
193 self.count_expr_consumptions(start, consumed, aliases, in_loop)?;
194 self.count_expr_consumptions(end, consumed, aliases, in_loop)?;
195 let mut loop_c = consumed.clone();
196 let mut loop_a = aliases.clone();
197 self.count_consumptions(body, &mut loop_c, &mut loop_a, true)?;
198 for (k, v) in consumed.iter_mut() {
199 if loop_c[k] > *v {
200 return Err(e(format!(
201 "owned parameter '{}' consumed inside a for loop - would be consumed multiple times",
202 k
203 )));
204 }
205 }
206 }
207 Stmt::Emit { fields, .. } => {
208 for (_, ex) in fields {
209 self.count_expr_consumptions(ex, consumed, aliases, in_loop)?;
210 }
211 }
212 Stmt::Require { expr } => {
213 self.count_expr_consumptions(expr, consumed, aliases, in_loop)?
214 }
215 Stmt::Expr(expr) => {
216 self.count_expr_consumptions(expr, consumed, aliases, in_loop)?
217 }
218 _ => {}
219 }
220 }
221 Ok(())
222 }
223
224 fn count_expr_consumptions<'b>(
225 &self,
226 expr: &Expr,
227 consumed: &mut HashMap<&'b str, usize>,
228 aliases: &mut HashMap<String, String>,
229 in_loop: bool,
230 ) -> Result<(), TypeError>
231 where
232 'a: 'b,
233 {
234 match expr {
235 Expr::TokenTransfer {
236 token,
237 from,
238 to,
239 amount,
240 } => {
241 self.mark_if_owned(amount, consumed, aliases, in_loop)?;
242 self.count_expr_consumptions(token, consumed, aliases, in_loop)?;
243 self.count_expr_consumptions(from, consumed, aliases, in_loop)?;
244 self.count_expr_consumptions(to, consumed, aliases, in_loop)?;
245 }
246 Expr::TokenMint { token, to, amount } => {
247 self.mark_if_owned(amount, consumed, aliases, in_loop)?;
248 self.count_expr_consumptions(token, consumed, aliases, in_loop)?;
249 self.count_expr_consumptions(to, consumed, aliases, in_loop)?;
250 }
251 Expr::TokenBurn {
252 token,
253 owner,
254 amount,
255 } => {
256 self.mark_if_owned(amount, consumed, aliases, in_loop)?;
257 self.count_expr_consumptions(token, consumed, aliases, in_loop)?;
258 self.count_expr_consumptions(owner, consumed, aliases, in_loop)?;
259 }
260 Expr::CallCell { args, .. } => {
261 for a in args {
262 self.mark_if_owned(a, consumed, aliases, in_loop)?;
263 }
264 }
265 Expr::Call { args, .. } => {
266 for a in args {
267 self.mark_if_owned(a, consumed, aliases, in_loop)?;
268 }
269 }
270 Expr::Bin { lhs, rhs, .. } => {
271 self.count_expr_consumptions(lhs, consumed, aliases, in_loop)?;
272 self.count_expr_consumptions(rhs, consumed, aliases, in_loop)?;
273 }
274 Expr::Not(inner) | Expr::Hash(inner) => {
275 self.count_expr_consumptions(inner, consumed, aliases, in_loop)?
276 }
277 Expr::Index { base, key } => {
278 self.count_expr_consumptions(base, consumed, aliases, in_loop)?;
279 self.count_expr_consumptions(key, consumed, aliases, in_loop)?;
280 }
281 _ => {}
282 }
283 Ok(())
284 }
285
286 fn mark_if_owned<'b>(
287 &self,
288 expr: &Expr,
289 consumed: &mut HashMap<&'b str, usize>,
290 aliases: &HashMap<String, String>,
291 in_loop: bool,
292 ) -> Result<(), TypeError>
293 where
294 'a: 'b,
295 {
296 if let Expr::Var(name) = expr {
297 if let Some(count) = consumed.get_mut(name.as_str()) {
299 if in_loop {
300 return Err(e(format!(
301 "owned parameter '{}' consumed inside a loop - would be consumed multiple times",
302 name
303 )));
304 }
305 *count += 1;
306 return Ok(());
307 }
308 if let Some(src) = aliases.get(name.as_str()) {
310 if let Some(count) = consumed.get_mut(src.as_str()) {
311 if in_loop {
312 return Err(e(format!(
313 "owned parameter '{}' (via alias '{}') consumed inside a loop",
314 src, name
315 )));
316 }
317 *count += 1;
318 }
319 }
320 }
321 Ok(())
322 }
323
324 fn params_scope(&self, params: &[Param]) -> HashMap<String, Type> {
325 params
326 .iter()
327 .map(|p| (p.name.clone(), p.ty.clone()))
328 .collect()
329 }
330
331 fn check_body(
332 &self,
333 stmts: &[Stmt],
334 scope: &HashMap<String, Type>,
335 ret_ty: Option<&[Type]>,
336 ) -> Result<(), TypeError> {
337 let mut scope = scope.clone();
338 for stmt in stmts {
339 self.check_stmt(stmt, &mut scope, ret_ty)?;
340 }
341 Ok(())
342 }
343
344 fn lvalue_type(&self, lv: &LValue, scope: &HashMap<String, Type>) -> Result<Type, TypeError> {
345 match lv {
346 LValue::Var(name) => self
347 .storage
348 .get(name)
349 .or_else(|| scope.get(name))
350 .cloned()
351 .ok_or_else(|| e(format!("undefined variable '{}'", name))),
352 LValue::Index { base, .. } => {
353 let base_ty = self
354 .storage
355 .get(base)
356 .or_else(|| scope.get(base))
357 .cloned()
358 .ok_or_else(|| e(format!("undefined variable '{}'", base)))?;
359 match base_ty {
360 Type::Mapping(_, v) => Ok(*v),
361 Type::Array(v) => Ok(*v),
362 _ => Err(e(format!("'{}' is not indexable", base))),
363 }
364 }
365 LValue::Field { base, field } => {
366 let base_ty = self
367 .storage
368 .get(base)
369 .or_else(|| scope.get(base))
370 .cloned()
371 .ok_or_else(|| e(format!("undefined variable '{}'", base)))?;
372 match base_ty {
373 Type::Struct(name) => {
374 let fields = self
375 .structs
376 .get(&name)
377 .ok_or_else(|| e(format!("unknown struct '{}'", name)))?;
378 fields
379 .iter()
380 .find(|f| &f.name == field)
381 .map(|f| f.ty.clone())
382 .ok_or_else(|| e(format!("struct '{}' has no field '{}'", name, field)))
383 }
384 _ => Err(e(format!("'{}' is not a struct", base))),
385 }
386 }
387 }
388 }
389
390 fn check_stmt(
391 &self,
392 stmt: &Stmt,
393 scope: &mut HashMap<String, Type>,
394 ret_ty: Option<&[Type]>,
395 ) -> Result<(), TypeError> {
396 match stmt {
397 Stmt::Let { name, ty, expr } => {
398 let inferred = self.infer(expr, scope)?;
399 let final_ty = if let Some(declared) = ty {
400 self.check_assignable(&inferred, declared, name)?;
401 declared.clone()
402 } else {
403 inferred
404 };
405 scope.insert(name.clone(), final_ty);
406 }
407 Stmt::Assign { target, expr } => {
408 let rhs_ty = self.infer(expr, scope)?;
409 let target_ty = self.lvalue_type(target, scope)?;
410 let name = match target {
411 LValue::Var(n) => n.as_str(),
412 _ => "target",
413 };
414 self.check_assignable(&rhs_ty, &target_ty, name)?;
415 }
416 Stmt::AssignAdd { target, expr }
417 | Stmt::AssignSub { target, expr }
418 | Stmt::AssignMul { target, expr }
419 | Stmt::AssignDiv { target, expr } => {
420 let rhs_ty = self.infer(expr, scope)?;
421 let target_ty = self.lvalue_type(target, scope)?;
422 let name = match target {
423 LValue::Var(n) => n.as_str(),
424 _ => "target",
425 };
426 self.check_assignable(&rhs_ty, &target_ty, name)?;
427 }
428 Stmt::Require { expr } => {
429 self.infer(expr, scope)?;
430 }
431 Stmt::Revert { error } => {
432 if !self.errors.contains_key(error) {
433 return Err(e(format!("revert '{}': error not declared", error)));
434 }
435 }
436 Stmt::Return { exprs } => {
437 if let Some(ret_types) = ret_ty {
438 if exprs.len() != ret_types.len() {
439 return Err(e(format!(
440 "return arity mismatch: expected {} values, got {}",
441 ret_types.len(),
442 exprs.len()
443 )));
444 }
445 for (expr, expected) in exprs.iter().zip(ret_types.iter()) {
446 let ty = self.infer(expr, scope)?;
447 self.check_assignable(&ty, expected, "return")?;
448 }
449 }
450 }
451 Stmt::Emit { fields, .. } => {
452 for (_, expr) in fields {
453 self.infer(expr, scope)?;
454 }
455 }
456 Stmt::If { cond, then, else_ } => {
457 self.infer(cond, scope)?;
458 self.check_body(then, scope, ret_ty)?;
459 self.check_body(else_, scope, ret_ty)?;
460 }
461 Stmt::While { cond, body } => {
462 self.infer(cond, scope)?;
463 self.check_body(body, scope, ret_ty)?;
464 }
465 Stmt::For {
466 var,
467 start,
468 end,
469 body,
470 } => {
471 self.infer(start, scope)?;
472 self.infer(end, scope)?;
473 let mut inner = scope.clone();
474 inner.insert(var.clone(), Type::U64);
475 self.check_body(body, &inner, ret_ty)?;
476 }
477 Stmt::Loop { body } => {
478 self.check_body(body, scope, ret_ty)?;
479 }
480 Stmt::Break | Stmt::Continue => {}
481 Stmt::Expr(expr) => {
482 self.infer(expr, scope)?;
483 }
484 }
485 Ok(())
486 }
487
488 fn check_assignable(&self, from: &Type, to: &Type, target: &str) -> Result<(), TypeError> {
489 if from == to {
490 return Ok(());
491 }
492 if from == &Type::U256 && matches!(to, Type::U64 | Type::U128 | Type::U256) {
493 return Ok(());
494 }
495 if from == &Type::U64 && matches!(to, Type::U128 | Type::U256) {
496 return Ok(());
497 }
498 if from == &Type::U128 && to == &Type::U256 {
499 return Ok(());
500 }
501 if matches!(from, Type::U64 | Type::U256) && to == &Type::Bool {
503 return Ok(());
504 }
505 if from == &Type::Bool && matches!(to, Type::U64 | Type::U128 | Type::U256) {
506 return Ok(());
507 }
508 Err(e(format!(
509 "type mismatch for '{}': cannot assign {:?} to {:?}",
510 target, from, to
511 )))
512 }
513
514 fn infer(&self, expr: &Expr, scope: &HashMap<String, Type>) -> Result<Type, TypeError> {
515 match expr {
516 Expr::Int(v) => {
517 if *v <= u64::MAX as u128 { Ok(Type::U64) }
518 else if *v <= u128::MAX { Ok(Type::U128) }
519 else { Ok(Type::U256) }
520 }
521 Expr::Bytes(_) => Ok(Type::U256),
522 Expr::Caller => Ok(Type::Address),
523 Expr::Owner => Ok(Type::Address),
524 Expr::SelfAddr => Ok(Type::Address),
525 Expr::Height => Ok(Type::U64),
526 Expr::Timestamp => Ok(Type::U64),
527 Expr::Value => Ok(Type::U128),
528 Expr::Var(name) => {
529 if let Some(ty) = scope.get(name) { return Ok(ty.clone()); }
530 if let Some(ty) = self.storage.get(name) { return Ok(ty.clone()); }
531 if name == "value" { return Ok(Type::U128); }
532 Err(e(format!("undefined variable '{}'", name)))
533 }
534 Expr::Index { base, key } => {
535 let base_ty = self.infer(base, scope)?;
536 self.infer(key, scope)?;
537 match base_ty {
538 Type::Mapping(_, v) => Ok(*v),
539 Type::Array(v) => Ok(*v),
540 _ => Err(e("index operator requires mapping or array")),
541 }
542 }
543 Expr::Field { base, field } => {
544 let base_ty = self.infer(base, scope)?;
545 match base_ty {
546 Type::Struct(name) => {
547 let fields = self.structs.get(&name)
548 .ok_or_else(|| e(format!("unknown struct '{}'", name)))?;
549 fields.iter().find(|f| &f.name == field)
550 .map(|f| f.ty.clone())
551 .ok_or_else(|| e(format!("struct '{}' has no field '{}'", name, field)))
552 }
553 _ => Err(e(format!("field access on non-struct type"))),
554 }
555 }
556 Expr::Hash(_) => Ok(Type::U256),
557 Expr::Not(inner) => { self.infer(inner, scope)?; Ok(Type::Bool) }
558 Expr::Bin { lhs, rhs, op } => {
559 let lt = self.infer(lhs, scope)?;
560 let _rt = self.infer(rhs, scope)?;
561 match op {
562 BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le |
563 BinOp::Gt | BinOp::Ge | BinOp::LogicAnd | BinOp::LogicOr => Ok(Type::Bool),
564 _ => Ok(lt),
565 }
566 }
567 Expr::TokenBalance { .. } => Ok(Type::U128),
568 Expr::TokenTransfer { .. } | Expr::TokenMint { .. } | Expr::TokenBurn { .. } => Ok(Type::Bool),
569 Expr::AccordRequest { .. } | Expr::AccordRead { .. } => Ok(Type::U256),
570 Expr::CallCell { ret, method, .. } => {
571 match ret {
572 Some(ty) => Ok(ty.clone()),
573 None => Err(e(format!(
574 "cross-cell call to '{}' has no declared return type - use `call x.{}(...) -> type`",
575 method, method
576 ))),
577 }
578 }
579 Expr::Call { name, .. } => {
580 if let Some(f) = self.cell.fns.iter().find(|f| &f.name == name) {
581 match &f.ret {
582 Some(types) if types.len() == 1 => Ok(types[0].clone()),
583 Some(types) if types.len() > 1 => Ok(Type::U256), _ => Ok(Type::Bool),
585 }
586 } else {
587 Err(e(format!("call to undefined function '{}'", name)))
588 }
589 }
590 }
591 }
592}
593
594pub fn check(cell: &CellDef) -> Result<(), TypeError> {
595 TypeChecker::new(cell).check()
596}