Skip to main content

tsz_checker/
expr.rs

1//! Expression Type Checking
2//!
3//! This module handles type inference and checking for expressions.
4//! It follows the "Check Fast, Explain Slow" pattern where we first
5//! infer types, then use the solver to explain any failures.
6//!
7//! ## Integration with `CheckerState`
8//!
9//! `ExpressionChecker` serves as the primary dispatcher for expression types.
10//! Simple expressions are handled directly here, while complex expressions
11//! that need full `CheckerState` context return `TypeId::DELEGATE` to signal
12//! that `CheckerState::compute_type_of_node` should handle them.
13//!
14//! ### Expressions handled directly:
15//! - Simple literals without contextual typing (null)
16//! - typeof expressions (always string)
17//! - void expressions (always undefined)
18//! - Postfix unary (++/-- always return number)
19//! - Parenthesized expressions (pass through)
20//!
21//! ### Expressions delegated to `CheckerState`:
22//! - Literals with contextual typing (numeric, string, boolean, template)
23//! - Identifiers, this, super (need symbol resolution)
24//! - Binary expressions (need operator overloading, narrowing)
25//! - Call/new expressions (need signature resolution)
26//! - Property/element access (need object type resolution)
27//! - Function/arrow expressions (need signature building)
28//! - Object/array literals (need contextual typing)
29//! - Type assertions (as/satisfies) (need type node resolution)
30//! - Conditional expressions (need union type building)
31//! - Await expressions (need Promise unwrapping)
32
33use super::context::CheckerContext;
34
35use tsz_parser::parser::NodeIndex;
36use tsz_parser::parser::syntax_kind_ext;
37use tsz_scanner::SyntaxKind;
38use tsz_solver::TypeId;
39
40use tsz_solver::recursion::{DepthCounter, RecursionProfile};
41
42/// Expression type checker that operates on the shared context.
43///
44/// This is a stateless checker that borrows the context mutably.
45/// All type inference for expressions goes through this checker.
46pub struct ExpressionChecker<'a, 'ctx> {
47    ctx: &'a mut CheckerContext<'ctx>,
48    /// Recursion depth counter for stack overflow protection.
49    depth: DepthCounter,
50}
51
52impl<'a, 'ctx> ExpressionChecker<'a, 'ctx> {
53    /// Create a new expression checker with a mutable context reference.
54    pub const fn new(ctx: &'a mut CheckerContext<'ctx>) -> Self {
55        Self {
56            ctx,
57            depth: DepthCounter::with_profile(RecursionProfile::ExpressionCheck),
58        }
59    }
60
61    /// Check an expression and return its type.
62    ///
63    /// This is the main entry point for expression type checking.
64    /// It handles caching and dispatches to specific expression handlers.
65    pub fn check(&mut self, idx: NodeIndex) -> TypeId {
66        self.check_with_context(idx, None)
67    }
68
69    /// Check an expression with a contextual type hint.
70    ///
71    /// Contextual types enable downward inference where the expected type
72    /// influences the inferred type. For example:
73    /// - `const x: string = expr` - `expr` is checked with context `string`
74    /// - `const f: (x: number) => void = (x) => {}` - `x` is inferred as `number`
75    ///
76    /// # Caching Behavior
77    ///
78    /// When `context_type` is `Some`, the cache is **bypassed** to avoid
79    /// incorrect results. The same expression can have different types
80    /// depending on the context, so caching by `NodeIndex` alone is unsound.
81    pub fn check_with_context(&mut self, idx: NodeIndex, context_type: Option<TypeId>) -> TypeId {
82        // Stack overflow protection
83        if !self.depth.enter() {
84            return TypeId::ERROR;
85        }
86
87        let result = if let Some(ctx_type) = context_type {
88            // Bypass cache when contextual type is provided
89            // Contextual types can produce different results for the same node
90            self.compute_type_with_context(idx, ctx_type)
91        } else {
92            // Check cache first for non-contextual checks
93            if let Some(&cached) = self.ctx.node_types.get(&idx.0) {
94                self.depth.leave();
95                return cached;
96            }
97
98            // Compute and cache
99            let result = self.compute_type(idx);
100            self.ctx.node_types.insert(idx.0, result);
101            result
102        };
103
104        self.depth.leave();
105        result
106    }
107
108    /// Compute the type of an expression without caching.
109    ///
110    /// This is called by `CheckerState::compute_type_of_node` to get an initial
111    /// type for expressions. Returns `TypeId::DELEGATE` if the expression needs
112    /// full `CheckerState` context for proper type resolution.
113    ///
114    /// Simple expressions that don't need contextual typing or symbol resolution
115    /// are handled directly here. Complex expressions delegate to `CheckerState`.
116    pub fn compute_type_uncached(&mut self, idx: NodeIndex) -> TypeId {
117        self.compute_type_impl(idx, None)
118    }
119
120    /// Compute the type of an expression with contextual typing (no caching).
121    ///
122    /// This is called when a contextual type is available (e.g., from variable
123    /// declarations, assignments, function parameters). The contextual type
124    /// influences how the expression is inferred.
125    fn compute_type_with_context(&mut self, idx: NodeIndex, context_type: TypeId) -> TypeId {
126        self.compute_type_impl(idx, Some(context_type))
127    }
128
129    /// Compute the type of an expression (internal, not cached).
130    fn compute_type(&mut self, idx: NodeIndex) -> TypeId {
131        self.compute_type_impl(idx, None)
132    }
133
134    /// Core implementation for computing expression types.
135    ///
136    /// Returns `TypeId::DELEGATE` for complex expressions that need `CheckerState`.
137    ///
138    /// # Parameters
139    /// - `idx`: The node index to check
140    /// - `context_type`: Optional contextual type hint for downward inference
141    fn compute_type_impl(&mut self, idx: NodeIndex, _context_type: Option<TypeId>) -> TypeId {
142        let Some(node) = self.ctx.arena.get(idx) else {
143            // Return UNKNOWN instead of ANY to expose missing nodes as errors
144            return TypeId::UNKNOWN;
145        };
146
147        match node.kind {
148            // =====================================================================
149            // Simple expressions handled directly
150            // =====================================================================
151
152            // Null literal - always TypeId::NULL (context doesn't affect null)
153            k if k == SyntaxKind::NullKeyword as u16 => TypeId::NULL,
154
155            // typeof expression always returns string (context doesn't affect typeof)
156            k if k == syntax_kind_ext::TYPE_OF_EXPRESSION => TypeId::STRING,
157
158            // void expression always returns undefined (context doesn't affect void)
159            k if k == syntax_kind_ext::VOID_EXPRESSION => TypeId::UNDEFINED,
160
161            // Parenthesized expression - pass through context to inner expression
162            k if k == syntax_kind_ext::PARENTHESIZED_EXPRESSION => {
163                if let Some(paren) = self.ctx.arena.get_parenthesized(node) {
164                    // Check if expression is missing (parse error: empty parentheses)
165                    if paren.expression.is_none() {
166                        // Parse error - return ERROR to suppress cascading errors
167                        return TypeId::ERROR;
168                    }
169                    // Recursively check inner expression with same context
170                    self.compute_type_impl(paren.expression, _context_type)
171                } else {
172                    // Return DELEGATE to let CheckerState handle malformed nodes
173                    TypeId::DELEGATE
174                }
175            }
176
177            // =====================================================================
178            // Literals with contextual typing - DELEGATE to CheckerState
179            // These need contextual typing analysis to decide between literal types
180            // (e.g., `42` as literal) vs widened types (e.g., `number`).
181            // =====================================================================
182            k if k == SyntaxKind::NumericLiteral as u16 => TypeId::DELEGATE,
183            k if k == SyntaxKind::StringLiteral as u16 => TypeId::DELEGATE,
184            k if k == SyntaxKind::TrueKeyword as u16 => TypeId::DELEGATE,
185            k if k == SyntaxKind::FalseKeyword as u16 => TypeId::DELEGATE,
186            k if k == SyntaxKind::NoSubstitutionTemplateLiteral as u16 => TypeId::DELEGATE,
187
188            // =====================================================================
189            // Expressions requiring symbol resolution - DELEGATE to CheckerState
190            // =====================================================================
191            k if k == SyntaxKind::Identifier as u16 => TypeId::DELEGATE,
192            k if k == SyntaxKind::ThisKeyword as u16 => TypeId::DELEGATE,
193            k if k == SyntaxKind::SuperKeyword as u16 => TypeId::DELEGATE,
194
195            // =====================================================================
196            // Complex expressions - DELEGATE to CheckerState
197            // =====================================================================
198
199            // Binary expressions need operator type resolution and narrowing
200            k if k == syntax_kind_ext::BINARY_EXPRESSION => TypeId::DELEGATE,
201
202            // Call/new expressions need signature resolution
203            k if k == syntax_kind_ext::CALL_EXPRESSION => TypeId::DELEGATE,
204            k if k == syntax_kind_ext::NEW_EXPRESSION => TypeId::DELEGATE,
205
206            // Property/element access need object type resolution
207            k if k == syntax_kind_ext::PROPERTY_ACCESS_EXPRESSION => TypeId::DELEGATE,
208            k if k == syntax_kind_ext::ELEMENT_ACCESS_EXPRESSION => TypeId::DELEGATE,
209
210            // Conditional expressions need union type building
211            k if k == syntax_kind_ext::CONDITIONAL_EXPRESSION => TypeId::DELEGATE,
212
213            // Function expressions need signature building
214            k if k == syntax_kind_ext::FUNCTION_EXPRESSION => TypeId::DELEGATE,
215            k if k == syntax_kind_ext::ARROW_FUNCTION => TypeId::DELEGATE,
216
217            // Object/array literals need contextual typing
218            k if k == syntax_kind_ext::OBJECT_LITERAL_EXPRESSION => TypeId::DELEGATE,
219            k if k == syntax_kind_ext::ARRAY_LITERAL_EXPRESSION => TypeId::DELEGATE,
220
221            // Class expressions need class type building
222            k if k == syntax_kind_ext::CLASS_EXPRESSION => TypeId::DELEGATE,
223
224            // Unary expressions need operand type checking
225            k if k == syntax_kind_ext::PREFIX_UNARY_EXPRESSION => TypeId::DELEGATE,
226            k if k == syntax_kind_ext::POSTFIX_UNARY_EXPRESSION => TypeId::DELEGATE,
227
228            // Await expressions need Promise unwrapping
229            k if k == syntax_kind_ext::AWAIT_EXPRESSION => TypeId::DELEGATE,
230
231            // Type assertions need type node resolution
232            k if k == syntax_kind_ext::AS_EXPRESSION => TypeId::DELEGATE,
233            k if k == syntax_kind_ext::SATISFIES_EXPRESSION => TypeId::DELEGATE,
234            k if k == syntax_kind_ext::TYPE_ASSERTION => TypeId::DELEGATE,
235
236            // Template expressions need string interpolation handling
237            k if k == syntax_kind_ext::TEMPLATE_EXPRESSION => TypeId::DELEGATE,
238
239            // Variable declarations need initializer/annotation handling
240            k if k == syntax_kind_ext::VARIABLE_DECLARATION => TypeId::DELEGATE,
241
242            // Function declarations need signature building
243            k if k == syntax_kind_ext::FUNCTION_DECLARATION => TypeId::DELEGATE,
244
245            // =====================================================================
246            // Type nodes - DELEGATE to CheckerState
247            // These are not expressions but may be passed through get_type_of_node
248            // =====================================================================
249            k if k == syntax_kind_ext::TYPE_REFERENCE => TypeId::DELEGATE,
250            k if k == syntax_kind_ext::UNION_TYPE => TypeId::DELEGATE,
251            k if k == syntax_kind_ext::INTERSECTION_TYPE => TypeId::DELEGATE,
252            k if k == syntax_kind_ext::ARRAY_TYPE => TypeId::DELEGATE,
253            k if k == syntax_kind_ext::TYPE_OPERATOR => TypeId::DELEGATE,
254            k if k == syntax_kind_ext::FUNCTION_TYPE => TypeId::DELEGATE,
255            k if k == syntax_kind_ext::TYPE_LITERAL => TypeId::DELEGATE,
256            k if k == syntax_kind_ext::TYPE_QUERY => TypeId::DELEGATE,
257            k if k == syntax_kind_ext::QUALIFIED_NAME => TypeId::DELEGATE,
258
259            // Type keywords - DELEGATE to CheckerState for consistency
260            k if k == SyntaxKind::NumberKeyword as u16 => TypeId::DELEGATE,
261            k if k == SyntaxKind::StringKeyword as u16 => TypeId::DELEGATE,
262            k if k == SyntaxKind::BooleanKeyword as u16 => TypeId::DELEGATE,
263            k if k == SyntaxKind::VoidKeyword as u16 => TypeId::DELEGATE,
264            k if k == SyntaxKind::AnyKeyword as u16 => TypeId::DELEGATE,
265            k if k == SyntaxKind::NeverKeyword as u16 => TypeId::DELEGATE,
266            k if k == SyntaxKind::UnknownKeyword as u16 => TypeId::DELEGATE,
267            k if k == SyntaxKind::UndefinedKeyword as u16 => TypeId::DELEGATE,
268            k if k == SyntaxKind::ObjectKeyword as u16 => TypeId::DELEGATE,
269            k if k == SyntaxKind::BigIntKeyword as u16 => TypeId::DELEGATE,
270            k if k == SyntaxKind::SymbolKeyword as u16 => TypeId::DELEGATE,
271
272            // JSX elements - DELEGATE to CheckerState
273            k if k == syntax_kind_ext::JSX_ELEMENT => TypeId::DELEGATE,
274            k if k == syntax_kind_ext::JSX_SELF_CLOSING_ELEMENT => TypeId::DELEGATE,
275            k if k == syntax_kind_ext::JSX_FRAGMENT => TypeId::DELEGATE,
276
277            // =====================================================================
278            // Default - unknown node type, delegate to CheckerState
279            // =====================================================================
280            _ => TypeId::DELEGATE,
281        }
282    }
283
284    /// Get the context reference (for read-only access).
285    pub const fn context(&self) -> &CheckerContext<'ctx> {
286        self.ctx
287    }
288}
289
290#[cfg(test)]
291#[path = "../tests/expr.rs"]
292mod tests;