Skip to main content

shape_runtime/type_system/
checker.rs

1//! Type Checker
2//!
3//! Performs type checking on Shape programs using the type inference engine
4//! and reports type errors with helpful messages.
5
6use super::errors::{TypeError, TypeErrorWithLocation, TypeResult};
7use super::inference::TypeInferenceEngine;
8use super::*;
9use shape_ast::ast::{EnumDef, Expr, Item, Program, Span, Statement, TypeAnnotation};
10use std::collections::{HashMap, HashSet};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum TypeAnalysisMode {
14    FailFast,
15    RecoverAll,
16}
17
18pub struct TypeChecker {
19    /// Type inference engine
20    inference_engine: TypeInferenceEngine,
21    /// Collected errors
22    errors: Vec<TypeErrorWithLocation>,
23    /// Source code for error reporting
24    source: Option<String>,
25    /// File name for error reporting
26    filename: Option<String>,
27    /// Enum definitions for resolving named types
28    enum_defs: HashMap<String, EnumDef>,
29    /// Current function's parameter types (name -> type annotation)
30    current_function_params: HashMap<String, shape_ast::ast::TypeAnnotation>,
31    /// Error emission behavior for semantic analysis.
32    analysis_mode: TypeAnalysisMode,
33}
34
35impl Default for TypeChecker {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl TypeChecker {
42    pub fn new() -> Self {
43        TypeChecker {
44            inference_engine: TypeInferenceEngine::new(),
45            errors: Vec::new(),
46            source: None,
47            filename: None,
48            enum_defs: HashMap::new(),
49            current_function_params: HashMap::new(),
50            analysis_mode: TypeAnalysisMode::FailFast,
51        }
52    }
53
54    /// Set source code for error reporting
55    pub fn with_source(mut self, source: String) -> Self {
56        self.source = Some(source);
57        self
58    }
59
60    /// Set filename for error reporting
61    pub fn with_filename(mut self, filename: String) -> Self {
62        self.filename = Some(filename);
63        self
64    }
65
66    /// Register host-provided root-scope bindings (e.g. extension module namespaces).
67    pub fn with_known_bindings(mut self, names: &[String]) -> Self {
68        self.inference_engine.register_known_bindings(names);
69        self
70    }
71
72    pub fn with_analysis_mode(mut self, mode: TypeAnalysisMode) -> Self {
73        self.analysis_mode = mode;
74        self
75    }
76
77    /// Type check a complete program
78    pub fn check_program(
79        &mut self,
80        program: &Program,
81    ) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
82        // Clear previous errors
83        self.errors.clear();
84        self.enum_defs.clear();
85
86        // Collect enum definitions for type resolution
87        for item in &program.items {
88            if let Item::Enum(enum_def, _) = item {
89                self.enum_defs
90                    .insert(enum_def.name.clone(), enum_def.clone());
91            }
92        }
93
94        let types = match self.analysis_mode {
95            TypeAnalysisMode::FailFast => match self.inference_engine.infer_program(program) {
96                Ok(types) => types,
97                Err(err) => {
98                    self.add_inference_error(err);
99                    return Err(self.errors.clone());
100                }
101            },
102            TypeAnalysisMode::RecoverAll => {
103                let (types, inference_errors) =
104                    self.inference_engine.infer_program_best_effort(program);
105                for err in inference_errors {
106                    self.add_inference_error(err);
107                }
108                types
109            }
110        };
111
112        // Perform additional type checking
113        self.check_items(&program.items);
114
115        // Check exhaustiveness of match expressions
116        self.check_expressions(&program.items);
117
118        self.prune_error_cascades();
119
120        if self.errors.is_empty() {
121            // Convert inference types to semantic types
122            let semantic_types: HashMap<String, SemanticType> = types
123                .iter()
124                .filter_map(|(name, ty)| ty.to_semantic().map(|st| (name.clone(), st)))
125                .collect();
126
127            Ok(TypeCheckResult {
128                types,
129                semantic_types,
130                warnings: Vec::new(),
131            })
132        } else {
133            Err(self.errors.clone())
134        }
135    }
136
137    fn add_inference_error(&mut self, err: TypeError) {
138        let (line, col) = self.find_inference_error_position(&err);
139        self.add_error(err, line, col);
140    }
141
142    fn prune_error_cascades(&mut self) {
143        let has_specific_errors = self
144            .errors
145            .iter()
146            .any(|err| !matches!(err.error, TypeError::UnsolvedConstraints(_)));
147        if has_specific_errors {
148            self.errors
149                .retain(|err| !matches!(err.error, TypeError::UnsolvedConstraints(_)));
150        }
151
152        let mut seen = HashSet::new();
153        self.errors.retain(|err| {
154            let key = (err.line, err.column, err.error.to_string());
155            seen.insert(key)
156        });
157    }
158
159    fn find_inference_error_position(&self, error: &TypeError) -> (usize, usize) {
160        match error {
161            TypeError::UnknownProperty(_, property) => {
162                if let Some(span) = self
163                    .inference_engine
164                    .lookup_unknown_property_origin(property)
165                {
166                    if let Some((line, col)) = self.span_to_line_col(span) {
167                        return (line, col);
168                    }
169                }
170                (0, 0)
171            }
172            TypeError::UndefinedVariable(name) => self
173                .inference_engine
174                .lookup_undefined_variable_origin(name)
175                .and_then(|span| self.span_to_line_col(span))
176                .unwrap_or((0, 0)),
177            TypeError::UnsolvedConstraints(constraints) => {
178                if let Some(span) = self
179                    .inference_engine
180                    .find_origin_for_unsolved_constraints(constraints)
181                {
182                    if let Some((line, col)) = self.span_to_line_col(span) {
183                        return (line, col);
184                    }
185                }
186                if let Some(span) = self.inference_engine.find_any_constraint_origin() {
187                    if let Some((line, col)) = self.span_to_line_col(span) {
188                        return (line, col);
189                    }
190                }
191                (0, 0)
192            }
193            TypeError::InvalidAssertion(_, _) => (0, 0),
194            TypeError::NonExhaustiveMatch { enum_name, .. } => self
195                .inference_engine
196                .lookup_non_exhaustive_match_origin(enum_name)
197                .and_then(|span| self.span_to_line_col(span))
198                .unwrap_or((0, 0)),
199            TypeError::GenericTypeError { symbol, .. } => {
200                if let Some(symbol) = symbol
201                    && let Some(span) = self
202                        .inference_engine
203                        .lookup_callable_origin_for_name(symbol)
204                    && let Some((line, col)) = self.span_to_line_col(span)
205                {
206                    return (line, col);
207                }
208                if let Some(span) = self.inference_engine.find_any_constraint_origin() {
209                    if let Some((line, col)) = self.span_to_line_col(span) {
210                        return (line, col);
211                    }
212                }
213                (0, 0)
214            }
215            _ => (0, 0),
216        }
217    }
218
219    fn span_to_line_col(&self, span: shape_ast::ast::Span) -> Option<(usize, usize)> {
220        let source = self.source.as_ref()?;
221        let start = span.start.min(source.len());
222        let prefix = &source[..start];
223        let line = prefix.bytes().filter(|b| *b == b'\n').count() + 1;
224        let line_start = prefix.rfind('\n').map(|idx| idx + 1).unwrap_or(0);
225        let column = prefix[line_start..].chars().count() + 1;
226        Some((line, column))
227    }
228
229    /// Check all items in the program
230    fn check_items(&mut self, items: &[Item]) {
231        for item in items {
232            self.check_item(item);
233        }
234    }
235
236    /// Check all expressions in the program for exhaustiveness
237    fn check_expressions(&mut self, items: &[Item]) {
238        for item in items {
239            self.check_item_expressions(item);
240        }
241    }
242
243    /// Check expressions within an item
244    fn check_item_expressions(&mut self, item: &Item) {
245        if let Item::Function(func, _) = item {
246            // Set up function parameter context for type resolution
247            self.current_function_params.clear();
248            for param in &func.params {
249                if let Some(type_ann) = &param.type_annotation {
250                    // Insert all identifiers from the pattern
251                    for name in param.get_identifiers() {
252                        self.current_function_params.insert(name, type_ann.clone());
253                    }
254                }
255            }
256
257            for stmt in &func.body {
258                self.check_statement_expressions(stmt);
259            }
260
261            // Clear function parameter context
262            self.current_function_params.clear();
263        }
264    }
265
266    /// Check expressions within a statement
267    fn check_statement_expressions(&mut self, stmt: &Statement) {
268        match stmt {
269            Statement::Expression(expr, _) => self.check_expr(expr),
270            Statement::Return(Some(expr), _) => self.check_expr(expr),
271            Statement::VariableDecl(decl, _) => {
272                if let Some(init) = &decl.value {
273                    self.check_expr(init);
274                }
275            }
276            Statement::If(if_stmt, _) => {
277                self.check_expr(&if_stmt.condition);
278                for stmt in &if_stmt.then_body {
279                    self.check_statement_expressions(stmt);
280                }
281                if let Some(else_body) = &if_stmt.else_body {
282                    for stmt in else_body {
283                        self.check_statement_expressions(stmt);
284                    }
285                }
286            }
287            Statement::While(while_loop, _) => {
288                self.check_expr(&while_loop.condition);
289                for stmt in &while_loop.body {
290                    self.check_statement_expressions(stmt);
291                }
292            }
293            Statement::For(for_loop, _) => {
294                for stmt in &for_loop.body {
295                    self.check_statement_expressions(stmt);
296                }
297            }
298            _ => {}
299        }
300    }
301
302    /// Check a single expression for type issues
303    ///
304    /// Note: Exhaustiveness checking is now handled by the inference engine
305    /// during `infer_expr()`, so we don't need to check it here.
306    fn check_expr(&mut self, expr: &Expr) {
307        match expr {
308            Expr::Match(match_expr, _span) => {
309                // Exhaustiveness is checked by inference engine - just check sub-expressions
310                self.check_expr(&match_expr.scrutinee);
311                for arm in &match_expr.arms {
312                    if let Some(guard) = &arm.guard {
313                        self.check_expr(guard);
314                    }
315                    self.check_expr(&arm.body);
316                }
317            }
318            // Recursively check sub-expressions
319            Expr::BinaryOp { left, right, .. } => {
320                self.check_expr(left);
321                self.check_expr(right);
322            }
323            Expr::UnaryOp { operand, .. } => {
324                self.check_expr(operand);
325            }
326            Expr::Conditional {
327                condition,
328                then_expr,
329                else_expr,
330                ..
331            } => {
332                self.check_expr(condition);
333                self.check_expr(then_expr);
334                if let Some(else_e) = else_expr {
335                    self.check_expr(else_e);
336                }
337            }
338            Expr::If(if_expr, _) => {
339                self.check_expr(&if_expr.condition);
340                self.check_expr(&if_expr.then_branch);
341                if let Some(else_branch) = &if_expr.else_branch {
342                    self.check_expr(else_branch);
343                }
344            }
345            Expr::FunctionCall { args, .. } => {
346                for arg in args {
347                    self.check_expr(arg);
348                }
349            }
350            Expr::MethodCall { receiver, args, .. } => {
351                self.check_expr(receiver);
352                for arg in args {
353                    self.check_expr(arg);
354                }
355            }
356            Expr::Array(elems, _) => {
357                for elem in elems {
358                    self.check_expr(elem);
359                }
360            }
361            Expr::PropertyAccess { object, .. } => {
362                self.check_expr(object);
363            }
364            Expr::IndexAccess {
365                object,
366                index,
367                end_index,
368                ..
369            } => {
370                self.check_expr(object);
371                self.check_expr(index);
372                if let Some(end) = end_index {
373                    self.check_expr(end);
374                }
375            }
376            _ => {}
377        }
378    }
379
380    // Note: check_match_exhaustiveness, resolve_named_to_enum, and span_to_location
381    // were removed as exhaustiveness checking is now handled by the inference engine
382    // in TypeInferenceEngine::infer_expr() for Match expressions.
383
384    /// Check a single item
385    fn check_item(&mut self, item: &Item) {
386        match item {
387            Item::Function(func, span) => {
388                // Check for missing return statements
389                if func.return_type.is_some()
390                    && !matches!(func.return_type.as_ref().unwrap(), TypeAnnotation::Void)
391                    && !self.has_return_statement(&func.body)
392                {
393                    let (line, col) = self.item_span_to_line_col(*span);
394                    self.add_error(TypeError::MissingReturn(func.name.clone()), line, col);
395                }
396            }
397
398            Item::TypeAlias(alias, span) => {
399                // Check for cyclic type aliases
400                if self.is_cyclic_type_alias(&alias.name, &alias.type_annotation) {
401                    let (line, col) = self.item_span_to_line_col(*span);
402                    self.add_error(TypeError::CyclicTypeAlias(alias.name.clone()), line, col);
403                }
404            }
405
406            Item::Interface(interface, span) => {
407                // Validate interface definition
408                self.check_interface(interface, *span);
409            }
410
411            _ => {}
412        }
413    }
414
415    /// Check if statements contain a return statement
416    fn has_return_statement(&self, stmts: &[Statement]) -> bool {
417        for stmt in stmts {
418            match stmt {
419                Statement::Return(_, _) => return true,
420                Statement::If(if_stmt, _) => {
421                    // Both branches must have returns
422                    if let Some(else_body) = &if_stmt.else_body {
423                        if self.has_return_statement(&if_stmt.then_body)
424                            && self.has_return_statement(else_body)
425                        {
426                            return true;
427                        }
428                    }
429                }
430                Statement::While(while_loop, _) => {
431                    if self.has_return_statement(&while_loop.body) {
432                        // Note: This is conservative - while loop might not execute
433                        return true;
434                    }
435                }
436                Statement::For(for_loop, _) => {
437                    if self.has_return_statement(&for_loop.body) {
438                        // Note: This is conservative - for loop might not execute
439                        return true;
440                    }
441                }
442                _ => {}
443            }
444        }
445
446        false
447    }
448
449    /// Check for cyclic type aliases
450    fn is_cyclic_type_alias(&self, name: &str, ty: &TypeAnnotation) -> bool {
451        self.references_type(ty, name)
452    }
453
454    /// Check if a type annotation references a specific type name
455    fn references_type(&self, ty: &TypeAnnotation, name: &str) -> bool {
456        match ty {
457            TypeAnnotation::Reference(ref_name) => ref_name == name,
458            TypeAnnotation::Array(elem) => self.references_type(elem, name),
459            TypeAnnotation::Tuple(elems) => {
460                elems.iter().any(|elem| self.references_type(elem, name))
461            }
462            TypeAnnotation::Object(fields) => fields
463                .iter()
464                .any(|field| self.references_type(&field.type_annotation, name)),
465            TypeAnnotation::Function { params, returns } => {
466                params
467                    .iter()
468                    .any(|param| self.references_type(&param.type_annotation, name))
469                    || self.references_type(returns, name)
470            }
471            TypeAnnotation::Union(types) => types.iter().any(|ty| self.references_type(ty, name)),
472            TypeAnnotation::Optional(ty) => self.references_type(ty, name),
473            TypeAnnotation::Generic { args, .. } => {
474                args.iter().any(|arg| self.references_type(arg, name))
475            }
476            _ => false,
477        }
478    }
479
480    /// Check interface definition
481    fn check_interface(&mut self, interface: &shape_ast::ast::InterfaceDef, interface_span: Span) {
482        // Check for duplicate members
483        let mut seen_members = HashMap::new();
484
485        for (i, member) in interface.members.iter().enumerate() {
486            let member_name = match member {
487                shape_ast::ast::InterfaceMember::Property { name, .. } => name,
488                shape_ast::ast::InterfaceMember::Method { name, .. } => name,
489                shape_ast::ast::InterfaceMember::IndexSignature { .. } => continue,
490            };
491
492            if let Some(_prev_index) = seen_members.get(member_name) {
493                let (line, col) = self.item_span_to_line_col(interface_span);
494                self.add_error(
495                    TypeError::InterfaceError(
496                        interface.name.clone(),
497                        format!("Duplicate member '{}'", member_name),
498                    ),
499                    line,
500                    col,
501                );
502            } else {
503                seen_members.insert(member_name.clone(), i);
504            }
505        }
506    }
507
508    fn item_span_to_line_col(&self, span: Span) -> (usize, usize) {
509        self.span_to_line_col(span).unwrap_or((0, 0))
510    }
511
512    /// Add an error with location information
513    fn add_error(&mut self, error: TypeError, line: usize, column: usize) {
514        let mut err = TypeErrorWithLocation::new(error, line, column);
515
516        if let Some(filename) = &self.filename {
517            err = err.with_file(filename.clone());
518        }
519
520        if let Some(source) = &self.source {
521            // Extract the source line
522            if let Some(source_line) = source.lines().nth(line.saturating_sub(1)) {
523                err = err.with_source_line(source_line.to_string());
524            }
525        }
526
527        self.errors.push(err);
528    }
529
530    /// Get all collected errors
531    pub fn errors(&self) -> &[TypeErrorWithLocation] {
532        &self.errors
533    }
534
535    /// Format all errors for display
536    pub fn format_errors(&self) -> String {
537        self.errors
538            .iter()
539            .map(|err| err.format_with_source())
540            .collect::<Vec<_>>()
541            .join("\n")
542    }
543}
544
545/// Shared single-entry type analysis used by compiler and LSP.
546pub fn analyze_program(
547    program: &Program,
548    source: Option<&str>,
549    filename: Option<&str>,
550    known_bindings: Option<&[String]>,
551) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
552    analyze_program_with_mode(
553        program,
554        source,
555        filename,
556        known_bindings,
557        TypeAnalysisMode::FailFast,
558    )
559}
560
561/// Shared type analysis with explicit recovery behavior.
562pub fn analyze_program_with_mode(
563    program: &Program,
564    source: Option<&str>,
565    filename: Option<&str>,
566    known_bindings: Option<&[String]>,
567    analysis_mode: TypeAnalysisMode,
568) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
569    let mut checker = TypeChecker::new();
570    if let Some(src) = source {
571        checker = checker.with_source(src.to_string());
572    }
573    if let Some(file) = filename {
574        checker = checker.with_filename(file.to_string());
575    }
576    if let Some(names) = known_bindings {
577        checker = checker.with_known_bindings(names);
578    }
579    checker = checker.with_analysis_mode(analysis_mode);
580    checker.check_program(program)
581}
582
583/// Result of type checking
584#[derive(Debug)]
585pub struct TypeCheckResult {
586    /// Inferred types for all declarations (inference-level types)
587    pub types: HashMap<String, Type>,
588    /// Semantic types for all declarations (user-facing types)
589    pub semantic_types: HashMap<String, SemanticType>,
590    /// Type warnings (non-fatal issues)
591    pub warnings: Vec<TypeWarning>,
592}
593
594impl TypeCheckResult {
595    /// Get the semantic type for a declaration
596    pub fn get_semantic_type(&self, name: &str) -> Option<&SemanticType> {
597        self.semantic_types.get(name)
598    }
599
600    /// Get all function declarations that are fallible (return Result)
601    pub fn fallible_functions(&self) -> Vec<&str> {
602        self.semantic_types
603            .iter()
604            .filter_map(|(name, ty)| {
605                if let SemanticType::Function(sig) = ty {
606                    if sig.return_type.is_result() {
607                        return Some(name.as_str());
608                    }
609                }
610                None
611            })
612            .collect()
613    }
614}
615
616/// Type warning for non-fatal issues
617#[derive(Debug)]
618pub struct TypeWarning {
619    pub message: String,
620    pub line: usize,
621    pub column: usize,
622}
623
624/// Type check an expression and return its type
625pub fn type_of_expr(expr: &Expr, _env: &TypeEnvironment) -> TypeResult<Type> {
626    let mut engine = TypeInferenceEngine::new();
627    engine.infer_expr(expr)
628}
629
630/// Quick type check for REPL and testing
631pub fn quick_check(source: &str) -> Result<TypeCheckResult, String> {
632    use shape_ast::parser::parse_program;
633
634    let program = parse_program(source).map_err(|e| format!("Parse error: {}", e))?;
635
636    let mut checker = TypeChecker::new().with_source(source.to_string());
637
638    checker.check_program(&program).map_err(|errors| {
639        errors
640            .iter()
641            .map(|e| e.format_with_source())
642            .collect::<Vec<_>>()
643            .join("\n")
644    })
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650
651    #[test]
652    fn test_exhaustiveness_integration_non_exhaustive_match_produces_error() {
653        // This test proves exhaustiveness checking is connected to the compiler pipeline.
654        // A match on an enum that doesn't cover all variants should produce an error.
655        let source = r#"
656            enum Status { Active, Inactive, Pending }
657
658            function check(s: Status) {
659                return match s {
660                    Status::Active => "yes"
661                };
662            }
663        "#;
664
665        let result = quick_check(source);
666
667        // The match is non-exhaustive (missing Inactive and Pending)
668        // so we expect an error
669        assert!(
670            result.is_err(),
671            "Expected error for non-exhaustive match, got: {:?}",
672            result
673        );
674        let err = result.unwrap_err();
675        assert!(
676            err.contains("NonExhaustive")
677                || err.contains("non-exhaustive")
678                || err.contains("missing"),
679            "Expected non-exhaustive match error, got: {}",
680            err
681        );
682    }
683
684    #[test]
685    fn test_exhaustiveness_integration_exhaustive_match_succeeds() {
686        // A match that covers all variants should succeed
687        let source = r#"
688            enum Status { Active, Inactive }
689
690            function check(s: Status) {
691                return match s {
692                    Status::Active => "yes",
693                    Status::Inactive => "no"
694                };
695            }
696        "#;
697
698        let result = quick_check(source);
699
700        // The match is exhaustive, so no error expected from exhaustiveness
701        // (there might be other errors, but not NonExhaustiveMatch)
702        if let Err(err) = &result {
703            assert!(
704                !err.contains("NonExhaustive") && !err.contains("non-exhaustive"),
705                "Should not have non-exhaustive error for exhaustive match, got: {}",
706                err
707            );
708        }
709    }
710
711    #[test]
712    fn test_exhaustiveness_integration_wildcard_makes_exhaustive() {
713        // A match with wildcard pattern should be trivially exhaustive
714        let source = r#"
715            enum Status { Active, Inactive, Pending }
716
717            function check(s: Status) {
718                return match s {
719                    Status::Active => "yes",
720                    _ => "other"
721                };
722            }
723        "#;
724
725        let result = quick_check(source);
726
727        // The wildcard makes it exhaustive
728        if let Err(err) = &result {
729            assert!(
730                !err.contains("NonExhaustive") && !err.contains("non-exhaustive"),
731                "Wildcard should make match exhaustive, got: {}",
732                err
733            );
734        }
735    }
736
737    #[test]
738    fn test_undefined_variable_reports_identifier_position() {
739        use shape_ast::parser::parse_program;
740
741        let source = r#"
742let x = 1
743let y = duckdb.connect("duckdb://analytics.db")
744"#;
745
746        let program = parse_program(source).expect("program should parse");
747        let result = analyze_program(&program, Some(source), None, None);
748        let errors = result.expect_err("undefined variable should fail analysis");
749        let undef = errors
750            .iter()
751            .find(|e| matches!(&e.error, TypeError::UndefinedVariable(name) if name == "duckdb"))
752            .expect("missing undefined-variable error for duckdb");
753
754        assert_eq!(undef.line, 3);
755        assert_eq!(undef.column, 9);
756    }
757
758    #[test]
759    fn test_known_bindings_allow_extension_namespace_in_type_analysis() {
760        use shape_ast::parser::parse_program;
761
762        let source = r#"let conn = duckdb.connect("duckdb://analytics.db")"#;
763        let program = parse_program(source).expect("program should parse");
764        let known = vec!["duckdb".to_string()];
765
766        let result = analyze_program(&program, Some(source), None, Some(&known));
767        assert!(
768            result.is_ok(),
769            "known extension namespaces should not fail type analysis: {:?}",
770            result.err()
771        );
772    }
773}