Skip to main content

shape_runtime/type_system/
checker.rs

1//! Type Checker
2//!
3//! Performs type checking on Shape programs using the type inference engine
4//! and reports type errors with helpful messages.
5
6use super::errors::{TypeError, TypeErrorWithLocation, TypeResult};
7use super::inference::TypeInferenceEngine;
8use super::*;
9use shape_ast::ast::{EnumDef, Expr, Item, Program, Span, Statement, TypeAnnotation};
10use std::collections::{HashMap, HashSet};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum TypeAnalysisMode {
14    FailFast,
15    RecoverAll,
16}
17
18pub struct TypeChecker {
19    /// Type inference engine
20    inference_engine: TypeInferenceEngine,
21    /// Collected errors
22    errors: Vec<TypeErrorWithLocation>,
23    /// Source code for error reporting
24    source: Option<String>,
25    /// File name for error reporting
26    filename: Option<String>,
27    /// Enum definitions for resolving named types
28    enum_defs: HashMap<String, EnumDef>,
29    /// Current function's parameter types (name -> type annotation)
30    current_function_params: HashMap<String, shape_ast::ast::TypeAnnotation>,
31    /// Error emission behavior for semantic analysis.
32    analysis_mode: TypeAnalysisMode,
33}
34
35impl Default for TypeChecker {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl TypeChecker {
42    pub fn new() -> Self {
43        TypeChecker {
44            inference_engine: TypeInferenceEngine::new(),
45            errors: Vec::new(),
46            source: None,
47            filename: None,
48            enum_defs: HashMap::new(),
49            current_function_params: HashMap::new(),
50            analysis_mode: TypeAnalysisMode::FailFast,
51        }
52    }
53
54    /// Set source code for error reporting
55    pub fn with_source(mut self, source: String) -> Self {
56        self.source = Some(source);
57        self
58    }
59
60    /// Set filename for error reporting
61    pub fn with_filename(mut self, filename: String) -> Self {
62        self.filename = Some(filename);
63        self
64    }
65
66    /// Register host-provided root-scope bindings (e.g. extension module namespaces).
67    pub fn with_known_bindings(mut self, names: &[String]) -> Self {
68        self.inference_engine.register_known_bindings(names);
69        self
70    }
71
72    pub fn with_analysis_mode(mut self, mode: TypeAnalysisMode) -> Self {
73        self.analysis_mode = mode;
74        self
75    }
76
77    /// Type check a complete program
78    pub fn check_program(
79        &mut self,
80        program: &Program,
81    ) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
82        // Clear previous errors
83        self.errors.clear();
84        self.enum_defs.clear();
85
86        // Collect enum definitions for type resolution
87        for item in &program.items {
88            if let Item::Enum(enum_def, _) = item {
89                self.enum_defs
90                    .insert(enum_def.name.clone(), enum_def.clone());
91            }
92        }
93
94        let types = match self.analysis_mode {
95            TypeAnalysisMode::FailFast => match self.inference_engine.infer_program(program) {
96                Ok(types) => types,
97                Err(err) => {
98                    self.add_inference_error(err);
99                    return Err(self.errors.clone());
100                }
101            },
102            TypeAnalysisMode::RecoverAll => {
103                let (types, inference_errors) =
104                    self.inference_engine.infer_program_best_effort(program);
105                for err in inference_errors {
106                    self.add_inference_error(err);
107                }
108                types
109            }
110        };
111
112        // Perform additional type checking
113        self.check_items(&program.items);
114
115        // Check exhaustiveness of match expressions
116        self.check_expressions(&program.items);
117
118        self.prune_error_cascades();
119
120        if self.errors.is_empty() {
121            // Convert inference types to semantic types
122            let semantic_types: HashMap<String, SemanticType> = types
123                .iter()
124                .filter_map(|(name, ty)| ty.to_semantic().map(|st| (name.clone(), st)))
125                .collect();
126
127            Ok(TypeCheckResult {
128                types,
129                semantic_types,
130                warnings: Vec::new(),
131            })
132        } else {
133            Err(self.errors.clone())
134        }
135    }
136
137    fn add_inference_error(&mut self, err: TypeError) {
138        let (line, col) = self.find_inference_error_position(&err);
139        self.add_error(err, line, col);
140    }
141
142    fn prune_error_cascades(&mut self) {
143        let has_specific_errors = self
144            .errors
145            .iter()
146            .any(|err| !matches!(err.error, TypeError::UnsolvedConstraints(_)));
147        if has_specific_errors {
148            self.errors
149                .retain(|err| !matches!(err.error, TypeError::UnsolvedConstraints(_)));
150        }
151
152        let mut seen = HashSet::new();
153        self.errors.retain(|err| {
154            let key = (err.line, err.column, err.error.to_string());
155            seen.insert(key)
156        });
157    }
158
159    fn find_inference_error_position(&self, error: &TypeError) -> (usize, usize) {
160        match error {
161            TypeError::UnknownProperty(_, property) => {
162                if let Some(span) = self
163                    .inference_engine
164                    .lookup_unknown_property_origin(property)
165                {
166                    if let Some((line, col)) = self.span_to_line_col(span) {
167                        return (line, col);
168                    }
169                }
170                (0, 0)
171            }
172            TypeError::UndefinedVariable(name) => self
173                .inference_engine
174                .lookup_undefined_variable_origin(name)
175                .and_then(|span| self.span_to_line_col(span))
176                .unwrap_or((0, 0)),
177            TypeError::UnsolvedConstraints(constraints) => {
178                if let Some(span) = self
179                    .inference_engine
180                    .find_origin_for_unsolved_constraints(constraints)
181                {
182                    if let Some((line, col)) = self.span_to_line_col(span) {
183                        return (line, col);
184                    }
185                }
186                if let Some(span) = self.inference_engine.find_any_constraint_origin() {
187                    if let Some((line, col)) = self.span_to_line_col(span) {
188                        return (line, col);
189                    }
190                }
191                (0, 0)
192            }
193            TypeError::InvalidAssertion(_, _) => (0, 0),
194            TypeError::NonExhaustiveMatch { enum_name, .. } => self
195                .inference_engine
196                .lookup_non_exhaustive_match_origin(enum_name)
197                .and_then(|span| self.span_to_line_col(span))
198                .unwrap_or((0, 0)),
199            TypeError::GenericTypeError { symbol, .. } => {
200                if let Some(symbol) = symbol
201                    && let Some(span) = self
202                        .inference_engine
203                        .lookup_callable_origin_for_name(symbol)
204                    && let Some((line, col)) = self.span_to_line_col(span)
205                {
206                    return (line, col);
207                }
208                if let Some(span) = self.inference_engine.find_any_constraint_origin() {
209                    if let Some((line, col)) = self.span_to_line_col(span) {
210                        return (line, col);
211                    }
212                }
213                (0, 0)
214            }
215            _ => (0, 0),
216        }
217    }
218
219    fn span_to_line_col(&self, span: shape_ast::ast::Span) -> Option<(usize, usize)> {
220        let source = self.source.as_ref()?;
221        let start = span.start.min(source.len());
222        let prefix = &source[..start];
223        let line = prefix.bytes().filter(|b| *b == b'\n').count() + 1;
224        let line_start = prefix.rfind('\n').map(|idx| idx + 1).unwrap_or(0);
225        let column = prefix[line_start..].chars().count() + 1;
226        Some((line, column))
227    }
228
229    /// Check all items in the program
230    fn check_items(&mut self, items: &[Item]) {
231        for item in items {
232            self.check_item(item);
233        }
234    }
235
236    /// Check all expressions in the program for exhaustiveness
237    fn check_expressions(&mut self, items: &[Item]) {
238        for item in items {
239            self.check_item_expressions(item);
240        }
241    }
242
243    /// Check expressions within an item
244    fn check_item_expressions(&mut self, item: &Item) {
245        if let Item::Function(func, _) = item {
246            // Set up function parameter context for type resolution
247            self.current_function_params.clear();
248            for param in &func.params {
249                if let Some(type_ann) = &param.type_annotation {
250                    // Insert all identifiers from the pattern
251                    for name in param.get_identifiers() {
252                        self.current_function_params.insert(name, type_ann.clone());
253                    }
254                }
255            }
256
257            for stmt in &func.body {
258                self.check_statement_expressions(stmt);
259            }
260
261            // Clear function parameter context
262            self.current_function_params.clear();
263        }
264    }
265
266    /// Check expressions within a statement
267    fn check_statement_expressions(&mut self, stmt: &Statement) {
268        match stmt {
269            Statement::Expression(expr, _) => self.check_expr(expr),
270            Statement::Return(Some(expr), _) => self.check_expr(expr),
271            Statement::VariableDecl(decl, _) => {
272                if let Some(init) = &decl.value {
273                    self.check_expr(init);
274                }
275            }
276            Statement::If(if_stmt, _) => {
277                self.check_expr(&if_stmt.condition);
278                for stmt in &if_stmt.then_body {
279                    self.check_statement_expressions(stmt);
280                }
281                if let Some(else_body) = &if_stmt.else_body {
282                    for stmt in else_body {
283                        self.check_statement_expressions(stmt);
284                    }
285                }
286            }
287            Statement::While(while_loop, _) => {
288                self.check_expr(&while_loop.condition);
289                for stmt in &while_loop.body {
290                    self.check_statement_expressions(stmt);
291                }
292            }
293            Statement::For(for_loop, _) => {
294                for stmt in &for_loop.body {
295                    self.check_statement_expressions(stmt);
296                }
297            }
298            _ => {}
299        }
300    }
301
302    /// Check a single expression for type issues
303    ///
304    /// Note: Exhaustiveness checking is now handled by the inference engine
305    /// during `infer_expr()`, so we don't need to check it here.
306    fn check_expr(&mut self, expr: &Expr) {
307        match expr {
308            Expr::Match(match_expr, _span) => {
309                // Exhaustiveness is checked by inference engine - just check sub-expressions
310                self.check_expr(&match_expr.scrutinee);
311                for arm in &match_expr.arms {
312                    if let Some(guard) = &arm.guard {
313                        self.check_expr(guard);
314                    }
315                    self.check_expr(&arm.body);
316                }
317            }
318            // Recursively check sub-expressions
319            Expr::BinaryOp { left, right, .. } => {
320                self.check_expr(left);
321                self.check_expr(right);
322            }
323            Expr::UnaryOp { operand, .. } => {
324                self.check_expr(operand);
325            }
326            Expr::Conditional {
327                condition,
328                then_expr,
329                else_expr,
330                ..
331            } => {
332                self.check_expr(condition);
333                self.check_expr(then_expr);
334                if let Some(else_e) = else_expr {
335                    self.check_expr(else_e);
336                }
337            }
338            Expr::If(if_expr, _) => {
339                self.check_expr(&if_expr.condition);
340                self.check_expr(&if_expr.then_branch);
341                if let Some(else_branch) = &if_expr.else_branch {
342                    self.check_expr(else_branch);
343                }
344            }
345            Expr::FunctionCall { args, .. } => {
346                for arg in args {
347                    self.check_expr(arg);
348                }
349            }
350            Expr::QualifiedFunctionCall { args, .. } => {
351                for arg in args {
352                    self.check_expr(arg);
353                }
354            }
355            Expr::MethodCall { receiver, args, .. } => {
356                self.check_expr(receiver);
357                for arg in args {
358                    self.check_expr(arg);
359                }
360            }
361            Expr::Array(elems, _) => {
362                for elem in elems {
363                    self.check_expr(elem);
364                }
365            }
366            Expr::PropertyAccess { object, .. } => {
367                self.check_expr(object);
368            }
369            Expr::IndexAccess {
370                object,
371                index,
372                end_index,
373                ..
374            } => {
375                self.check_expr(object);
376                self.check_expr(index);
377                if let Some(end) = end_index {
378                    self.check_expr(end);
379                }
380            }
381            _ => {}
382        }
383    }
384
385    // Note: check_match_exhaustiveness, resolve_named_to_enum, and span_to_location
386    // were removed as exhaustiveness checking is now handled by the inference engine
387    // in TypeInferenceEngine::infer_expr() for Match expressions.
388
389    /// Check a single item
390    fn check_item(&mut self, item: &Item) {
391        match item {
392            Item::Function(func, span) => {
393                // Check for missing return statements
394                if func.return_type.is_some()
395                    && !matches!(func.return_type.as_ref().unwrap(), TypeAnnotation::Void)
396                    && !self.has_return_statement(&func.body)
397                {
398                    let (line, col) = self.item_span_to_line_col(*span);
399                    self.add_error(TypeError::MissingReturn(func.name.clone()), line, col);
400                }
401            }
402
403            Item::TypeAlias(alias, span) => {
404                // Check for cyclic type aliases
405                if self.is_cyclic_type_alias(&alias.name, &alias.type_annotation) {
406                    let (line, col) = self.item_span_to_line_col(*span);
407                    self.add_error(TypeError::CyclicTypeAlias(alias.name.clone()), line, col);
408                }
409            }
410
411            Item::Interface(interface, span) => {
412                // Validate interface definition
413                self.check_interface(interface, *span);
414            }
415
416            _ => {}
417        }
418    }
419
420    /// Check if statements contain a return statement
421    fn has_return_statement(&self, stmts: &[Statement]) -> bool {
422        for stmt in stmts {
423            match stmt {
424                Statement::Return(_, _) => return true,
425                Statement::If(if_stmt, _) => {
426                    // Both branches must have returns
427                    if let Some(else_body) = &if_stmt.else_body {
428                        if self.has_return_statement(&if_stmt.then_body)
429                            && self.has_return_statement(else_body)
430                        {
431                            return true;
432                        }
433                    }
434                }
435                Statement::While(while_loop, _) => {
436                    if self.has_return_statement(&while_loop.body) {
437                        // Note: This is conservative - while loop might not execute
438                        return true;
439                    }
440                }
441                Statement::For(for_loop, _) => {
442                    if self.has_return_statement(&for_loop.body) {
443                        // Note: This is conservative - for loop might not execute
444                        return true;
445                    }
446                }
447                _ => {}
448            }
449        }
450
451        false
452    }
453
454    /// Check for cyclic type aliases
455    fn is_cyclic_type_alias(&self, name: &str, ty: &TypeAnnotation) -> bool {
456        self.references_type(ty, name)
457    }
458
459    /// Check if a type annotation references a specific type name
460    fn references_type(&self, ty: &TypeAnnotation, name: &str) -> bool {
461        match ty {
462            TypeAnnotation::Reference(ref_name) => ref_name == name,
463            TypeAnnotation::Array(elem) => self.references_type(elem, name),
464            TypeAnnotation::Tuple(elems) => {
465                elems.iter().any(|elem| self.references_type(elem, name))
466            }
467            TypeAnnotation::Object(fields) => fields
468                .iter()
469                .any(|field| self.references_type(&field.type_annotation, name)),
470            TypeAnnotation::Function { params, returns } => {
471                params
472                    .iter()
473                    .any(|param| self.references_type(&param.type_annotation, name))
474                    || self.references_type(returns, name)
475            }
476            TypeAnnotation::Union(types) => types.iter().any(|ty| self.references_type(ty, name)),
477            TypeAnnotation::Generic { args, .. } => {
478                args.iter().any(|arg| self.references_type(arg, name))
479            }
480            _ => false,
481        }
482    }
483
484    /// Check interface definition
485    fn check_interface(&mut self, interface: &shape_ast::ast::InterfaceDef, interface_span: Span) {
486        // Check for duplicate members
487        let mut seen_members = HashMap::new();
488
489        for (i, member) in interface.members.iter().enumerate() {
490            let member_name = match member {
491                shape_ast::ast::InterfaceMember::Property { name, .. } => name,
492                shape_ast::ast::InterfaceMember::Method { name, .. } => name,
493                shape_ast::ast::InterfaceMember::IndexSignature { .. } => continue,
494            };
495
496            if let Some(_prev_index) = seen_members.get(member_name) {
497                let (line, col) = self.item_span_to_line_col(interface_span);
498                self.add_error(
499                    TypeError::InterfaceError(
500                        interface.name.clone(),
501                        format!("Duplicate member '{}'", member_name),
502                    ),
503                    line,
504                    col,
505                );
506            } else {
507                seen_members.insert(member_name.clone(), i);
508            }
509        }
510    }
511
512    fn item_span_to_line_col(&self, span: Span) -> (usize, usize) {
513        self.span_to_line_col(span).unwrap_or((0, 0))
514    }
515
516    /// Add an error with location information
517    fn add_error(&mut self, error: TypeError, line: usize, column: usize) {
518        let mut err = TypeErrorWithLocation::new(error, line, column);
519
520        if let Some(filename) = &self.filename {
521            err = err.with_file(filename.clone());
522        }
523
524        if let Some(source) = &self.source {
525            // Extract the source line
526            if let Some(source_line) = source.lines().nth(line.saturating_sub(1)) {
527                err = err.with_source_line(source_line.to_string());
528            }
529        }
530
531        self.errors.push(err);
532    }
533
534    /// Get all collected errors
535    pub fn errors(&self) -> &[TypeErrorWithLocation] {
536        &self.errors
537    }
538
539    /// Format all errors for display
540    pub fn format_errors(&self) -> String {
541        self.errors
542            .iter()
543            .map(|err| err.format_with_source())
544            .collect::<Vec<_>>()
545            .join("\n")
546    }
547}
548
549/// Shared single-entry type analysis used by compiler and LSP.
550pub fn analyze_program(
551    program: &Program,
552    source: Option<&str>,
553    filename: Option<&str>,
554    known_bindings: Option<&[String]>,
555) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
556    analyze_program_with_mode(
557        program,
558        source,
559        filename,
560        known_bindings,
561        TypeAnalysisMode::FailFast,
562    )
563}
564
565/// Shared type analysis with explicit recovery behavior.
566pub fn analyze_program_with_mode(
567    program: &Program,
568    source: Option<&str>,
569    filename: Option<&str>,
570    known_bindings: Option<&[String]>,
571    analysis_mode: TypeAnalysisMode,
572) -> Result<TypeCheckResult, Vec<TypeErrorWithLocation>> {
573    let mut checker = TypeChecker::new();
574    if let Some(src) = source {
575        checker = checker.with_source(src.to_string());
576    }
577    if let Some(file) = filename {
578        checker = checker.with_filename(file.to_string());
579    }
580    if let Some(names) = known_bindings {
581        checker = checker.with_known_bindings(names);
582    }
583    checker = checker.with_analysis_mode(analysis_mode);
584    checker.check_program(program)
585}
586
587/// Result of type checking
588#[derive(Debug)]
589pub struct TypeCheckResult {
590    /// Inferred types for all declarations (inference-level types)
591    pub types: HashMap<String, Type>,
592    /// Semantic types for all declarations (user-facing types)
593    pub semantic_types: HashMap<String, SemanticType>,
594    /// Type warnings (non-fatal issues)
595    pub warnings: Vec<TypeWarning>,
596}
597
598impl TypeCheckResult {
599    /// Get the semantic type for a declaration
600    pub fn get_semantic_type(&self, name: &str) -> Option<&SemanticType> {
601        self.semantic_types.get(name)
602    }
603
604    /// Get all function declarations that are fallible (return Result)
605    pub fn fallible_functions(&self) -> Vec<&str> {
606        self.semantic_types
607            .iter()
608            .filter_map(|(name, ty)| {
609                if let SemanticType::Function(sig) = ty {
610                    if sig.return_type.is_result() {
611                        return Some(name.as_str());
612                    }
613                }
614                None
615            })
616            .collect()
617    }
618}
619
620/// Type warning for non-fatal issues
621#[derive(Debug)]
622pub struct TypeWarning {
623    pub message: String,
624    pub line: usize,
625    pub column: usize,
626}
627
628/// Type check an expression and return its type
629pub fn type_of_expr(expr: &Expr, _env: &TypeEnvironment) -> TypeResult<Type> {
630    let mut engine = TypeInferenceEngine::new();
631    engine.infer_expr(expr)
632}
633
634/// Quick type check for REPL and testing
635pub fn quick_check(source: &str) -> Result<TypeCheckResult, String> {
636    use shape_ast::parser::parse_program;
637
638    let program = parse_program(source).map_err(|e| format!("Parse error: {}", e))?;
639
640    let mut checker = TypeChecker::new().with_source(source.to_string());
641
642    checker.check_program(&program).map_err(|errors| {
643        errors
644            .iter()
645            .map(|e| e.format_with_source())
646            .collect::<Vec<_>>()
647            .join("\n")
648    })
649}
650
651#[cfg(test)]
652mod tests {
653    use super::*;
654
655    #[test]
656    fn test_exhaustiveness_integration_non_exhaustive_match_produces_error() {
657        // This test proves exhaustiveness checking is connected to the compiler pipeline.
658        // A match on an enum that doesn't cover all variants should produce an error.
659        let source = r#"
660            enum Status { Active, Inactive, Pending }
661
662            function check(s: Status) {
663                return match s {
664                    Status::Active => "yes"
665                };
666            }
667        "#;
668
669        let result = quick_check(source);
670
671        // The match is non-exhaustive (missing Inactive and Pending)
672        // so we expect an error
673        assert!(
674            result.is_err(),
675            "Expected error for non-exhaustive match, got: {:?}",
676            result
677        );
678        let err = result.unwrap_err();
679        assert!(
680            err.contains("NonExhaustive")
681                || err.contains("non-exhaustive")
682                || err.contains("missing"),
683            "Expected non-exhaustive match error, got: {}",
684            err
685        );
686    }
687
688    #[test]
689    fn test_exhaustiveness_integration_exhaustive_match_succeeds() {
690        // A match that covers all variants should succeed
691        let source = r#"
692            enum Status { Active, Inactive }
693
694            function check(s: Status) {
695                return match s {
696                    Status::Active => "yes",
697                    Status::Inactive => "no"
698                };
699            }
700        "#;
701
702        let result = quick_check(source);
703
704        // The match is exhaustive, so no error expected from exhaustiveness
705        // (there might be other errors, but not NonExhaustiveMatch)
706        if let Err(err) = &result {
707            assert!(
708                !err.contains("NonExhaustive") && !err.contains("non-exhaustive"),
709                "Should not have non-exhaustive error for exhaustive match, got: {}",
710                err
711            );
712        }
713    }
714
715    #[test]
716    fn test_exhaustiveness_integration_wildcard_makes_exhaustive() {
717        // A match with wildcard pattern should be trivially exhaustive
718        let source = r#"
719            enum Status { Active, Inactive, Pending }
720
721            function check(s: Status) {
722                return match s {
723                    Status::Active => "yes",
724                    _ => "other"
725                };
726            }
727        "#;
728
729        let result = quick_check(source);
730
731        // The wildcard makes it exhaustive
732        if let Err(err) = &result {
733            assert!(
734                !err.contains("NonExhaustive") && !err.contains("non-exhaustive"),
735                "Wildcard should make match exhaustive, got: {}",
736                err
737            );
738        }
739    }
740
741    #[test]
742    fn test_undefined_variable_reports_identifier_position() {
743        use shape_ast::parser::parse_program;
744
745        let source = r#"
746let x = 1
747let y = duckdb.connect("duckdb://analytics.db")
748"#;
749
750        let program = parse_program(source).expect("program should parse");
751        let result = analyze_program(&program, Some(source), None, None);
752        let errors = result.expect_err("undefined variable should fail analysis");
753        let undef = errors
754            .iter()
755            .find(|e| matches!(&e.error, TypeError::UndefinedVariable(name) if name == "duckdb"))
756            .expect("missing undefined-variable error for duckdb");
757
758        assert_eq!(undef.line, 3);
759        assert_eq!(undef.column, 9);
760    }
761
762    #[test]
763    fn test_known_bindings_allow_extension_namespace_in_type_analysis() {
764        use shape_ast::parser::parse_program;
765
766        let source = r#"let conn = duckdb.connect("duckdb://analytics.db")"#;
767        let program = parse_program(source).expect("program should parse");
768        let known = vec!["duckdb".to_string()];
769
770        let result = analyze_program(&program, Some(source), None, Some(&known));
771        assert!(
772            result.is_ok(),
773            "known extension namespaces should not fail type analysis: {:?}",
774            result.err()
775        );
776    }
777}