Skip to main content

trident/typecheck/
mod.rs

1mod analysis;
2mod block;
3mod builtins;
4mod expr;
5mod resolve;
6mod stmt;
7#[cfg(test)]
8mod tests;
9pub mod types;
10
11use std::collections::{BTreeMap, BTreeSet};
12
13use crate::ast::*;
14use crate::diagnostic::Diagnostic;
15use crate::span::{Span, Spanned};
16use crate::types::{StructTy, Ty};
17
18/// A function signature for type checking.
19#[derive(Clone, Debug)]
20pub(super) struct FnSig {
21    pub(super) params: Vec<(String, Ty)>,
22    pub(super) return_ty: Ty,
23}
24
25/// A generic (size-parameterized) function definition, stored unresolved.
26#[derive(Clone, Debug)]
27pub(super) struct GenericFnDef {
28    /// Size parameter names, e.g. `["N"]`.
29    pub(super) type_params: Vec<String>,
30    /// Parameter types as AST types (may contain `ArraySize::Param`).
31    pub(super) params: Vec<(String, Type)>,
32    /// Return type as AST type (may contain `ArraySize::Param`).
33    pub(super) return_ty: Option<Type>,
34}
35
36/// A monomorphized instance of a generic function.
37#[derive(Clone, Debug, PartialEq, Eq, Hash)]
38pub struct MonoInstance {
39    /// Original function name.
40    pub name: String,
41    /// Concrete size values for each type parameter.
42    pub size_args: Vec<u64>,
43}
44
45impl MonoInstance {
46    /// Mangled label: `sum` with N=3 -> `__sum__N3`.
47    pub fn mangled_name(&self) -> String {
48        let suffix: Vec<String> = self.size_args.iter().map(|n| format!("{}", n)).collect();
49        format!("{}__N{}", self.name, suffix.join("_"))
50    }
51}
52
53/// Variable info in scope.
54#[derive(Clone, Debug)]
55pub(super) struct VarInfo {
56    pub(super) ty: Ty,
57    pub(super) mutable: bool,
58}
59
60/// A function's exported signature: (name, params, return_type).
61pub type FnExport = (String, Vec<(String, Ty)>, Ty);
62
63/// Exported signatures from a type-checked module.
64#[derive(Clone, Debug)]
65pub struct ModuleExports {
66    pub module_name: String,
67    pub functions: Vec<FnExport>,
68    pub constants: Vec<(String, Ty, u64)>, // (name, ty, value)
69    pub structs: Vec<StructTy>,            // exported struct types
70    pub warnings: Vec<Diagnostic>,         // non-fatal diagnostics
71    /// Unique monomorphized instances of generic functions to emit.
72    pub mono_instances: Vec<MonoInstance>,
73    /// Per-call-site resolution: each generic call in AST order maps to a MonoInstance.
74    /// The emitter consumes these in order to know which mangled name to call.
75    pub call_resolutions: Vec<MonoInstance>,
76}
77
78pub(crate) struct TypeChecker {
79    /// Known function signatures (user-defined + builtins).
80    pub(super) functions: BTreeMap<String, FnSig>,
81    /// Variable scopes (stack of scope maps).
82    pub(super) scopes: Vec<BTreeMap<String, VarInfo>>,
83    /// Known constants (name -> value).
84    pub(super) constants: BTreeMap<String, u64>,
85    /// Known struct types (name or module.name -> StructTy).
86    pub(super) structs: BTreeMap<String, StructTy>,
87    /// Known event types (name -> field list).
88    pub(super) events: BTreeMap<String, Vec<(String, Ty)>>,
89    /// Accumulated diagnostics.
90    pub(super) diagnostics: Vec<Diagnostic>,
91    /// Variables proven to be in U32 range (via as_u32, split, or U32 type).
92    pub(super) u32_proven: BTreeSet<String>,
93    /// Generic (size-parameterized) function definitions.
94    pub(super) generic_fns: BTreeMap<String, GenericFnDef>,
95    /// Unique monomorphized instances collected during type checking.
96    pub(super) mono_instances: Vec<MonoInstance>,
97    /// Per-call-site resolutions in AST walk order.
98    pub(super) call_resolutions: Vec<MonoInstance>,
99    /// Active cfg flags for conditional compilation.
100    pub(super) cfg_flags: BTreeSet<String>,
101    /// Target VM configuration (digest width, hash rate, field limbs, etc.).
102    pub(super) target_config: crate::target::TerrainConfig,
103    /// Whether we are currently inside a `#[pure]` function body.
104    pub(super) in_pure_fn: bool,
105}
106
107impl Default for TypeChecker {
108    fn default() -> Self {
109        Self::new()
110    }
111}
112
113impl TypeChecker {
114    pub(crate) fn new() -> Self {
115        Self::with_target(crate::target::TerrainConfig::triton())
116    }
117
118    pub(crate) fn with_target(config: crate::target::TerrainConfig) -> Self {
119        let mut tc = Self {
120            functions: BTreeMap::new(),
121            scopes: Vec::new(),
122            constants: BTreeMap::new(),
123            structs: BTreeMap::new(),
124            events: BTreeMap::new(),
125            diagnostics: Vec::new(),
126            u32_proven: BTreeSet::new(),
127            generic_fns: BTreeMap::new(),
128            mono_instances: Vec::new(),
129            call_resolutions: Vec::new(),
130            cfg_flags: BTreeSet::from(["debug".to_string()]),
131            target_config: config,
132            in_pure_fn: false,
133        };
134        tc.register_builtins();
135        tc
136    }
137
138    /// Set active cfg flags for conditional compilation.
139    pub(crate) fn with_cfg_flags(mut self, flags: BTreeSet<String>) -> Self {
140        self.cfg_flags = flags;
141        self
142    }
143
144    /// Check if an item's cfg attribute is active.
145    fn is_cfg_active(&self, cfg: &Option<Spanned<String>>) -> bool {
146        match cfg {
147            None => true,
148            Some(flag) => self.cfg_flags.contains(&flag.node),
149        }
150    }
151
152    /// Check if a top-level item's cfg is active.
153    fn is_item_cfg_active(&self, item: &Item) -> bool {
154        match item {
155            Item::Fn(f) => self.is_cfg_active(&f.cfg),
156            Item::Const(c) => self.is_cfg_active(&c.cfg),
157            Item::Struct(s) => self.is_cfg_active(&s.cfg),
158            Item::Event(e) => self.is_cfg_active(&e.cfg),
159        }
160    }
161
162    /// Import exported signatures from another module.
163    /// Makes them available as `module_name.fn_name`.
164    /// For dotted modules like `std.hash`, also registers under
165    /// the short alias `hash.fn_name` so `hash.tip5()` works.
166    pub(crate) fn import_module(&mut self, exports: &ModuleExports) {
167        // Short alias: last segment of dotted module name
168        let short_prefix = exports
169            .module_name
170            .rsplit('.')
171            .next()
172            .unwrap_or(&exports.module_name);
173        let has_short = short_prefix != exports.module_name;
174
175        for (fn_name, params, return_ty) in &exports.functions {
176            let qualified = format!("{}.{}", exports.module_name, fn_name);
177            let sig = FnSig {
178                params: params.clone(),
179                return_ty: return_ty.clone(),
180            };
181            self.functions.insert(qualified, sig.clone());
182            if has_short {
183                let short = format!("{}.{}", short_prefix, fn_name);
184                self.functions.insert(short, sig);
185            }
186        }
187        for (const_name, _ty, value) in &exports.constants {
188            let qualified = format!("{}.{}", exports.module_name, const_name);
189            self.constants.insert(qualified, *value);
190            if has_short {
191                let short = format!("{}.{}", short_prefix, const_name);
192                self.constants.insert(short, *value);
193            }
194        }
195        for sty in &exports.structs {
196            let qualified = format!("{}.{}", exports.module_name, sty.name);
197            self.structs.insert(qualified, sty.clone());
198            if has_short {
199                let short = format!("{}.{}", short_prefix, sty.name);
200                self.structs.insert(short, sty.clone());
201            }
202        }
203    }
204
205    pub(crate) fn check_file(mut self, file: &File) -> Result<ModuleExports, Vec<Diagnostic>> {
206        let is_std_module = file.name.node.starts_with("std.")
207            || file.name.node.starts_with("vm.")
208            || file.name.node.starts_with("os.")
209            || file.name.node.starts_with("ext.")
210            || file.name.node.contains(".ext.");
211
212        // First pass: register all structs, function signatures, and constants
213        for item in &file.items {
214            // Skip items excluded by conditional compilation
215            if !self.is_item_cfg_active(&item.node) {
216                continue;
217            }
218            match &item.node {
219                Item::Struct(sdef) => {
220                    let fields: Vec<(String, Ty, bool)> = sdef
221                        .fields
222                        .iter()
223                        .map(|f| (f.name.node.clone(), self.resolve_type(&f.ty.node), f.is_pub))
224                        .collect();
225                    let sty = StructTy {
226                        name: sdef.name.node.clone(),
227                        fields,
228                    };
229                    self.structs.insert(sdef.name.node.clone(), sty);
230                }
231                Item::Fn(func) => {
232                    // #[intrinsic] is only allowed in vm.*/std.*/os.*/ext.* modules
233                    if func.intrinsic.is_some() && !is_std_module {
234                        self.error(
235                            format!(
236                                "#[intrinsic] is only allowed in vm.*/std.*/os.* modules, \
237                                 not in '{}'",
238                                file.name.node
239                            ),
240                            func.name.span,
241                        );
242                    }
243                    if func.type_params.is_empty() {
244                        // Non-generic function: resolve immediately.
245                        let params: Vec<(String, Ty)> = func
246                            .params
247                            .iter()
248                            .map(|p| (p.name.node.clone(), self.resolve_type(&p.ty.node)))
249                            .collect();
250                        let return_ty = func
251                            .return_ty
252                            .as_ref()
253                            .map(|t| self.resolve_type(&t.node))
254                            .unwrap_or(Ty::Unit);
255                        self.functions
256                            .insert(func.name.node.clone(), FnSig { params, return_ty });
257                    } else {
258                        // Generic function: store unresolved for monomorphization.
259                        let gdef = GenericFnDef {
260                            type_params: func.type_params.iter().map(|p| p.node.clone()).collect(),
261                            params: func
262                                .params
263                                .iter()
264                                .map(|p| (p.name.node.clone(), p.ty.node.clone()))
265                                .collect(),
266                            return_ty: func.return_ty.as_ref().map(|t| t.node.clone()),
267                        };
268                        self.generic_fns.insert(func.name.node.clone(), gdef);
269                    }
270                }
271                Item::Const(cdef) => {
272                    if let Expr::Literal(Literal::Integer(v)) = &cdef.value.node {
273                        self.constants.insert(cdef.name.node.clone(), *v);
274                    }
275                }
276                Item::Event(edef) => {
277                    if edef.fields.len() > 9 {
278                        self.error(
279                            format!(
280                                "event '{}' has {} fields, max is 9",
281                                edef.name.node,
282                                edef.fields.len()
283                            ),
284                            edef.name.span,
285                        );
286                    }
287                    let fields: Vec<(String, Ty)> = edef
288                        .fields
289                        .iter()
290                        .map(|f| {
291                            let ty = self.resolve_type(&f.ty.node);
292                            if ty != Ty::Field {
293                                self.error(
294                                    format!(
295                                        "event field '{}' must be Field type, got {}",
296                                        f.name.node,
297                                        ty.display()
298                                    ),
299                                    f.ty.span,
300                                );
301                            }
302                            (f.name.node.clone(), ty)
303                        })
304                        .collect();
305                    self.events.insert(edef.name.node.clone(), fields);
306                }
307            }
308        }
309
310        // Recursion detection: build call graph and reject cycles
311        self.detect_recursion(file);
312
313        // Second pass: type check function bodies
314        for item in &file.items {
315            if !self.is_item_cfg_active(&item.node) {
316                continue;
317            }
318            if let Item::Fn(func) = &item.node {
319                self.check_fn(func);
320            }
321        }
322
323        // Unused import detection: collect used module prefixes from all calls
324        let mut used_prefixes: BTreeSet<String> = BTreeSet::new();
325        for item in &file.items {
326            if !self.is_item_cfg_active(&item.node) {
327                continue;
328            }
329            if let Item::Fn(func) = &item.node {
330                if let Some(body) = &func.body {
331                    Self::collect_used_modules_block(&body.node, &mut used_prefixes);
332                }
333            }
334        }
335        for use_stmt in &file.uses {
336            let module_path = use_stmt.node.as_dotted();
337            // Short alias: last segment
338            let short = module_path
339                .rsplit('.')
340                .next()
341                .unwrap_or(&module_path)
342                .to_string();
343            if !used_prefixes.contains(&short) && !used_prefixes.contains(&module_path) {
344                self.warning(format!("unused import '{}'", module_path), use_stmt.span);
345            }
346        }
347
348        // Collect exports (pub items only)
349        let module_name = file.name.node.clone();
350        let mut exported_fns = Vec::new();
351        let mut exported_consts = Vec::new();
352        let mut exported_structs = Vec::new();
353
354        for item in &file.items {
355            if !self.is_item_cfg_active(&item.node) {
356                continue;
357            }
358            match &item.node {
359                Item::Fn(func) if func.is_pub => {
360                    let params: Vec<(String, Ty)> = func
361                        .params
362                        .iter()
363                        .map(|p| (p.name.node.clone(), self.resolve_type(&p.ty.node)))
364                        .collect();
365                    let return_ty = func
366                        .return_ty
367                        .as_ref()
368                        .map(|t| self.resolve_type(&t.node))
369                        .unwrap_or(Ty::Unit);
370                    exported_fns.push((func.name.node.clone(), params, return_ty));
371                }
372                Item::Const(cdef) if cdef.is_pub => {
373                    let ty = self.resolve_type(&cdef.ty.node);
374                    if let Expr::Literal(Literal::Integer(v)) = &cdef.value.node {
375                        exported_consts.push((cdef.name.node.clone(), ty, *v));
376                    }
377                }
378                Item::Struct(sdef) if sdef.is_pub => {
379                    if let Some(sty) = self.structs.get(&sdef.name.node) {
380                        exported_structs.push(sty.clone());
381                    }
382                }
383                _ => {}
384            }
385        }
386
387        let has_errors = self
388            .diagnostics
389            .iter()
390            .any(|d| d.severity == crate::diagnostic::Severity::Error);
391        if has_errors {
392            Err(self.diagnostics)
393        } else {
394            Ok(ModuleExports {
395                module_name,
396                functions: exported_fns,
397                constants: exported_consts,
398                structs: exported_structs,
399                warnings: self.diagnostics,
400                mono_instances: self.mono_instances,
401                call_resolutions: self.call_resolutions,
402            })
403        }
404    }
405
406    // --- Scope management ---
407
408    pub(super) fn push_scope(&mut self) {
409        self.scopes.push(BTreeMap::new());
410    }
411
412    pub(super) fn pop_scope(&mut self) {
413        self.scopes.pop();
414    }
415
416    pub(super) fn define_var(&mut self, name: &str, ty: Ty, mutable: bool) {
417        if let Some(scope) = self.scopes.last_mut() {
418            scope.insert(name.to_string(), VarInfo { ty, mutable });
419        }
420    }
421
422    pub(super) fn lookup_var(&self, name: &str) -> Option<&VarInfo> {
423        for scope in self.scopes.iter().rev() {
424            if let Some(info) = scope.get(name) {
425                return Some(info);
426            }
427        }
428        None
429    }
430
431    // --- Diagnostics ---
432
433    pub(super) fn error(&mut self, msg: String, span: Span) {
434        self.diagnostics.push(Diagnostic::error(msg, span));
435    }
436
437    pub(super) fn error_with_help(&mut self, msg: String, span: Span, help: String) {
438        self.diagnostics
439            .push(Diagnostic::error(msg, span).with_help(help));
440    }
441
442    pub(super) fn warning(&mut self, msg: String, span: Span) {
443        self.diagnostics.push(Diagnostic::warning(msg, span));
444    }
445}