Skip to main content

truthlinked_axiom_compiler/
typeck.rs

1//! Type checker for .cell AST.
2//!
3//! Ownership / borrow rules enforced here:
4//!
5//! 1. Every `owned` param must be consumed **exactly once** - passed to a token op,
6//!    cross-cell call, or returned.
7//! 2. Consumption in only one branch of an `if` is an error (branch asymmetry).
8//! 3. Consumption inside a loop body is an error (would consume N times).
9//! 4. Aliasing an owned param through a `let` binding propagates ownership -
10//!    the alias must be consumed instead of the original.
11//! 5. Storage reads are never consuming positions - a storage slot is not owned.
12
13use crate::ast::*;
14use std::collections::HashMap;
15
16#[derive(Debug, thiserror::Error)]
17#[error("type error: {0}")]
18pub struct TypeError(pub String);
19
20fn e(msg: impl Into<String>) -> TypeError {
21    TypeError(msg.into())
22}
23
24pub struct TypeChecker<'a> {
25    cell: &'a CellDef,
26    errors: HashMap<String, ()>,
27    storage: HashMap<String, Type>,
28    structs: HashMap<String, Vec<FieldDef>>,
29}
30
31impl<'a> TypeChecker<'a> {
32    pub fn new(cell: &'a CellDef) -> Self {
33        let errors = cell.errors.iter().map(|e| (e.name.clone(), ())).collect();
34        let storage = cell
35            .storage
36            .iter()
37            .map(|s| (s.name.clone(), s.ty.clone()))
38            .collect();
39        let structs = cell
40            .structs
41            .iter()
42            .map(|s| (s.name.clone(), s.fields.clone()))
43            .collect();
44        Self {
45            cell,
46            errors,
47            storage,
48            structs,
49        }
50    }
51
52    pub fn check(&self) -> Result<(), TypeError> {
53        if let Some(init) = &self.cell.init {
54            self.check_body(&init.body, &self.params_scope(&init.params), None)?;
55            self.check_ownership(&init.body, &init.params)?;
56        }
57        for f in &self.cell.fns {
58            self.check_body(&f.body, &self.params_scope(&f.params), f.ret.as_deref())?;
59            self.check_ownership(&f.body, &f.params)?;
60        }
61        Ok(())
62    }
63
64    /// Ownership checker: every `owned` param must be consumed exactly once.
65    fn check_ownership(&self, stmts: &[Stmt], params: &[Param]) -> Result<(), TypeError> {
66        let owned: Vec<&str> = params
67            .iter()
68            .filter(|p| p.owned)
69            .map(|p| p.name.as_str())
70            .collect();
71        if owned.is_empty() {
72            return Ok(());
73        }
74
75        // consumed[name] = count. Special sentinel: 2 = branch asymmetry error.
76        let mut consumed: HashMap<&str, usize> = owned.iter().map(|n| (*n, 0)).collect();
77        // aliases: local let-bindings that alias an owned param
78        let mut aliases: HashMap<String, String> = HashMap::new();
79        self.count_consumptions(stmts, &mut consumed, &mut aliases, false)?;
80
81        for name in &owned {
82            let count = consumed[name];
83            if count == 0 {
84                // Check if it was consumed via an alias
85                let alias_consumed = aliases.iter().any(|(alias, src)| {
86                    src.as_str() == *name && consumed.get(alias.as_str()).copied().unwrap_or(0) > 0
87                });
88                if !alias_consumed {
89                    return Err(e(format!(
90                        "owned parameter '{}' is never consumed - must be passed to a token op, call, or returned",
91                        name
92                    )));
93                }
94            } else if count == 2 {
95                return Err(e(format!(
96                    "owned parameter '{}' is consumed in only one branch of an if - must be consumed in both branches or neither",
97                    name
98                )));
99            } else if count > 1 {
100                return Err(e(format!(
101                    "owned parameter '{}' is consumed {} times - double-spend detected",
102                    name, count
103                )));
104            }
105        }
106        Ok(())
107    }
108
109    fn count_consumptions<'b>(
110        &self,
111        stmts: &[Stmt],
112        consumed: &mut HashMap<&'b str, usize>,
113        aliases: &mut HashMap<String, String>,
114        in_loop: bool,
115    ) -> Result<(), TypeError>
116    where
117        'a: 'b,
118    {
119        for stmt in stmts {
120            match stmt {
121                Stmt::Let { name, expr, .. } => {
122                    // Track aliasing: `let x = owned_param` propagates ownership to x
123                    if let Expr::Var(src) = expr {
124                        if consumed.contains_key(src.as_str()) {
125                            aliases.insert(name.clone(), src.clone());
126                        }
127                    }
128                    self.count_expr_consumptions(expr, consumed, aliases, in_loop)?;
129                }
130                Stmt::Assign { expr, .. } => {
131                    self.count_expr_consumptions(expr, consumed, aliases, in_loop)?
132                }
133                Stmt::AssignAdd { expr, .. }
134                | Stmt::AssignSub { expr, .. }
135                | Stmt::AssignMul { expr, .. }
136                | Stmt::AssignDiv { expr, .. } => {
137                    self.count_expr_consumptions(expr, consumed, aliases, in_loop)?;
138                }
139                Stmt::Return { exprs } => {
140                    for ex in exprs {
141                        self.mark_if_owned(ex, consumed, aliases, in_loop)?;
142                        self.count_expr_consumptions(ex, consumed, aliases, in_loop)?;
143                    }
144                }
145                Stmt::If { cond, then, else_ } => {
146                    self.count_expr_consumptions(cond, consumed, aliases, in_loop)?;
147                    let mut then_c = consumed.clone();
148                    let mut else_c = consumed.clone();
149                    let mut then_a = aliases.clone();
150                    let mut else_a = aliases.clone();
151                    self.count_consumptions(then, &mut then_c, &mut then_a, in_loop)?;
152                    self.count_consumptions(else_, &mut else_c, &mut else_a, in_loop)?;
153                    for (k, v) in consumed.iter_mut() {
154                        let in_then = then_c[k];
155                        let in_else = else_c[k];
156                        if in_then != in_else {
157                            *v += 2; // asymmetry sentinel
158                        } else {
159                            *v += in_then;
160                        }
161                    }
162                }
163                Stmt::While { cond, body } => {
164                    self.count_expr_consumptions(cond, consumed, aliases, in_loop)?;
165                    let mut loop_c = consumed.clone();
166                    let mut loop_a = aliases.clone();
167                    self.count_consumptions(body, &mut loop_c, &mut loop_a, true)?;
168                    for (k, v) in consumed.iter_mut() {
169                        if loop_c[k] > *v {
170                            return Err(e(format!(
171                                "owned parameter '{}' consumed inside a loop - would be consumed multiple times",
172                                k
173                            )));
174                        }
175                    }
176                }
177                Stmt::Loop { body } => {
178                    let mut loop_c = consumed.clone();
179                    let mut loop_a = aliases.clone();
180                    self.count_consumptions(body, &mut loop_c, &mut loop_a, true)?;
181                    for (k, v) in consumed.iter_mut() {
182                        if loop_c[k] > *v {
183                            return Err(e(format!(
184                                "owned parameter '{}' consumed inside a loop - would be consumed multiple times",
185                                k
186                            )));
187                        }
188                    }
189                }
190                Stmt::For {
191                    start, end, body, ..
192                } => {
193                    self.count_expr_consumptions(start, consumed, aliases, in_loop)?;
194                    self.count_expr_consumptions(end, consumed, aliases, in_loop)?;
195                    let mut loop_c = consumed.clone();
196                    let mut loop_a = aliases.clone();
197                    self.count_consumptions(body, &mut loop_c, &mut loop_a, true)?;
198                    for (k, v) in consumed.iter_mut() {
199                        if loop_c[k] > *v {
200                            return Err(e(format!(
201                                "owned parameter '{}' consumed inside a for loop - would be consumed multiple times",
202                                k
203                            )));
204                        }
205                    }
206                }
207                Stmt::Emit { fields, .. } => {
208                    for (_, ex) in fields {
209                        self.count_expr_consumptions(ex, consumed, aliases, in_loop)?;
210                    }
211                }
212                Stmt::Require { expr } => {
213                    self.count_expr_consumptions(expr, consumed, aliases, in_loop)?
214                }
215                Stmt::Expr(expr) => {
216                    self.count_expr_consumptions(expr, consumed, aliases, in_loop)?
217                }
218                _ => {}
219            }
220        }
221        Ok(())
222    }
223
224    fn count_expr_consumptions<'b>(
225        &self,
226        expr: &Expr,
227        consumed: &mut HashMap<&'b str, usize>,
228        aliases: &mut HashMap<String, String>,
229        in_loop: bool,
230    ) -> Result<(), TypeError>
231    where
232        'a: 'b,
233    {
234        match expr {
235            Expr::TokenTransfer {
236                token,
237                from,
238                to,
239                amount,
240            } => {
241                self.mark_if_owned(amount, consumed, aliases, in_loop)?;
242                self.count_expr_consumptions(token, consumed, aliases, in_loop)?;
243                self.count_expr_consumptions(from, consumed, aliases, in_loop)?;
244                self.count_expr_consumptions(to, consumed, aliases, in_loop)?;
245            }
246            Expr::TokenMint { token, to, amount } => {
247                self.mark_if_owned(amount, consumed, aliases, in_loop)?;
248                self.count_expr_consumptions(token, consumed, aliases, in_loop)?;
249                self.count_expr_consumptions(to, consumed, aliases, in_loop)?;
250            }
251            Expr::TokenBurn {
252                token,
253                owner,
254                amount,
255            } => {
256                self.mark_if_owned(amount, consumed, aliases, in_loop)?;
257                self.count_expr_consumptions(token, consumed, aliases, in_loop)?;
258                self.count_expr_consumptions(owner, consumed, aliases, in_loop)?;
259            }
260            Expr::CallCell { args, .. } => {
261                for a in args {
262                    self.mark_if_owned(a, consumed, aliases, in_loop)?;
263                }
264            }
265            Expr::Call { args, .. } => {
266                for a in args {
267                    self.mark_if_owned(a, consumed, aliases, in_loop)?;
268                }
269            }
270            Expr::Bin { lhs, rhs, .. } => {
271                self.count_expr_consumptions(lhs, consumed, aliases, in_loop)?;
272                self.count_expr_consumptions(rhs, consumed, aliases, in_loop)?;
273            }
274            Expr::Not(inner) | Expr::Hash(inner) => {
275                self.count_expr_consumptions(inner, consumed, aliases, in_loop)?
276            }
277            Expr::Index { base, key } => {
278                self.count_expr_consumptions(base, consumed, aliases, in_loop)?;
279                self.count_expr_consumptions(key, consumed, aliases, in_loop)?;
280            }
281            _ => {}
282        }
283        Ok(())
284    }
285
286    fn mark_if_owned<'b>(
287        &self,
288        expr: &Expr,
289        consumed: &mut HashMap<&'b str, usize>,
290        aliases: &HashMap<String, String>,
291        in_loop: bool,
292    ) -> Result<(), TypeError>
293    where
294        'a: 'b,
295    {
296        if let Expr::Var(name) = expr {
297            // Direct owned param
298            if let Some(count) = consumed.get_mut(name.as_str()) {
299                if in_loop {
300                    return Err(e(format!(
301                        "owned parameter '{}' consumed inside a loop - would be consumed multiple times",
302                        name
303                    )));
304                }
305                *count += 1;
306                return Ok(());
307            }
308            // Alias of an owned param
309            if let Some(src) = aliases.get(name.as_str()) {
310                if let Some(count) = consumed.get_mut(src.as_str()) {
311                    if in_loop {
312                        return Err(e(format!(
313                            "owned parameter '{}' (via alias '{}') consumed inside a loop",
314                            src, name
315                        )));
316                    }
317                    *count += 1;
318                }
319            }
320        }
321        Ok(())
322    }
323
324    fn params_scope(&self, params: &[Param]) -> HashMap<String, Type> {
325        params
326            .iter()
327            .map(|p| (p.name.clone(), p.ty.clone()))
328            .collect()
329    }
330
331    fn check_body(
332        &self,
333        stmts: &[Stmt],
334        scope: &HashMap<String, Type>,
335        ret_ty: Option<&[Type]>,
336    ) -> Result<(), TypeError> {
337        let mut scope = scope.clone();
338        for stmt in stmts {
339            self.check_stmt(stmt, &mut scope, ret_ty)?;
340        }
341        Ok(())
342    }
343
344    fn lvalue_type(&self, lv: &LValue, scope: &HashMap<String, Type>) -> Result<Type, TypeError> {
345        match lv {
346            LValue::Var(name) => self
347                .storage
348                .get(name)
349                .or_else(|| scope.get(name))
350                .cloned()
351                .ok_or_else(|| e(format!("undefined variable '{}'", name))),
352            LValue::Index { base, .. } => {
353                let base_ty = self
354                    .storage
355                    .get(base)
356                    .or_else(|| scope.get(base))
357                    .cloned()
358                    .ok_or_else(|| e(format!("undefined variable '{}'", base)))?;
359                match base_ty {
360                    Type::Mapping(_, v) => Ok(*v),
361                    Type::Array(v) => Ok(*v),
362                    _ => Err(e(format!("'{}' is not indexable", base))),
363                }
364            }
365            LValue::Field { base, field } => {
366                let base_ty = self
367                    .storage
368                    .get(base)
369                    .or_else(|| scope.get(base))
370                    .cloned()
371                    .ok_or_else(|| e(format!("undefined variable '{}'", base)))?;
372                match base_ty {
373                    Type::Struct(name) => {
374                        let fields = self
375                            .structs
376                            .get(&name)
377                            .ok_or_else(|| e(format!("unknown struct '{}'", name)))?;
378                        fields
379                            .iter()
380                            .find(|f| &f.name == field)
381                            .map(|f| f.ty.clone())
382                            .ok_or_else(|| e(format!("struct '{}' has no field '{}'", name, field)))
383                    }
384                    _ => Err(e(format!("'{}' is not a struct", base))),
385                }
386            }
387        }
388    }
389
390    fn check_stmt(
391        &self,
392        stmt: &Stmt,
393        scope: &mut HashMap<String, Type>,
394        ret_ty: Option<&[Type]>,
395    ) -> Result<(), TypeError> {
396        match stmt {
397            Stmt::Let { name, ty, expr } => {
398                let inferred = self.infer(expr, scope)?;
399                let final_ty = if let Some(declared) = ty {
400                    self.check_assignable(&inferred, declared, name)?;
401                    declared.clone()
402                } else {
403                    inferred
404                };
405                scope.insert(name.clone(), final_ty);
406            }
407            Stmt::Assign { target, expr } => {
408                let rhs_ty = self.infer(expr, scope)?;
409                let target_ty = self.lvalue_type(target, scope)?;
410                let name = match target {
411                    LValue::Var(n) => n.as_str(),
412                    _ => "target",
413                };
414                self.check_assignable(&rhs_ty, &target_ty, name)?;
415            }
416            Stmt::AssignAdd { target, expr }
417            | Stmt::AssignSub { target, expr }
418            | Stmt::AssignMul { target, expr }
419            | Stmt::AssignDiv { target, expr } => {
420                let rhs_ty = self.infer(expr, scope)?;
421                let target_ty = self.lvalue_type(target, scope)?;
422                let name = match target {
423                    LValue::Var(n) => n.as_str(),
424                    _ => "target",
425                };
426                self.check_assignable(&rhs_ty, &target_ty, name)?;
427            }
428            Stmt::Require { expr } => {
429                self.infer(expr, scope)?;
430            }
431            Stmt::Revert { error } => {
432                if !self.errors.contains_key(error) {
433                    return Err(e(format!("revert '{}': error not declared", error)));
434                }
435            }
436            Stmt::Return { exprs } => {
437                if let Some(ret_types) = ret_ty {
438                    if exprs.len() != ret_types.len() {
439                        return Err(e(format!(
440                            "return arity mismatch: expected {} values, got {}",
441                            ret_types.len(),
442                            exprs.len()
443                        )));
444                    }
445                    for (expr, expected) in exprs.iter().zip(ret_types.iter()) {
446                        let ty = self.infer(expr, scope)?;
447                        self.check_assignable(&ty, expected, "return")?;
448                    }
449                }
450            }
451            Stmt::Emit { fields, .. } => {
452                for (_, expr) in fields {
453                    self.infer(expr, scope)?;
454                }
455            }
456            Stmt::If { cond, then, else_ } => {
457                self.infer(cond, scope)?;
458                self.check_body(then, scope, ret_ty)?;
459                self.check_body(else_, scope, ret_ty)?;
460            }
461            Stmt::While { cond, body } => {
462                self.infer(cond, scope)?;
463                self.check_body(body, scope, ret_ty)?;
464            }
465            Stmt::For {
466                var,
467                start,
468                end,
469                body,
470            } => {
471                self.infer(start, scope)?;
472                self.infer(end, scope)?;
473                let mut inner = scope.clone();
474                inner.insert(var.clone(), Type::U64);
475                self.check_body(body, &inner, ret_ty)?;
476            }
477            Stmt::Loop { body } => {
478                self.check_body(body, scope, ret_ty)?;
479            }
480            Stmt::Break | Stmt::Continue => {}
481            Stmt::Expr(expr) => {
482                self.infer(expr, scope)?;
483            }
484        }
485        Ok(())
486    }
487
488    fn check_assignable(&self, from: &Type, to: &Type, target: &str) -> Result<(), TypeError> {
489        if from == to {
490            return Ok(());
491        }
492        if from == &Type::U256 && matches!(to, Type::U64 | Type::U128 | Type::U256) {
493            return Ok(());
494        }
495        if from == &Type::U64 && matches!(to, Type::U128 | Type::U256) {
496            return Ok(());
497        }
498        if from == &Type::U128 && to == &Type::U256 {
499            return Ok(());
500        }
501        // true/false are Int(1)/Int(0) → U64; allow assigning to bool and vice versa
502        if matches!(from, Type::U64 | Type::U256) && to == &Type::Bool {
503            return Ok(());
504        }
505        if from == &Type::Bool && matches!(to, Type::U64 | Type::U128 | Type::U256) {
506            return Ok(());
507        }
508        Err(e(format!(
509            "type mismatch for '{}': cannot assign {:?} to {:?}",
510            target, from, to
511        )))
512    }
513
514    fn infer(&self, expr: &Expr, scope: &HashMap<String, Type>) -> Result<Type, TypeError> {
515        match expr {
516            Expr::Int(v) => {
517                if *v <= u64::MAX as u128 { Ok(Type::U64) }
518                else if *v <= u128::MAX   { Ok(Type::U128) }
519                else                      { Ok(Type::U256) }
520            }
521            Expr::Bytes(_)     => Ok(Type::U256),
522            Expr::Caller       => Ok(Type::Address),
523            Expr::Owner        => Ok(Type::Address),
524            Expr::SelfAddr     => Ok(Type::Address),
525            Expr::Height       => Ok(Type::U64),
526            Expr::Timestamp    => Ok(Type::U64),
527            Expr::Value        => Ok(Type::U128),
528            Expr::Var(name) => {
529                if let Some(ty) = scope.get(name) { return Ok(ty.clone()); }
530                if let Some(ty) = self.storage.get(name) { return Ok(ty.clone()); }
531                if name == "value" { return Ok(Type::U128); }
532                Err(e(format!("undefined variable '{}'", name)))
533            }
534            Expr::Index { base, key } => {
535                let base_ty = self.infer(base, scope)?;
536                self.infer(key, scope)?;
537                match base_ty {
538                    Type::Mapping(_, v) => Ok(*v),
539                    Type::Array(v)      => Ok(*v),
540                    _ => Err(e("index operator requires mapping or array")),
541                }
542            }
543            Expr::Field { base, field } => {
544                let base_ty = self.infer(base, scope)?;
545                match base_ty {
546                    Type::Struct(name) => {
547                        let fields = self.structs.get(&name)
548                            .ok_or_else(|| e(format!("unknown struct '{}'", name)))?;
549                        fields.iter().find(|f| &f.name == field)
550                            .map(|f| f.ty.clone())
551                            .ok_or_else(|| e(format!("struct '{}' has no field '{}'", name, field)))
552                    }
553                    _ => Err(e(format!("field access on non-struct type"))),
554                }
555            }
556            Expr::Hash(_)          => Ok(Type::U256),
557            Expr::Not(inner)       => { self.infer(inner, scope)?; Ok(Type::Bool) }
558            Expr::Bin { lhs, rhs, op } => {
559                let lt = self.infer(lhs, scope)?;
560                let _rt = self.infer(rhs, scope)?;
561                match op {
562                    BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le |
563                    BinOp::Gt | BinOp::Ge | BinOp::LogicAnd | BinOp::LogicOr => Ok(Type::Bool),
564                    _ => Ok(lt),
565                }
566            }
567            Expr::TokenBalance { .. }  => Ok(Type::U128),
568            Expr::TokenTransfer { .. } | Expr::TokenMint { .. } | Expr::TokenBurn { .. } => Ok(Type::Bool),
569            Expr::AccordRequest { .. } | Expr::AccordRead { .. } => Ok(Type::U256),
570            Expr::CallCell { ret, method, .. } => {
571                match ret {
572                    Some(ty) => Ok(ty.clone()),
573                    None => Err(e(format!(
574                        "cross-cell call to '{}' has no declared return type - use `call x.{}(...) -> type`",
575                        method, method
576                    ))),
577                }
578            }
579            Expr::Call { name, .. } => {
580                if let Some(f) = self.cell.fns.iter().find(|f| &f.name == name) {
581                    match &f.ret {
582                        Some(types) if types.len() == 1 => Ok(types[0].clone()),
583                        Some(types) if types.len() > 1  => Ok(Type::U256), // tuple - caller unpacks
584                        _ => Ok(Type::Bool),
585                    }
586                } else {
587                    Err(e(format!("call to undefined function '{}'", name)))
588                }
589            }
590        }
591    }
592}
593
594pub fn check(cell: &CellDef) -> Result<(), TypeError> {
595    TypeChecker::new(cell).check()
596}